Skip to content

Commit 23ee010

Browse files
committed
to value_choices
1 parent 239603c commit 23ee010

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/static_ir/backprop.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
275275

276276
if node in fwd_marked
277277
input_grads = gensym("call_input_grads")
278-
choice_value = value_choices_var(node)
278+
value_choices = value_choices_var(node)
279279
choice_gradient = gradient_choices_var(node)
280280
subtrace_fieldname = get_subtrace_fieldname(node)
281281
call_selection = gensym("call_selection")
@@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
285285
push!(stmts, :($call_selection = EmptySelection()))
286286
end
287287
retval_grad = node in back_marked ? gradient_var(node) : :(nothing)
288-
push!(stmts, :(($input_grads, $choice_value, $choice_gradient) = choice_gradients(
288+
push!(stmts, :(($input_grads, $value_choices, $choice_gradient) = choice_gradients(
289289
trace.$subtrace_fieldname, $call_selection, $retval_grad)))
290290
end
291291

@@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
297297
end
298298
end
299299

300-
# NOTE: the choice_value and choice_gradient are dealt with later
300+
# NOTE: the value_choices and choice_gradient are dealt with later
301301
end
302302

303303
function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
@@ -327,7 +327,7 @@ end
327327

328328
function generate_value_choice_gradient(selected_choices::Set{RandomChoiceNode},
329329
selected_calls::Set{GenerativeFunctionCallNode},
330-
choice_value::Symbol, choice_gradient::Symbol)
330+
value_choices::Symbol, choice_gradient::Symbol)
331331
selected_choices_vec = collect(selected_choices)
332332
quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec)
333333
leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec)
@@ -339,7 +339,7 @@ function generate_value_choice_gradient(selected_choices::Set{RandomChoiceNode},
339339
selected_calls_vec)
340340
internal_gradients = map((node) -> gradient_choices_var(node), selected_calls_vec)
341341
quote
342-
$choice_value = StaticChoiceMap(
342+
$value_choices = StaticChoiceMap(
343343
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),
344344
NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),)))
345345
$choice_gradient = StaticChoiceMap(
@@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
429429
back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode())
430430
end
431431

432-
# assemble choice_value and choice_gradient
433-
choice_value = gensym("choice_value")
432+
# assemble value_choices and choice_gradient
433+
value_choices = gensym("value_choices")
434434
choice_gradient = gensym("choice_gradient")
435435
push!(stmts, generate_value_choice_gradient(selected_choices, selected_calls,
436-
choice_value, choice_gradient))
436+
value_choices, choice_gradient))
437437

438438
# gradients with respect to inputs
439439
arg_grad = (node::ArgumentNode) -> node.compute_grad ? gradient_var(node) : :(nothing)
440440
input_grads = Expr(:tuple, map(arg_grad, ir.arg_nodes)...)
441441

442442
# return values
443-
push!(stmts, :(return ($input_grads, $choice_value, $choice_gradient)))
443+
push!(stmts, :(return ($input_grads, $value_choices, $choice_gradient)))
444444

445445
Expr(:block, stmts...)
446446
end

0 commit comments

Comments
 (0)