Skip to content

Commit 027dbc6

Browse files
authored
Merge pull request #562 from SamuelBrand1/vars-to-choice
Issue 561: Rename variables to reflect that they are choice maps
2 parents aa45748 + a33dfdb commit 027dbc6

File tree

10 files changed

+69
-68
lines changed

10 files changed

+69
-68
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

src/gen_fn_interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ function accumulate_param_gradients!(trace)
373373
end
374374

375375
"""
376-
(arg_grads, choice_values, choice_grads) = choice_gradients(
376+
(arg_grads, value_choices, gradient_choices) = choice_gradients(
377377
trace, selection=EmptySelection(), retgrad=nothing)
378378
379379
Given a previous trace \$(x, t)\$ (`trace`) and a gradient with respect to the
@@ -389,15 +389,15 @@ If an argument is not annotated with `(grad)`, the corresponding value in
389389
`arg_grads` will be `nothing`.
390390
391391
Also given a set of addresses \$A\$ (`selection`) that are continuous-valued
392-
random choices, return the folowing gradient (`choice_grads`) with respect to
392+
random choices, return the folowing gradient (`gradient_choices`) with respect to
393393
the values of these choices:
394394
```math
395395
∇_A \\left( \\log P(t; x) + J \\right)
396396
```
397397
The gradient is represented as a choicemap whose value at (hierarchical)
398398
address `addr` is \$∂J/∂t[\\texttt{addr}]\$.
399399
400-
Also return the choicemap (`choice_values`) that is the restriction of \$t\$ to \$A\$.
400+
Also return the choicemap (`value_choices`) that is the restriction of \$t\$ to \$A\$.
401401
"""
402402
function choice_gradients(trace, selection::Selection, retgrad)
403403
error("Not implemented")

