@@ -4,11 +4,11 @@ struct BackpropParamsMode end
44const gradient_prefix = gensym (" gradient" )
55gradient_var (node:: StaticIRNode ) = Symbol (" $(gradient_prefix) _$(node. name) " )
66
7- const value_trie_prefix = gensym (" value_trie " )
8- value_trie_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(value_trie_prefix ) _$(node. addr) " )
7+ const value_choices_prefix = gensym (" value_choices " )
8+ value_choices_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(value_choices_prefix ) _$(node. addr) " )
99
10- const gradient_trie_prefix = gensym (" gradient_trie " )
11- gradient_trie_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(gradient_trie_prefix ) _$(node. addr) " )
10+ const gradient_choices_prefix = gensym (" gradient_choices " )
11+ gradient_choices_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(gradient_choices_prefix ) _$(node. addr) " )
1212
1313const tape_prefix = gensym (" tape" )
1414tape_var (node:: JuliaNode ) = Symbol (" $(tape_prefix) _$(node. name) " )
@@ -128,7 +128,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
128128 # we need the value for initializing gradient to zero (to get the type
129129 # and e.g. shape), and for reference by other nodes during
130130 # back_codegen! we could be more selective about which JuliaNodes need
131- # to be evalutaed , that is a performance optimization for the future
131+ # to be evaluated , that is a performance optimization for the future
132132 args = map ((input_node) -> input_node. name, node. inputs)
133133 push! (stmts, :($ (node. name) = $ (QuoteNode (node. fn))($ (args... ))))
134134 end
@@ -275,8 +275,8 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
275275
276276 if node in fwd_marked
277277 input_grads = gensym (" call_input_grads" )
278- value_trie = value_trie_var (node)
279- gradient_trie = gradient_trie_var (node)
278+ value_choices = value_choices_var (node)
279+ gradient_choices = gradient_choices_var (node)
280280 subtrace_fieldname = get_subtrace_fieldname (node)
281281 call_selection = gensym (" call_selection" )
282282 if node in selected_calls
@@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
285285 push! (stmts, :($ call_selection = EmptySelection ()))
286286 end
287287 retval_grad = node in back_marked ? gradient_var (node) : :(nothing )
288- push! (stmts, :(($ input_grads, $ value_trie , $ gradient_trie ) = choice_gradients (
288+ push! (stmts, :(($ input_grads, $ value_choices , $ gradient_choices ) = choice_gradients (
289289 trace.$ subtrace_fieldname, $ call_selection, $ retval_grad)))
290290 end
291291
@@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
297297 end
298298 end
299299
300- # NOTE: the value_trie and gradient_trie are dealt with later
300+ # NOTE: the value_choices and gradient_choices are dealt with later
301301end
302302
303303function back_codegen! (stmts, ir, selected_calls, fwd_marked, back_marked,
@@ -325,9 +325,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
325325 end
326326end
327327
328- function generate_value_gradient_trie (selected_choices:: Set{RandomChoiceNode} ,
328+ function generate_value_gradient_choices (selected_choices:: Set{RandomChoiceNode} ,
329329 selected_calls:: Set{GenerativeFunctionCallNode} ,
330- value_trie :: Symbol , gradient_trie :: Symbol )
330+ value_choices :: Symbol , gradient_choices :: Symbol )
331331 selected_choices_vec = collect (selected_choices)
332332 quoted_leaf_keys = map ((node) -> QuoteNode (node. addr), selected_choices_vec)
333333 leaf_values = map ((node) -> :(trace.$ (get_value_fieldname (node))), selected_choices_vec)
@@ -337,12 +337,12 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
337337 quoted_internal_keys = map ((node) -> QuoteNode (node. addr), selected_calls_vec)
338338 internal_values = map ((node) -> :(get_choices (trace.$ (get_subtrace_fieldname (node)))),
339339 selected_calls_vec)
340- internal_gradients = map ((node) -> gradient_trie_var (node), selected_calls_vec)
340+ internal_gradients = map ((node) -> gradient_choices_var (node), selected_calls_vec)
341341 quote
342- $ value_trie = StaticChoiceMap (
342+ $ value_choices = StaticChoiceMap (
343343 NamedTuple {($(quoted_leaf_keys...),)} (($ (leaf_values... ),)),
344344 NamedTuple {($(quoted_internal_keys...),)} (($ (internal_values... ),)))
345- $ gradient_trie = StaticChoiceMap (
345+ $ gradient_choices = StaticChoiceMap (
346346 NamedTuple {($(quoted_leaf_keys...),)} (($ (leaf_gradients... ),)),
347347 NamedTuple {($(quoted_internal_keys...),)} (($ (internal_gradients... ),)))
348348 end
@@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
429429 back_codegen! (stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode ())
430430 end
431431
432- # assemble value_trie and gradient_trie
433- value_trie = gensym (" value_trie " )
434- gradient_trie = gensym (" gradient_trie " )
435- push! (stmts, generate_value_gradient_trie (selected_choices, selected_calls,
436- value_trie, gradient_trie ))
432+ # assemble value_choices and gradient_choices
433+ value_choices = gensym (" value_choices " )
434+ gradient_choices = gensym (" gradient_choices " )
435+ push! (stmts, generate_value_gradient_choices (selected_choices, selected_calls,
436+ value_choices, gradient_choices ))
437437
438438 # gradients with respect to inputs
439439 arg_grad = (node:: ArgumentNode ) -> node. compute_grad ? gradient_var (node) : :(nothing )
440440 input_grads = Expr (:tuple , map (arg_grad, ir. arg_nodes)... )
441441
442442 # return values
443- push! (stmts, :(return ($ input_grads, $ value_trie , $ gradient_trie )))
443+ push! (stmts, :(return ($ input_grads, $ value_choices , $ gradient_choices )))
444444
445445 Expr (:block , stmts... )
446446end
0 commit comments