Skip to content

Commit c39c74d

Browse files
committed
Rename value_trie and gradient_trie to choice_value and choice_gradient
Refactored code and tests to consistently use 'choice_value' and 'choice_gradient' instead of 'value_trie' and 'gradient_trie' for clarity and alignment with choicemap terminology. Updated variable names, function signatures, and related comments across inference, static IR, and test files.
1 parent aa45748 commit c39c74d

File tree

5 files changed

+45
-44
lines changed

5 files changed

+45
-44
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

src/inference/mala.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ function mala(
1717
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
1818

1919
# forward proposal
20-
(_, values_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad)
21-
values = to_array(values_trie, Float64)
22-
gradient = to_array(gradient_trie, Float64)
20+
(_, choice_values, choice_gradient) = choice_gradients(trace, selection, retval_grad)
21+
values = to_array(choice_values, Float64)
22+
gradient = to_array(choice_gradient, Float64)
2323
forward_mu = values + tau * gradient
2424
forward_score = 0.
2525
proposed_values = Vector{Float64}(undef, length(values))
@@ -29,14 +29,14 @@ function mala(
2929
end
3030

3131
# evaluate model weight
32-
constraints = from_array(values_trie, proposed_values)
32+
constraints = from_array(choice_values, proposed_values)
3333
(new_trace, weight, _, discard) = update(trace,
3434
args, argdiffs, constraints)
3535
check && check_observations(get_choices(new_trace), observations)
3636

3737
# backward proposal
38-
(_, _, backward_gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
39-
backward_gradient = to_array(backward_gradient_trie, Float64)
38+
(_, _, backward_choice_gradient) = choice_gradients(new_trace, selection, retval_grad)
39+
backward_gradient = to_array(backward_choice_gradient, Float64)
4040
@assert length(backward_gradient) == length(values)
4141
backward_score = 0.
4242
backward_mu = proposed_values + tau * backward_gradient

src/static_ir/backprop.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ struct BackpropParamsMode end
44
const gradient_prefix = gensym("gradient")
55
gradient_var(node::StaticIRNode) = Symbol("$(gradient_prefix)_$(node.name)")
66

7-
const value_trie_prefix = gensym("value_trie")
8-
value_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(value_trie_prefix)_$(node.addr)")
7+
const choice_value_prefix = gensym("choice_value")
8+
choice_value_var(node::GenerativeFunctionCallNode) = Symbol("$(choice_value_prefix)_$(node.addr)")
99

10-
const gradient_trie_prefix = gensym("gradient_trie")
11-
gradient_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_trie_prefix)_$(node.addr)")
10+
const choice_gradient_prefix = gensym("choice_gradient")
11+
choice_gradient_var(node::GenerativeFunctionCallNode) = Symbol("$(choice_gradient_prefix)_$(node.addr)")
1212

1313
const tape_prefix = gensym("tape")
1414
tape_var(node::JuliaNode) = Symbol("$(tape_prefix)_$(node.name)")
@@ -128,7 +128,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
128128
# we need the value for initializing gradient to zero (to get the type
129129
# and e.g. shape), and for reference by other nodes during
130130
# back_codegen! we could be more selective about which JuliaNodes need
131-
# to be evalutaed, that is a performance optimization for the future
131+
# to be evaluated, that is a performance optimization for the future
132132
args = map((input_node) -> input_node.name, node.inputs)
133133
push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...))))
134134
end
@@ -275,8 +275,8 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
275275

276276
if node in fwd_marked
277277
input_grads = gensym("call_input_grads")
278-
value_trie = value_trie_var(node)
279-
gradient_trie = gradient_trie_var(node)
278+
choice_value = choice_value_var(node)
279+
choice_gradient = choice_gradient_var(node)
280280
subtrace_fieldname = get_subtrace_fieldname(node)
281281
call_selection = gensym("call_selection")
282282
if node in selected_calls
@@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
285285
push!(stmts, :($call_selection = EmptySelection()))
286286
end
287287
retval_grad = node in back_marked ? gradient_var(node) : :(nothing)
288-
push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients(
288+
push!(stmts, :(($input_grads, $choice_value, $choice_gradient) = choice_gradients(
289289
trace.$subtrace_fieldname, $call_selection, $retval_grad)))
290290
end
291291

@@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
297297
end
298298
end
299299

300-
# NOTE: the value_trie and gradient_trie are dealt with later
300+
# NOTE: the choice_value and choice_gradient are dealt with later
301301
end
302302

303303
function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
@@ -325,9 +325,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
325325
end
326326
end
327327

328-
function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
328+
function generate_value_choice_gradient(selected_choices::Set{RandomChoiceNode},
329329
selected_calls::Set{GenerativeFunctionCallNode},
330-
value_trie::Symbol, gradient_trie::Symbol)
330+
choice_value::Symbol, choice_gradient::Symbol)
331331
selected_choices_vec = collect(selected_choices)
332332
quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec)
333333
leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec)
@@ -337,12 +337,12 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
337337
quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec)
338338
internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))),
339339
selected_calls_vec)
340-
internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec)
340+
internal_gradients = map((node) -> choice_gradient_var(node), selected_calls_vec)
341341
quote
342-
$value_trie = StaticChoiceMap(
342+
$choice_value = StaticChoiceMap(
343343
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),
344344
NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),)))
345-
$gradient_trie = StaticChoiceMap(
345+
$choice_gradient = StaticChoiceMap(
346346
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)),
347347
NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),)))
348348
end
@@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
429429
back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode())
430430
end
431431