src/inference/hmc.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ function hmc(
8888

8989
# run leapfrog dynamics
9090
new_trace = trace
91-
(_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
92-
values = to_array(values_trie, Float64)
93-
gradient = to_array(gradient_trie, Float64)
91+
(_, value_choices, gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
92+
values = to_array(value_choices, Float64)
93+
gradient = to_array(gradient_choices, Float64)
9494
momenta = sample_momenta(length(values), metric)
9595
prev_momenta_score = assess_momenta(momenta, metric)
9696
for step=1:L
@@ -102,10 +102,10 @@ function hmc(
102102
values += eps * momenta
103103

104104
# get new gradient
105-
values_trie = from_array(values_trie, values)
106-
(new_trace, _, _) = update(new_trace, args, argdiffs, values_trie)
107-
(_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
108-
gradient = to_array(gradient_trie, Float64)
105+
value_choices = from_array(value_choices, values)
106+
(new_trace, _, _) = update(new_trace, args, argdiffs, value_choices)
107+
(_, _, gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
108+
gradient = to_array(gradient_choices, Float64)
109109

110110
# half step on momenta
111111
momenta += (eps / 2) * gradient

src/inference/mala.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ function mala(
1717
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
1818

1919
# forward proposal
20-
(_, values_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad)
21-
values = to_array(values_trie, Float64)
22-
gradient = to_array(gradient_trie, Float64)
20+
(_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
21+
values = to_array(value_choices, Float64)
22+
gradient = to_array(gradient_choices, Float64)
2323
forward_mu = values + tau * gradient
2424
forward_score = 0.
2525
proposed_values = Vector{Float64}(undef, length(values))
@@ -29,14 +29,14 @@ function mala(
2929
end
3030

3131
# evaluate model weight
32-
constraints = from_array(values_trie, proposed_values)
32+
constraints = from_array(value_choices, proposed_values)
3333
(new_trace, weight, _, discard) = update(trace,
3434
args, argdiffs, constraints)
3535
check && check_observations(get_choices(new_trace), observations)
3636

3737
# backward proposal
38-
(_, _, backward_gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
39-
backward_gradient = to_array(backward_gradient_trie, Float64)
38+
(_, _, backward_gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
39+
backward_gradient = to_array(backward_gradient_choices, Float64)
4040
@assert length(backward_gradient) == length(values)
4141
backward_score = 0.
4242
backward_mu = proposed_values + tau * backward_gradient

src/inference/map_optimize.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ function map_optimize(trace, selection::Selection;
1111
argdiffs = map((_) -> NoChange(), args)
1212
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
1313

14-
(_, values, gradient) = choice_gradients(trace, selection, retval_grad)
15-
values_vec = to_array(values, Float64)
16-
gradient_vec = to_array(gradient, Float64)
14+
(_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
15+
values_vec = to_array(value_choices, Float64)
16+
gradient_vec = to_array(gradient_choices, Float64)
1717
step_size = max_step_size
1818
score = get_score(trace)
1919
while true
2020
new_values_vec = values_vec + gradient_vec * step_size
21-
values = from_array(values, new_values_vec)
21+
value_choices = from_array(value_choices, new_values_vec)
2222
# TODO discard and weight are not actually needed, there should be a more specialized variant
23-
(new_trace, _, _, discard) = update(trace, args, argdiffs, values)
23+
(new_trace, _, _, discard) = update(trace, args, argdiffs, value_choices)
2424
new_score = get_score(new_trace)
2525
change = new_score - score
2626
if verbose

src/modeling_library/call_at/call_at.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ end
144144

145145
function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad)
146146
subselection = selection[trace.key]
147-
(kernel_input_grads, value_submap, gradient_submap) = choice_gradients(
147+
(kernel_input_grads, value_choices_submap, gradient_choices_submap) = choice_gradients(
148148
trace.subtrace, subselection, retval_grad)
149149
input_grads = (kernel_input_grads..., nothing)
150-
value_choices = CallAtChoiceMap(trace.key, value_submap)
151-
gradient_choices = CallAtChoiceMap(trace.key, gradient_submap)
150+
value_choices = CallAtChoiceMap(trace.key, value_choices_submap)
151+
gradient_choices = CallAtChoiceMap(trace.key, gradient_choices_submap)
152152
(input_grads, value_choices, gradient_choices)
153153
end
154154

src/static_ir/backprop.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ struct BackpropParamsMode end
44
const gradient_prefix = gensym("gradient")
55
gradient_var(node::StaticIRNode) = Symbol("$(gradient_prefix)_$(node.name)")
66

7-
const value_trie_prefix = gensym("value_trie")
8-
value_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(value_trie_prefix)_$(node.addr)")
7+
const value_choices_prefix = gensym("value_choices")
8+
value_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(value_choices_prefix)_$(node.addr)")
99

10-
const gradient_trie_prefix = gensym("gradient_trie")
11-
gradient_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_trie_prefix)_$(node.addr)")
10+
const gradient_choices_prefix = gensym("gradient_choices")
11+
gradient_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_choices_prefix)_$(node.addr)")
1212

1313
const tape_prefix = gensym("tape")
1414
tape_var(node::JuliaNode) = Symbol("$(tape_prefix)_$(node.name)")
@@ -128,7 +128,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
128128
# we need the value for initializing gradient to zero (to get the type
129129
# and e.g. shape), and for reference by other nodes during
130130
# back_codegen! we could be more selective about which JuliaNodes need
131-
# to be evalutaed, that is a performance optimization for the future
131+
# to be evaluated, that is a performance optimization for the future
132132
args = map((input_node) -> input_node.name, node.inputs)
133133
push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...))))
134134
end
@@ -275,8 +275,8 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
275275

276276
if node in fwd_marked
277277
input_grads = gensym("call_input_grads")
278-
value_trie = value_trie_var(node)
279-
gradient_trie = gradient_trie_var(node)
278+
value_choices = value_choices_var(node)
279+
gradient_choices = gradient_choices_var(node)
280280
subtrace_fieldname = get_subtrace_fieldname(node)
281281
call_selection = gensym("call_selection")
282282
if node in selected_calls
@@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
285285
push!(stmts, :($call_selection = EmptySelection()))
286286
end
287287
retval_grad = node in back_marked ? gradient_var(node) : :(nothing)
288-
push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients(
288+
push!(stmts, :(($input_grads, $value_choices, $gradient_choices) = choice_gradients(
289289
trace.$subtrace_fieldname, $call_selection, $retval_grad)))
290290
end
291291

@@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
297297
end
298298
end
299299

300-
# NOTE: the value_trie and gradient_trie are dealt with later
300+
# NOTE: the value_choices and gradient_choices are dealt with later
301301
end
302302

303303
function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
@@ -325,9 +325,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
325325
end
326326
end
327327

328-
function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
328+
function generate_value_gradient_choices(selected_choices::Set{RandomChoiceNode},
329329
selected_calls::Set{GenerativeFunctionCallNode},
330-
value_trie::Symbol, gradient_trie::Symbol)
330+
value_choices::Symbol, gradient_choices::Symbol)
331331
selected_choices_vec = collect(selected_choices)
332332
quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec)
333333
leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec)
@@ -337,12 +337,12 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
337337
quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec)
338338
internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))),
339339
selected_calls_vec)
340-
internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec)
340+
internal_gradients = map((node) -> gradient_choices_var(node), selected_calls_vec)
341341
quote
342-
$value_trie = StaticChoiceMap(
342+
$value_choices = StaticChoiceMap(
343343
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),
344344
NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),)))
345-
$gradient_trie = StaticChoiceMap(
345+
$gradient_choices = StaticChoiceMap(
346346
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)),
347347
NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),)))
348348
end
@@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
429429
back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode())
430430
end
431431

