@@ -275,7 +275,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
275275
276276 if node in fwd_marked
277277 input_grads = gensym (" call_input_grads" )
278- choice_value = value_choices_var (node)
278+ value_choices = value_choices_var (node)
279279 choice_gradient = gradient_choices_var (node)
280280 subtrace_fieldname = get_subtrace_fieldname (node)
281281 call_selection = gensym (" call_selection" )
@@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
285285 push! (stmts, :($ call_selection = EmptySelection ()))
286286 end
287287 retval_grad = node in back_marked ? gradient_var (node) : :(nothing )
288- push! (stmts, :(($ input_grads, $ choice_value , $ choice_gradient) = choice_gradients (
288+ push! (stmts, :(($ input_grads, $ value_choices , $ choice_gradient) = choice_gradients (
289289 trace.$ subtrace_fieldname, $ call_selection, $ retval_grad)))
290290 end
291291
@@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
297297 end
298298 end
299299
300- # NOTE: the choice_value and choice_gradient are dealt with later
300+ # NOTE: the value_choices and choice_gradient are dealt with later
301301end
302302
303303function back_codegen! (stmts, ir, selected_calls, fwd_marked, back_marked,
327327
328328function generate_value_choice_gradient (selected_choices:: Set{RandomChoiceNode} ,
329329 selected_calls:: Set{GenerativeFunctionCallNode} ,
330- choice_value :: Symbol , choice_gradient:: Symbol )
330+ value_choices :: Symbol , choice_gradient:: Symbol )
331331 selected_choices_vec = collect (selected_choices)
332332 quoted_leaf_keys = map ((node) -> QuoteNode (node. addr), selected_choices_vec)
333333 leaf_values = map ((node) -> :(trace.$ (get_value_fieldname (node))), selected_choices_vec)
@@ -339,7 +339,7 @@ function generate_value_choice_gradient(selected_choices::Set{RandomChoiceNode},
339339 selected_calls_vec)
340340 internal_gradients = map ((node) -> gradient_choices_var (node), selected_calls_vec)
341341 quote
342- $ choice_value = StaticChoiceMap (
342+ $ value_choices = StaticChoiceMap (
343343 NamedTuple {($(quoted_leaf_keys...),)} (($ (leaf_values... ),)),
344344 NamedTuple {($(quoted_internal_keys...),)} (($ (internal_values... ),)))
345345 $ choice_gradient = StaticChoiceMap (
@@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
429429 back_codegen! (stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode ())
430430 end
431431
432- # assemble choice_value and choice_gradient
433- choice_value = gensym (" choice_value " )
432+ # assemble value_choices and choice_gradient
433+ value_choices = gensym (" value_choices " )
434434 choice_gradient = gensym (" choice_gradient" )
435435 push! (stmts, generate_value_choice_gradient (selected_choices, selected_calls,
436- choice_value , choice_gradient))
436+ value_choices , choice_gradient))
437437
438438 # gradients with respect to inputs
439439 arg_grad = (node:: ArgumentNode ) -> node. compute_grad ? gradient_var (node) : :(nothing )
440440 input_grads = Expr (:tuple , map (arg_grad, ir. arg_nodes)... )
441441
442442 # return values
443- push! (stmts, :(return ($ input_grads, $ choice_value , $ choice_gradient)))
443+ push! (stmts, :(return ($ input_grads, $ value_choices , $ choice_gradient)))
444444
445445 Expr (:block , stmts... )
446446end
0 commit comments