432-
# assemble value_trie and gradient_trie
433-
value_trie = gensym("value_trie")
434-
gradient_trie = gensym("gradient_trie")
435-
push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls,
436-
value_trie, gradient_trie))
432+
# assemble choice_value and choice_gradient
433+
choice_value = gensym("choice_value")
434+
choice_gradient = gensym("choice_gradient")
435+
push!(stmts, generate_value_choice_gradient(selected_choices, selected_calls,
436+
choice_value, choice_gradient))
437437

438438
# gradients with respect to inputs
439439
arg_grad = (node::ArgumentNode) -> node.compute_grad ? gradient_var(node) : :(nothing)
440440
input_grads = Expr(:tuple, map(arg_grad, ir.arg_nodes)...)
441441

442442
# return values
443-
push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie)))
443+
push!(stmts, :(return ($input_grads, $choice_value, $choice_gradient)))
444444

445445
Expr(:block, stmts...)
446446
end

test/dsl/dynamic_dsl.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,15 @@ end
345345
# check input gradient
346346
@test isapprox(mu_a_grad, finite_diff(f, (mu_a, a, b, z, out), 1, dx))
347347

348-
# check value trie
348+
# check value from choicemap
349349
@test get_value(choices, :a) == a
350350
@test get_value(choices, :out) == out
351351
@test get_value(choices, :bar => :z) == z
352352
@test !has_value(choices, :b) # was not selected
353353
@test length(collect(get_submaps_shallow(choices))) == 1
354354
@test length(collect(get_values_shallow(choices))) == 2
355355

356-
# check gradient trie
356+
# check gradient from choicemap
357357
@test length(collect(get_submaps_shallow(gradients))) == 1
358358
@test length(collect(get_values_shallow(gradients))) == 2
359359
@test !has_value(gradients, :b) # was not selected

test/static_ir/static_ir.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -351,26 +351,26 @@ end
351351
selection = select(:bar => :z, :a, :out)
352352
selection = StaticSelection(selection)
353353
retval_grad = 2.
354-
((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad)
354+
((mu_a_grad,), choice_value, choice_gradient) = choice_gradients(trace, selection, retval_grad)
355355

356356
# check input gradient
357357
@test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx))
358358

359-
# check value trie
360-
@test get_value(value_trie, :a) == a
361-
@test get_value(value_trie, :out) == out
362-
@test get_value(value_trie, :bar => :z) == z
363-
@test !has_value(value_trie, :b) # was not selected
364-
@test length(get_submaps_shallow(value_trie)) == 1
365-
@test length(get_values_shallow(value_trie)) == 2
366-
367-
# check gradient trie
368-
@test length(get_submaps_shallow(gradient_trie)) == 1
369-
@test length(get_values_shallow(gradient_trie)) == 2
370-
@test !has_value(gradient_trie, :b) # was not selected
371-
@test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx))
372-
@test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))
373-
@test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx))
359+
# check value from choice map
360+
@test get_value(choice_value, :a) == a
361+
@test get_value(choice_value, :out) == out
362+
@test get_value(choice_value, :bar => :z) == z
363+
@test !has_value(choice_value, :b) # was not selected
364+
@test length(get_submaps_shallow(choice_value)) == 1
365+
@test length(get_values_shallow(choice_value)) == 2
366+
367+
# check gradient from choice map
368+
@test length(get_submaps_shallow(choice_gradient)) == 1
369+
@test length(get_values_shallow(choice_gradient)) == 2
370+
@test !has_value(choice_gradient, :b) # was not selected
371+
@test isapprox(get_value(choice_gradient, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx))
372+
@test isapprox(get_value(choice_gradient, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))
373+
@test isapprox(get_value(choice_gradient, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx))
374374

375375
# reset the trainable parameter gradient
376376
zero_param_grad!(foo, :theta)

0 commit comments

Comments
 (0)