@@ -21,30 +21,32 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix)
2121const maybe_tracked_arg_prefix = gensym (" maybe_tracked_arg" )
2222maybe_tracked_arg_var (node:: JuliaNode , i:: Int ) = Symbol (" $(maybe_tracked_arg_prefix) _$(node. name) _$i " )
2323
24- function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: TrainableParameterNode )
25- # TODO : only need to mark it if we are doing backprop params
24+ function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: TrainableParameterNode , :: BackpropParamsMode )
2625 push! (fwd_marked, node)
2726end
2827
29- function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: ArgumentNode )
28+ function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: TrainableParameterNode , :: BackpropTraceMode )
29+ end
30+
31+ function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: ArgumentNode , mode)
3032 if node. compute_grad
3133 push! (fwd_marked, node)
3234 end
3335end
3436
35- function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: JuliaNode )
37+ function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: JuliaNode , mode )
3638 if any (input_node in fwd_marked for input_node in node. inputs)
3739 push! (fwd_marked, node)
3840 end
3941end
4042
41- function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: RandomChoiceNode )
43+ function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: RandomChoiceNode , mode )
4244 if node in selected_choices
4345 push! (fwd_marked, node)
4446 end
4547end
4648
47- function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: GenerativeFunctionCallNode )
49+ function fwd_pass! (selected_choices, selected_calls, fwd_marked, node:: GenerativeFunctionCallNode , mode )
4850 if node in selected_calls || any (input_node in fwd_marked for input_node in node. inputs)
4951 push! (fwd_marked, node)
5052 end
7375
7476function back_pass! (back_marked, node:: GenerativeFunctionCallNode )
7577 # the logpdf of every generative function call is a SINK
78+ # (we could ask whether the generative function is deterministic or not
79+ # as a perforance optimization, because only stochsatic generative functions
80+ # actually have a non-trivial logpdf)
7681 for (input_node, has_grad) in zip (node. inputs, has_argument_grads (node. generative_function))
7782 if has_grad
7883 push! (back_marked, input_node)
8287
8388function fwd_codegen! (stmts, fwd_marked, back_marked, node:: TrainableParameterNode )
8489 if node in back_marked
85- push! (stmts, :($ (node. name) = $ (QuoteNode (get_param))($ (QuoteNode (get_gen_fn))(trace),
86- $ (QuoteNode (node. name)))))
90+ push! (stmts, :($ (node. name) = $ (QuoteNode (get_parameter_value))(trace, $ (QuoteNode (node. name)))))
8791 end
8892
8993 if node in fwd_marked && node in back_marked
9094
9195 # initialize gradient to zero
9296 # NOTE: we are avoiding allocating a new gradient accumulator for this function
9397 # instead, we are using the threadsafe gradient accumulator directly..
94- # push!(stmts, :($(gradient_var(node)) = zero($(node.name))))
95- push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (get_gen_fn))(trace). params_grad[$ (QuoteNode (node. name))]))
98+ push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (get_gradient_accumulator))(trace, $ (QuoteNode (node. name)))))
9699 end
97100end
98101
106109
107110function fwd_codegen! (stmts, fwd_marked, back_marked, node:: JuliaNode )
108111
109- if node in back_marked && any (input_node in fwd_marked for input_node in node . inputs )
112+ if ( node in fwd_marked) && (node in back_marked )
110113
111114 # tracked forward execution
112115 tape = tape_var (node)
@@ -128,27 +131,20 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
128131
129132 # initialize gradient to zero
130133 push! (stmts, :($ (gradient_var (node)) = zero ($ (node. name))))
131- else
132134
133- # regular forward execution.
135+ elseif node in back_marked
134136
135- # we need the value for initializing gradient to zero (to get the type
136- # and e.g. shape), and for reference by other nodes during
137- # back_codegen! we could be more selective about which JuliaNodes need
138- # to be evalutaed, that is a performance optimization for the future
137+ # regular forward execution.
139138 args = map ((input_node) -> input_node. name, node. inputs)
140139 push! (stmts, :($ (node. name) = $ (QuoteNode (node. fn))($ (args... ))))
141140 end
142141end
143142
144143function fwd_codegen! (stmts, fwd_marked, back_marked, node:: RandomChoiceNode )
145- # for reference by other nodes during back_codegen!
146- # could performance optimize this away
147- push! (stmts, :($ (node. name) = trace.$ (get_value_fieldname (node))))
148-
149144 # every random choice is in back_marked, since it affects it logpdf, but
150145 # also possibly due to other downstream usage of the value
151146 @assert node in back_marked
147+ push! (stmts, :($ (node. name) = trace.$ (get_value_fieldname (node))))
152148
153149 if node in fwd_marked
154150 # the only way we are fwd_marked is if this choice was selected
@@ -160,14 +156,16 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode)
160156end
161157
162158function fwd_codegen! (stmts, fwd_marked, back_marked, node:: GenerativeFunctionCallNode )
163- # for reference by other nodes during back_codegen!
164- # could performance optimize this away
165- subtrace_fieldname = get_subtrace_fieldname (node)
166- push! (stmts, :($ (node. name) = $ (QuoteNode (get_retval))(trace.$ subtrace_fieldname)))
159+
160+ if node in back_marked
161+ # for reference by other nodes during back_codegen!
162+ subtrace_fieldname = get_subtrace_fieldname (node)
163+ push! (stmts, :($ (node. name) = $ (QuoteNode (get_retval))(trace.$ subtrace_fieldname)))
164+ end
167165
168166 # NOTE: we will still potentially run choice_gradients recursively on the generative function,
169167 # we just might not use its return value gradient.
170- if node in fwd_marked && node in back_marked
168+ if ( node in fwd_marked) && ( node in back_marked)
171169 # we are fwd_marked if an input was fwd_marked, or if we were selected internally
172170 push! (stmts, :($ (gradient_var (node)) = zero ($ (node. name))))
173171 end
@@ -181,15 +179,6 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
181179 push! (stmts, :(isnothing (retval_grad) && error (" Required return value gradient but got nothing" )))
182180 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
183181 $ (gradient_var (node)), retval_grad, scale_factor)))
184- # push!(stmts, :($(gradient_var(node)) += retval_grad))
185- end
186-
187- if node in fwd_marked && node in back_marked
188- # NOTE: unecessary, because we accumulated in-place already
189- # push!(stmts, :($(QuoteNode(increment_param_grad!))(trace.$static_ir_gen_fn_ref,
190- # $(QuoteNode(node.name)),
191- # $(gradient_var(node)),
192- # scale_factor)))
193182 end
194183end
195184
@@ -199,8 +188,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
199188 if node === ir. return_node && node in fwd_marked
200189 @assert node in back_marked
201190 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
202- $ (gradient_var (node)), retval_grad, scale_factor)))
203- # push!(stmts, :($(gradient_var(node)) += retval_grad))
191+ $ (gradient_var (node)), retval_grad)))
204192 end
205193end
206194
@@ -209,8 +197,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
209197 if node === ir. return_node && node in fwd_marked
210198 @assert node in back_marked
211199 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
212- $ (gradient_var (node)), retval_grad, scale_factor)))
213- # push!(stmts, :($(gradient_var(node)) += retval_grad))
200+ $ (gradient_var (node)), retval_grad)))
214201 end
215202 if node in back_marked && any (input_node in fwd_marked for input_node in node. inputs)
216203
@@ -222,9 +209,13 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node:
222209 for (i, input_node) in enumerate (node. inputs)
223210 if input_node in fwd_marked
224211 arg_maybe_tracked = maybe_tracked_arg_var (node, i)
225- push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
226- $ (gradient_var (input_node)), $ (QuoteNode (deriv))($ arg_maybe_tracked), scale_factor)))
227- # push!(stmts, :($(gradient_var(input_node)) += $(QuoteNode(deriv))($arg_maybe_tracked)))
212+ if isa (input_node, TrainableParameterNode)
213+ push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
214+ $ (gradient_var (input_node)), $ (QuoteNode (deriv))($ arg_maybe_tracked), scale_factor)))
215+ else
216+ push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
217+ $ (gradient_var (input_node)), $ (QuoteNode (deriv))($ arg_maybe_tracked))))
218+ end
228219 end
229220 end
230221 end
@@ -246,9 +237,15 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke
246237 if ! has_argument_grads (node. dist)[i]
247238 error (" Distribution $(node. dist) does not have logpdf gradient for argument $i " )
248239 end
249- push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
250- $ (gradient_var (input_node)), $ logpdf_grad[$ (QuoteNode (i+ 1 ))], scale_factor)))
251- # push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))]))
240+ input_node_grad = gradient_var (input_node)
241+ increment = :($ logpdf_grad[$ (QuoteNode (i+ 1 ))])
242+ if isa (input_node, TrainableParameterNode)
243+ push! (stmts, :($ input_node_grad = $ (QuoteNode (in_place_add!))(
244+ $ input_node_grad, $ increment, scale_factor)))
245+ else
246+ push! (stmts, :($ input_node_grad = $ (QuoteNode (in_place_add!))(
247+ $ input_node_grad, $ increment)))
248+ end
252249 end
253250 end
254251
@@ -257,7 +254,6 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke
257254 @assert node in back_marked
258255 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
259256 $ (gradient_var (node)), retval_grad, scale_factor)))
260- # push!(stmts, :($(gradient_var(node)) += retval_grad))
261257 end
262258end
263259
@@ -274,8 +270,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
274270 error (" Distribution $dist does not logpdf gradient for its output value" )
275271 end
276272 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
277- $ (gradient_var (node)), retval_grad, scale_factor)))
278- # push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1]))
273+ $ (gradient_var (node)), retval_grad)))
279274 end
280275end
281276
@@ -292,8 +287,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
292287 if node === ir. return_node && node in fwd_marked
293288 @assert node in back_marked
294289 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
295- $ (gradient_var (node)), retval_grad, scale_factor)))
296- # push!(stmts, :($(gradient_var(node)) += retval_grad))
290+ $ (gradient_var (node)), retval_grad)))
297291 end
298292
299293 if node in fwd_marked
@@ -316,9 +310,10 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
316310 for (i, input_node) in enumerate (node. inputs)
317311 if input_node in fwd_marked
318312 @assert input_node in back_marked # this ensured its gradient will have been initialized
313+ input_node_grad = gradient_var (input_node)
314+ increment = :($ input_grads[$ (QuoteNode (i))])
319315 push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
320- $ (gradient_var (input_node)), $ input_grads[$ (QuoteNode (i))], scale_factor)))
321- # push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))]))
316+ $ input_node_grad, $ increment)))
322317 end
323318 end
324319
@@ -332,8 +327,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
332327 if node === ir. return_node && node in fwd_marked
333328 @assert node in back_marked
334329 push! (stmts, :($ (gradient_var (node)) = $ (QuoteNode (in_place_add!))(
335- $ (gradient_var (node)), retval_grad, scale_factor)))
336- # push!(stmts, :($(gradient_var(node)) += retval_grad))
330+ $ (gradient_var (node)), retval_grad)))
337331 end
338332
339333 if node in fwd_marked
@@ -347,9 +341,15 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
347341 for (i, (input_node, has_grad)) in enumerate (zip (node. inputs, has_argument_grads (node. generative_function)))
348342 if input_node in fwd_marked && has_grad
349343 @assert input_node in back_marked # this ensured its gradient will have been initialized
350- push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
351- $ (gradient_var (input_node)), $ input_grads[$ (QuoteNode (i))], scale_factor)))
352- # push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))]))
344+ input_node_grad = gradient_var (input_node)
345+ increment = :($ input_grads[$ (QuoteNode (i))])
346+ if isa (input_node, TrainableParameterNode)
347+ push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
348+ $ input_node_grad, $ increment, scale_factor)))
349+ else
350+ push! (stmts, :($ (gradient_var (input_node)) = $ (QuoteNode (in_place_add!))(
351+ $ input_node_grad, $ increment)))
352+ end
353353 end
354354 end
355355end
@@ -425,16 +425,14 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
425425 return quote choice_gradients (trace, StaticSelection (selection), retval_grad) end
426426 end
427427
428- push! (stmts, :(scale_factor = NaN ))
429-
430428 ir = get_ir (gen_fn_type)
431429 selected_choices = get_selected_choices (schema, ir)
432430 selected_calls = get_selected_calls (schema, ir)
433431
434432 # forward marking pass
435433 fwd_marked = Set {StaticIRNode} ()
436434 for node in ir. nodes
437- fwd_pass! (selected_choices, selected_calls, fwd_marked, node)
435+ fwd_pass! (selected_choices, selected_calls, fwd_marked, node, BackpropTraceMode () )
438436 end
439437
440438 # backward marking pass
@@ -489,13 +487,13 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T},
489487 selected_calls = Set {GenerativeFunctionCallNode} (
490488 node for node in ir. nodes if isa (node, GenerativeFunctionCallNode))
491489
492- # forward marking pass
490+ # forward marking pass (propagate forward from 'sources')
493491 fwd_marked = Set {StaticIRNode} ()
494492 for node in ir. nodes
495- fwd_pass! (selected_choices, selected_calls, fwd_marked, node)
493+ fwd_pass! (selected_choices, selected_calls, fwd_marked, node, BackpropParamsMode () )
496494 end
497495
498- # backward marking pass
496+ # backward marking pass (propagate backwards from 'sinks')
499497 back_marked = Set {StaticIRNode} ()
500498 push! (back_marked, ir. return_node)
501499 for node in reverse (ir. nodes)
@@ -508,12 +506,15 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T},
508506 arg_names = Symbol[arg_node. name for arg_node in ir. arg_nodes]
509507 push! (stmts, :($ (Expr (:tuple , arg_names... )) = $ (QuoteNode (get_args))(trace)))
510508
511- # forward code-generation pass (initialize gradients to zero, create needed references)
509+ # forward code-generation pass
510+ # any node that is backward-marked creates a variable for its current value
511+ # any node that is forward-marked and backwards marked initializes a gradient variable
512512 for node in ir. nodes
513513 fwd_codegen! (stmts, fwd_marked, back_marked, node)
514514 end
515515
516- # backward code-generation pass (increment gradients)
516+ # backward code-generation pass
517+ # any node that is forward-marked and backwards marked increments its gradient variable
517518 for node in reverse (ir. nodes)
518519 back_codegen! (stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode ())
519520 end
0 commit comments