@@ -8,7 +8,7 @@ const value_choices_prefix = gensym("value_choices")
88value_choices_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(value_choices_prefix) _$(node. addr) " )
99
1010const gradient_choices_prefix = gensym (" gradient_choices" )
11- choice_gradient_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(choice_gradient_prefix ) _$(node. addr) " )
11+ gradient_choices_var (node:: GenerativeFunctionCallNode ) = Symbol (" $(gradient_choices_prefix ) _$(node. addr) " )
1212
1313const tape_prefix = gensym (" tape" )
1414tape_var (node:: JuliaNode ) = Symbol (" $(tape_prefix) _$(node. name) " )
@@ -276,7 +276,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
276276 if node in fwd_marked
277277 input_grads = gensym (" call_input_grads" )
278278 choice_value = value_choices_var (node)
279- choice_gradient = choice_gradient_var (node)
279+ choice_gradient = gradient_choices_var (node)
280280 subtrace_fieldname = get_subtrace_fieldname (node)
281281 call_selection = gensym (" call_selection" )
282282 if node in selected_calls
@@ -337,7 +337,7 @@ function generate_value_choice_gradient(selected_choices::Set{RandomChoiceNode},
337337 quoted_internal_keys = map ((node) -> QuoteNode (node. addr), selected_calls_vec)
338338 internal_values = map ((node) -> :(get_choices (trace.$ (get_subtrace_fieldname (node)))),
339339 selected_calls_vec)
340- internal_gradients = map ((node) -> choice_gradient_var (node), selected_calls_vec)
340+ internal_gradients = map ((node) -> gradient_choices_var (node), selected_calls_vec)
341341 quote
342342 $ choice_value = StaticChoiceMap (
343343 NamedTuple {($(quoted_leaf_keys...),)} (($ (leaf_values... ),)),
0 commit comments