Skip to content

Commit 239603c

Browse files
committed
to gradient_choices
1 parent b9020c1 commit 239603c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/static_ir/backprop.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const value_choices_prefix = gensym("value_choices")
88
value_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(value_choices_prefix)_$(node.addr)")
99

1010
const gradient_choices_prefix = gensym("gradient_choices")
11-
choice_gradient_var(node::GenerativeFunctionCallNode) = Symbol("$(choice_gradient_prefix)_$(node.addr)")
11+
gradient_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_choices_prefix)_$(node.addr)")
1212

1313
const tape_prefix = gensym("tape")
1414
tape_var(node::JuliaNode) = Symbol("$(tape_prefix)_$(node.name)")
@@ -276,7 +276,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
276276
if node in fwd_marked
277277
input_grads = gensym("call_input_grads")
278278
choice_value = value_choices_var(node)
279-
choice_gradient = choice_gradient_var(node)
279+
choice_gradient = gradient_choices_var(node)
280280
subtrace_fieldname = get_subtrace_fieldname(node)
281281
call_selection = gensym("call_selection")
282282
if node in selected_calls
@@ -337,7 +337,7 @@ function generate_value_choice_gradient(selected_choices::Set{RandomChoiceNode},
337337
quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec)
338338
internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))),
339339
selected_calls_vec)
340-
internal_gradients = map((node) -> choice_gradient_var(node), selected_calls_vec)
340+
internal_gradients = map((node) -> gradient_choices_var(node), selected_calls_vec)
341341
quote
342342
$choice_value = StaticChoiceMap(
343343
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),

0 commit comments

Comments
 (0)