@@ -65,6 +65,13 @@ Cassette.@context DualContext
6565
6666const TaggedCtx{T} = Context{nametype (DualContext),T}
6767
68+ untagtype (:: Type{<:Dual{Tag{T},V}} , :: Type{<:TaggedCtx{T}} ) where {T,V} = V
69+
70+ @inline @generated function _overdub (ctx:: TaggedCtx{T} , f, args... ) where T
71+ F = Cassette. ReflectOn{Tuple{f, (untagtype (args[i], ctx) for i in 1 : nfields (args)). .. }}
72+ :(overdub (ctx, $ F (), f, args... ))
73+ end
74+
6875function dualcontext ()
6976 # Note that the `dualtag()` is not of the same type as that of the
7077 # Duals constructed in this context, because it is called in the older context
136143
137144 # we call frule with an older context because the Dual numbers may
138145 # themselves contain Dual numbers that were created in an older context
139- frule_result = overdub (ctx1, frule, f, vs... , dself, ps... )
146+ frule_result = _overdub (ctx1, frule, f, vs... , dself, ps... )
140147 else
141148 frule_result = frule (f, vs... , dself, ps... )
142149 end
146153 # We can't just do f(args...) here because `f` might be
147154 # a closure which closes over a Dual number, hence we call
148155 # recurse. Recurse overdubs the calls inside `f` and not `f` itself
149- return Cassette . overdub (ctx, f, args... )
156+ return _overdub (ctx, f, args... )
150157 else
151158 # this means there exists an frule for this specific call.
152159 # frule_result is then a tuple (val, pushforward) where val
172179
173180 idx = find_dual (tag, args... )
174181 if f === Dual
175- return overdub (ctx, f, args... )
182+ return _overdub (ctx, f, args... )
176183 elseif idx === 0
177184 # This is the base case for the recursion in this function which
178185 # tries to do the alternative with successively older contexts
183190 # none of the arguments have the same tag as the context
184191 # try with the parent context
185192 ctx1 = similarcontext (ctx, metadata= oldertag (ctx. metadata))
186- return overdub (ctx1, f, args... )
193+ return _overdub (ctx1, f, args... )
187194 else
188195 # call ChainRules.frule to execute `f` and
189196 # get a function that computes the partials
193200
194201function dualrun (f, args... )
195202 ctx = dualcontext ()
196- return overdub (ctx, f, args... )
203+ return _overdub (ctx, f, args... )
197204end
198205
199206const BINARY_PREDICATES = Symbol[:isequal , :isless , :< , :> , :(== ), :(!= ), :(<= ), :(>= )]
0 commit comments