Skip to content

Commit b9020c1

Browse files
committed
value choices refactor
1 parent 90807bf commit b9020c1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/static_ir/backprop.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ struct BackpropParamsMode end
44
const gradient_prefix = gensym("gradient")
55
gradient_var(node::StaticIRNode) = Symbol("$(gradient_prefix)_$(node.name)")
66

7-
const choice_value_prefix = gensym("choice_value")
8-
choice_value_var(node::GenerativeFunctionCallNode) = Symbol("$(choice_value_prefix)_$(node.addr)")
7+
const value_choices_prefix = gensym("value_choices")
8+
value_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(value_choices_prefix)_$(node.addr)")
99

10-
const choice_gradient_prefix = gensym("choice_gradient")
10+
const gradient_choices_prefix = gensym("gradient_choices")
1111
choice_gradient_var(node::GenerativeFunctionCallNode) = Symbol("$(choice_gradient_prefix)_$(node.addr)")
1212

1313
const tape_prefix = gensym("tape")
@@ -275,7 +275,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
275275

276276
if node in fwd_marked
277277
input_grads = gensym("call_input_grads")
278-
choice_value = choice_value_var(node)
278+
choice_value = value_choices_var(node)
279279
choice_gradient = choice_gradient_var(node)
280280
subtrace_fieldname = get_subtrace_fieldname(node)
281281
call_selection = gensym("call_selection")

0 commit comments

Comments
 (0)