@@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...)
8888 )
8989 f = call. args[1 ]
9090
91- frule_expr = scalar_frule_expr (__source__, f, call, setup_stmts, inputs, partials)
92- rrule_expr = scalar_rrule_expr (__source__, f, call, setup_stmts, inputs, partials)
91+ # Generate variables to store derivatives named dfi/dxj
92+ derivatives = map (keys (partials)) do i
93+ syms = map (j -> Symbol (" ∂f" , i, " /∂x" , j), keys (inputs))
94+ return Expr (:tuple , syms... )
95+ end
96+
97+ derivative_expr = scalar_derivative_expr (__source__, f, setup_stmts, inputs, partials)
98+ frule_expr = scalar_frule_expr (__source__, f, call, [], inputs, derivatives)
99+ rrule_expr = scalar_rrule_expr (__source__, f, call, [], inputs, derivatives)
93100
94101 # Final return: building the expression to insert in the place of this macro
95102 code = quote
@@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...)
99106 ))
100107 end
101108
109+ $ (derivative_expr)
102110 $ (frule_expr)
103111 $ (rrule_expr)
104112 end
@@ -135,16 +143,45 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
135143 # For consistency in code that follows we make all partials tuple expressions
136144 partials = map (partials) do partial
137145 if Meta. isexpr (partial, :tuple )
138- partial
146+ Expr ( :tuple , map (esc, partial. args) ... )
139147 else
140148 length (inputs) == 1 || error (" Invalid use of `@scalar_rule`" )
141- Expr (:tuple , partial)
149+ Expr (:tuple , esc ( partial) )
142150 end
143151 end
144152
145153 return call, setup_stmts, inputs, partials
146154end
147155
156+ """
157+ derivatives_given_output(Ω, f, xs...)
158+
159+ Compute the derivative of scalar function `f` at primal input point `xs...`,
160+ given that it had primal output `Ω`.
161+ Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`.
162+ The derivative of the `i`-th component of `f` with respect to the `j`-th input can be
163+ accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`.
164+
165+ !!! warning "Experimental"
166+ This function is experimental and not part of the stable API.
167+ At the moment, it can be considered an implementation detail of the macro
168+ [`@scalar_rule`](@ref), in which it is used.
169+ In the future, the exact semantics of this function will stabilize, and it
170+ will be added to the stable API.
171+ When that happens, this warning will be removed.
172+
173+ """
174+ function derivatives_given_output end
175+
176+ function scalar_derivative_expr (__source__, f, setup_stmts, inputs, partials)
177+ return @strip_linenos quote
178+ function ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), :: Core.Typeof ($ f), $ (inputs... ))
179+ $ (__source__)
180+ $ (setup_stmts... )
181+ return $ (Expr (:tuple , partials... ))
182+ end
183+ end
184+ end
148185
149186function scalar_frule_expr (__source__, f, call, setup_stmts, inputs, partials)
150187 n_outputs = length (partials)
@@ -173,6 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
173210 $ (__source__)
174211 $ (esc (:Ω )) = $ call
175212 $ (setup_stmts... )
213+ $ (Expr (:tuple , partials... )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
176214 return $ (esc (:Ω )), $ pushforward_returns
177215 end
178216 end
@@ -210,6 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
210248 $ (__source__)
211249 $ (esc (:Ω )) = $ call
212250 $ (setup_stmts... )
251+ $ (Expr (:tuple , partials... )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
213252 return $ (esc (:Ω )), $ pullback
214253 end
215254 end
@@ -240,9 +279,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
240279 # This is basically Δs ⋅ ∂s
241280 _∂s = map (∂s) do ∂s_i
242281 if _conj
243- :(conj ($ ( esc ( ∂s_i)) ))
282+ :(conj ($ ∂s_i))
244283 else
245- esc ( ∂s_i)
284+ ∂s_i
246285 end
247286 end
248287
0 commit comments