11using Cassette
22using ChainRules
33using ChainRulesCore
4- import ChainRulesCore: Wirtinger, Zero
4+ import ChainRulesCore: Zero
5+
6+ # TODO : remove the copy pasted code and add that package
7+ # copyed from SpecializeVarargs.jl, written by @MasonProtter
8+ using MacroTools: MacroTools, splitdef, combinedef, @capture
9+
10+ macro specialize_vararg (n:: Int , fdef:: Expr )
11+ @assert n > 0
12+
13+ macros = Symbol[]
14+ while fdef. head == :macrocall && length (fdef. args) == 3
15+ push! (macros, fdef. args[1 ])
16+ fdef = fdef. args[3 ]
17+ end
18+
19+ d = splitdef (fdef)
20+ args = d[:args ][end ]
21+ @assert d[:args ][end ] isa Expr && d[:args ][end ]. head == Symbol (" ..." ) && d[:args ][end ]. args[] isa Symbol
22+ args_symbol = d[:args ][end ]. args[]
23+
24+ fdefs = Expr (:block )
25+
26+ for i in 1 : n- 1
27+ di = deepcopy (d)
28+ pop! (di[:args ])
29+ args = Tuple (gensym (" arg$j " ) for j in 1 : i)
30+ Ts = Tuple (gensym (" T$j " ) for j in 1 : i)
31+
32+ args_with_Ts = ((arg, T) -> :($ arg :: $T )). (args, Ts)
33+
34+ di[:whereparams ] = (di[:whereparams ]. .. , Ts... )
35+
36+ push! (di[:args ], args_with_Ts... )
37+ pushfirst! (di[:body ]. args, :($ args_symbol = $ (Expr (:tuple , args... ))))
38+ cfdef = combinedef (di)
39+ mcfdef = isempty (macros) ? cfdef : foldr ((m,f) -> Expr (:macrocall , m, nothing , f), macros, init= cfdef)
40+ push! (fdefs. args, mcfdef)
41+ end
42+
43+ di = deepcopy (d)
44+ pop! (di[:args ])
45+ args = tuple ((gensym () for j in 1 : n). .. , :($ (gensym (" args" )). .. ))
46+ Ts = Tuple (gensym (" T$j " ) for j in 1 : n)
47+
48+ args_with_Ts = (((arg, T) -> :($ arg :: $T )). (args[1 : end - 1 ], Ts). .. , args[end ])
49+
50+ di[:whereparams ] = (di[:whereparams ]. .. , Ts... )
51+
52+ push! (di[:args ], args_with_Ts... )
53+ pushfirst! (di[:body ]. args, :($ args_symbol = $ (Expr (:tuple , args... ))))
54+
55+ cfdef = combinedef (di)
56+ mcfdef = isempty (macros) ? cfdef : foldr ((m,f) -> Expr (:macrocall , m, nothing , f), macros, init= cfdef)
57+ push! (fdefs. args, mcfdef)
58+
59+ esc (fdefs)
60+ end
561
662using Cassette: overdub, Context, nametype, similarcontext
763
3086@inline _partials (:: Any , x) = Zero ()
3187@inline _partials (:: Tag{T} , d:: Dual{Tag{T}} ) where T = d. partials
3288
33- Wirtinger (primal, conjugate) = Wirtinger .(primal, conjugate)
34-
3589@inline _values (S, xs) = map (x-> _value (S, x), xs)
3690@inline _partialss (S, xs) = map (x-> _partials (S, x), xs)
3791
@@ -48,64 +102,54 @@ Wirtinger(primal, conjugate) = Wirtinger.(primal, conjugate)
48102end
49103
50104# actually interesting:
51-
52105@inline isinteresting (ctx:: TaggedCtx , f, a) = anydual (a)
53106@inline isinteresting (ctx:: TaggedCtx , f, a, b) = anydual (a, b)
54107@inline isinteresting (ctx:: TaggedCtx , f, a, b, c) = anydual (a, b, c)
55108@inline isinteresting (ctx:: TaggedCtx , f, a, b, c, d) = anydual (a, b, c, d)
56- @inline isinteresting (ctx:: TaggedCtx , f, args... ) = false
57- @inline isinteresting (ctx:: TaggedCtx , f:: typeof (Base. show), args... ) = false
109+ @inline isinteresting (ctx:: TaggedCtx , f, args... ) = anydual (args... )
110+ @inline isinteresting (ctx:: TaggedCtx , f:: Core.Builtin , args... ) = false
111+ @inline isinteresting (ctx:: TaggedCtx , f:: Union {typeof (ForwardDiff2. find_dual),
112+ typeof (ForwardDiff2. anydual)}, args... ) = false
58113
59- @inline function _frule_overdub2 (ctx:: TaggedCtx{T} , f, args... ) where T
114+ @specialize_vararg 4 @ inline function _frule_overdub2 (ctx:: TaggedCtx{T} , f:: F , args... ) where {T,F}
60115 # Here we can assume that one or more `args` is a Dual with tag
61116 # of type T.
62117
63118 tag = Tag {T} ()
64119 # unwrap only duals with the tag T.
65120 vs = _values (tag, args)
66121
122+ # extract the partials only for the current tag
123+ # so we can pass them to the pushforward
124+ ps = _partialss (tag, args)
125+
126+ # default `dself` to `Zero()`
127+ dself = Zero ()
128+
67129 # call frule to see if there is a rule for this call:
68130 if ctx. metadata isa Tag
69131 ctx1 = similarcontext (ctx, metadata= oldertag (ctx. metadata))
70132
71133 # we call frule with an older context because the Dual numbers may
72134 # themselves contain Dual numbers that were created in an older context
73- frule_result = overdub (ctx1, frule, f, vs... )
135+ frule_result = overdub (ctx1, frule, f, vs... , dself, ps ... )
74136 else
75- frule_result = frule (f, vs... )
137+ frule_result = frule (f, vs... , dself, ps ... )
76138 end
77139
78140 if frule_result === nothing
79141 # this means there is no frule
80142 # We can't just do f(args...) here because `f` might be
81143 # a closure which closes over a Dual number, hence we call
82144 # recurse. Recurse overdubs the calls inside `f` and not `f` itself
83-
84145 return Cassette. overdub (ctx, f, args... )
85146 else
86147 # this means there exists an frule for this specific call.
87148 # frule_result is then a tuple (val, pushforward) where val
88149 # is the primal result. (Note: this may be Dual numbers but only
89150 # with an older tag)
90- val, pushforward = frule_result
91-
92- # extract the partials only for the current tag
93- # so we can pass them to the pushforward
94- ps = _partialss (tag, args)
95-
96- # Call the pushforward to get new partials
97- # we call it with the older context because the partials
98- # might themselves be Duals from older contexts
99- if ctx. metadata isa Tag
100- ctx1 = similarcontext (ctx, metadata= oldertag (ctx. metadata))
101- ∂s = overdub (ctx1, pushforward, Zero (), ps... )
102- else
103- ∂s = pushforward (Zero (), ps... )
104- end
151+ val, ∂s = frule_result
105152
106- # Attach the new partials to the primal result
107- # multi-output `f` such as result in the new partials being
108- # a tuple, we handle both cases:
109153 return if ∂s isa Tuple
110154 map (val, ∂s) do v, ∂
111155 Dual {Tag{T}} (v, ∂)
116160 end
117161end
118162
119- @inline function alternative (ctx:: TaggedCtx{T} , f, args... ) where {T}
163+ @specialize_vararg 4 @ inline function alternative (ctx:: TaggedCtx{T} , f:: F , args... ) where {T,F }
120164 # This method only executes if `args` contains at least 1 Dual
121165 # the question is what is its tag
122166
161205
162206
163207# #### Inference Hacks
164- # this makes `log` work by making throw_complex_domainerror inferable, but not really sure why
165- @inline isinteresting (ctx:: TaggedCtx , f:: typeof (Core. throw), xs) = true
166- # add `DualContext` here to avoid ambiguity
167- @noinline alternative (ctx:: Union{DualContext,TaggedCtx} , f:: typeof (Core. throw), arg) = throw (arg)
168-
169- @inline isinteresting (ctx:: TaggedCtx , f:: typeof (Base. print_to_string), args... ) = true
170- @noinline alternative (ctx:: Union{DualContext,TaggedCtx} , f:: typeof (Base. print_to_string), args... ) = f (args... )
208+ @inline isinteresting (ctx:: TaggedCtx , f:: Union{typeof(Base.print_to_string),typeof(hash)} , args... ) = false
209+ @inline Cassette. overdub (ctx:: TaggedCtx , f:: Union{typeof(Base.print_to_string),typeof(hash)} , args... ) = f (args... )
210+ @inline Cassette. overdub (ctx:: TaggedCtx , f:: Core.Builtin , args... ) = f (args... )
0 commit comments