432-
# assemble value_trie and gradient_trie
433-
value_trie = gensym("value_trie")
434-
gradient_trie = gensym("gradient_trie")
435-
push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls,
436-
value_trie, gradient_trie))
432+
# assemble value_choices and gradient_choices
433+
value_choices = gensym("value_choices")
434+
gradient_choices = gensym("gradient_choices")
435+
push!(stmts, generate_value_gradient_choices(selected_choices, selected_calls,
436+
value_choices, gradient_choices))
437437

438438
# gradients with respect to inputs
439439
arg_grad = (node::ArgumentNode) -> node.compute_grad ? gradient_var(node) : :(nothing)
440440
input_grads = Expr(:tuple, map(arg_grad, ir.arg_nodes)...)
441441

442442
# return values
443-
push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie)))
443+
push!(stmts, :(return ($input_grads, $value_choices, $gradient_choices)))
444444

445445
Expr(:block, stmts...)
446446
end

test/dsl/dynamic_dsl.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,15 @@ end
345345
# check input gradient
346346
@test isapprox(mu_a_grad, finite_diff(f, (mu_a, a, b, z, out), 1, dx))
347347

348-
# check value trie
348+
# check value from choicemap
349349
@test get_value(choices, :a) == a
350350
@test get_value(choices, :out) == out
351351
@test get_value(choices, :bar => :z) == z
352352
@test !has_value(choices, :b) # was not selected
353353
@test length(collect(get_submaps_shallow(choices))) == 1
354354
@test length(collect(get_values_shallow(choices))) == 2
355355

356-
# check gradient trie
356+
# check gradient from choicemap
357357
@test length(collect(get_submaps_shallow(gradients))) == 1
358358
@test length(collect(get_values_shallow(gradients))) == 2
359359
@test !has_value(gradients, :b) # was not selected

test/inference/hmc.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
(new_trace, accepted) = hmc(trace, select(:x))
2121

2222
# For Normal(0,1), grad should be -x
23-
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0)
24-
values = to_array(values_trie, Float64)
25-
grad = to_array(gradient_trie, Float64)
23+
(_, value_choices, gradient_choices) = choice_gradients(trace, select(:x), 0)
24+
values = to_array(value_choices, Float64)
25+
grad = to_array(gradient_choices, Float64)
2626
@test values -grad
2727

2828
# smoke test with vector metric
@@ -58,9 +58,9 @@
5858
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)
5959

6060
# For each Normal(0,1), grad should be -x
61-
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0)
62-
values = to_array(values_trie, Float64)
63-
grad = to_array(gradient_trie, Float64)
61+
(_, value_choices, gradient_choices) = choice_gradients(trace, select(:x, :y), 0)
62+
values = to_array(value_choices, Float64)
63+
grad = to_array(gradient_choices, Float64)
6464
@test values -grad
6565

6666
# smoke test with Diagonal metric and retval gradient

test/static_ir/static_ir.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -351,26 +351,26 @@ end
351351
selection = select(:bar => :z, :a, :out)
352352
selection = StaticSelection(selection)
353353
retval_grad = 2.
354-
((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad)
354+
((mu_a_grad,), value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
355355

356356
# check input gradient
357357
@test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx))
358358

359-
# check value trie
360-
@test get_value(value_trie, :a) == a
361-
@test get_value(value_trie, :out) == out
362-
@test get_value(value_trie, :bar => :z) == z
363-
@test !has_value(value_trie, :b) # was not selected
364-
@test length(get_submaps_shallow(value_trie)) == 1
365-
@test length(get_values_shallow(value_trie)) == 2
366-
367-
# check gradient trie
368-
@test length(get_submaps_shallow(gradient_trie)) == 1
369-
@test length(get_values_shallow(gradient_trie)) == 2
370-
@test !has_value(gradient_trie, :b) # was not selected
371-
@test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx))
372-
@test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))
373-
@test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx))
359+
# check value from choice map
360+
@test get_value(value_choices, :a) == a
361+
@test get_value(value_choices, :out) == out
362+
@test get_value(value_choices, :bar => :z) == z
363+
@test !has_value(value_choices, :b) # was not selected
364+
@test length(get_submaps_shallow(value_choices)) == 1
365+
@test length(get_values_shallow(value_choices)) == 2
366+
367+
# check gradient from choice map
368+
@test length(get_submaps_shallow(gradient_choices)) == 1
369+
@test length(get_values_shallow(gradient_choices)) == 2
370+
@test !has_value(gradient_choices, :b) # was not selected
371+
@test isapprox(get_value(gradient_choices, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx))
372+
@test isapprox(get_value(gradient_choices, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))
373+
@test isapprox(get_value(gradient_choices, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx))
374374

375375
# reset the trainable parameter gradient
376376
zero_param_grad!(foo, :theta)

0 commit comments

Comments
 (0)