Skip to content
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
6 changes: 3 additions & 3 deletions src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ function accumulate_param_gradients!(trace)
end

"""
(arg_grads, choice_values, choice_grads) = choice_gradients(
(arg_grads, value_choices, gradient_choices) = choice_gradients(
trace, selection=EmptySelection(), retgrad=nothing)

Given a previous trace \$(x, t)\$ (`trace`) and a gradient with respect to the
Expand All @@ -389,15 +389,15 @@ If an argument is not annotated with `(grad)`, the corresponding value in
`arg_grads` will be `nothing`.

Also given a set of addresses \$A\$ (`selection`) that are continuous-valued
random choices, return the folowing gradient (`choice_grads`) with respect to
random choices, return the folowing gradient (`gradient_choices`) with respect to
the values of these choices:
```math
∇_A \\left( \\log P(t; x) + J \\right)
```
The gradient is represented as a choicemap whose value at (hierarchical)
address `addr` is \$∂J/∂t[\\texttt{addr}]\$.

Also return the choicemap (`choice_values`) that is the restriction of \$t\$ to \$A\$.
Also return the choicemap (`value_choices`) that is the restriction of \$t\$ to \$A\$.
"""
function choice_gradients(trace, selection::Selection, retgrad)
error("Not implemented")
Expand Down
14 changes: 7 additions & 7 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ function hmc(

# run leapfrog dynamics
new_trace = trace
(_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
values = to_array(values_trie, Float64)
gradient = to_array(gradient_trie, Float64)
(_, value_choices, gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
values = to_array(value_choices, Float64)
gradient = to_array(gradient_choices, Float64)
momenta = sample_momenta(length(values), metric)
prev_momenta_score = assess_momenta(momenta, metric)
for step=1:L
Expand All @@ -102,10 +102,10 @@ function hmc(
values += eps * momenta

# get new gradient
values_trie = from_array(values_trie, values)
(new_trace, _, _) = update(new_trace, args, argdiffs, values_trie)
(_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
gradient = to_array(gradient_trie, Float64)
value_choices = from_array(value_choices, values)
(new_trace, _, _) = update(new_trace, args, argdiffs, value_choices)
(_, _, gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
gradient = to_array(gradient_choices, Float64)

# half step on momenta
momenta += (eps / 2) * gradient
Expand Down
12 changes: 6 additions & 6 deletions src/inference/mala.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ function mala(
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

# forward proposal
(_, values_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad)
values = to_array(values_trie, Float64)
gradient = to_array(gradient_trie, Float64)
(_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
values = to_array(value_choices, Float64)
gradient = to_array(gradient_choices, Float64)
forward_mu = values + tau * gradient
forward_score = 0.
proposed_values = Vector{Float64}(undef, length(values))
Expand All @@ -29,14 +29,14 @@ function mala(
end

# evaluate model weight
constraints = from_array(values_trie, proposed_values)
constraints = from_array(value_choices, proposed_values)
(new_trace, weight, _, discard) = update(trace,
args, argdiffs, constraints)
check && check_observations(get_choices(new_trace), observations)

# backward proposal
(_, _, backward_gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
backward_gradient = to_array(backward_gradient_trie, Float64)
(_, _, backward_gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
backward_gradient = to_array(backward_gradient_choices, Float64)
@assert length(backward_gradient) == length(values)
backward_score = 0.
backward_mu = proposed_values + tau * backward_gradient
Expand Down
10 changes: 5 additions & 5 deletions src/inference/map_optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ function map_optimize(trace, selection::Selection;
argdiffs = map((_) -> NoChange(), args)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

(_, values, gradient) = choice_gradients(trace, selection, retval_grad)
values_vec = to_array(values, Float64)
gradient_vec = to_array(gradient, Float64)
(_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
values_vec = to_array(value_choices, Float64)
gradient_vec = to_array(gradient_choices, Float64)
step_size = max_step_size
score = get_score(trace)
while true
new_values_vec = values_vec + gradient_vec * step_size
values = from_array(values, new_values_vec)
value_choices = from_array(value_choices, new_values_vec)
# TODO discard and weight are not actually needed, there should be a more specialized variant
(new_trace, _, _, discard) = update(trace, args, argdiffs, values)
(new_trace, _, _, discard) = update(trace, args, argdiffs, value_choices)
new_score = get_score(new_trace)
change = new_score - score
if verbose
Expand Down
6 changes: 3 additions & 3 deletions src/modeling_library/call_at/call_at.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ end

function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad)
subselection = selection[trace.key]
(kernel_input_grads, value_submap, gradient_submap) = choice_gradients(
(kernel_input_grads, value_choices_submap, gradient_choices_submap) = choice_gradients(
trace.subtrace, subselection, retval_grad)
input_grads = (kernel_input_grads..., nothing)
value_choices = CallAtChoiceMap(trace.key, value_submap)
gradient_choices = CallAtChoiceMap(trace.key, gradient_submap)
value_choices = CallAtChoiceMap(trace.key, value_choices_submap)
gradient_choices = CallAtChoiceMap(trace.key, gradient_choices_submap)
(input_grads, value_choices, gradient_choices)
end

Expand Down
40 changes: 20 additions & 20 deletions src/static_ir/backprop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ struct BackpropParamsMode end
const gradient_prefix = gensym("gradient")
gradient_var(node::StaticIRNode) = Symbol("$(gradient_prefix)_$(node.name)")

const value_trie_prefix = gensym("value_trie")
value_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(value_trie_prefix)_$(node.addr)")
const value_choices_prefix = gensym("value_choices")
value_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(value_choices_prefix)_$(node.addr)")

const gradient_trie_prefix = gensym("gradient_trie")
gradient_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_trie_prefix)_$(node.addr)")
const gradient_choices_prefix = gensym("gradient_choices")
gradient_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_choices_prefix)_$(node.addr)")

const tape_prefix = gensym("tape")
tape_var(node::JuliaNode) = Symbol("$(tape_prefix)_$(node.name)")
Expand Down Expand Up @@ -128,7 +128,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode)
# we need the value for initializing gradient to zero (to get the type
# and e.g. shape), and for reference by other nodes during
# back_codegen! we could be more selective about which JuliaNodes need
# to be evalutaed, that is a performance optimization for the future
# to be evaluated, that is a performance optimization for the future
args = map((input_node) -> input_node.name, node.inputs)
push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...))))
end
Expand Down Expand Up @@ -275,8 +275,8 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,

if node in fwd_marked
input_grads = gensym("call_input_grads")
value_trie = value_trie_var(node)
gradient_trie = gradient_trie_var(node)
value_choices = value_choices_var(node)
gradient_choices = gradient_choices_var(node)
subtrace_fieldname = get_subtrace_fieldname(node)
call_selection = gensym("call_selection")
if node in selected_calls
Expand All @@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
push!(stmts, :($call_selection = EmptySelection()))
end
retval_grad = node in back_marked ? gradient_var(node) : :(nothing)
push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients(
push!(stmts, :(($input_grads, $value_choices, $gradient_choices) = choice_gradients(
trace.$subtrace_fieldname, $call_selection, $retval_grad)))
end

Expand All @@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
end
end

# NOTE: the value_trie and gradient_trie are dealt with later
# NOTE: the value_choices and gradient_choices are dealt with later
end

function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
Expand Down Expand Up @@ -325,9 +325,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked,
end
end

function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
function generate_value_gradient_choices(selected_choices::Set{RandomChoiceNode},
selected_calls::Set{GenerativeFunctionCallNode},
value_trie::Symbol, gradient_trie::Symbol)
value_choices::Symbol, gradient_choices::Symbol)
selected_choices_vec = collect(selected_choices)
quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec)
leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec)
Expand All @@ -337,12 +337,12 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode},
quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec)
internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))),
selected_calls_vec)
internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec)
internal_gradients = map((node) -> gradient_choices_var(node), selected_calls_vec)
quote
$value_trie = StaticChoiceMap(
$value_choices = StaticChoiceMap(
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),
NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),)))
$gradient_trie = StaticChoiceMap(
$gradient_choices = StaticChoiceMap(
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)),
NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),)))
end
Expand Down Expand Up @@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type,
back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode())
end

# assemble value_trie and gradient_trie
value_trie = gensym("value_trie")
gradient_trie = gensym("gradient_trie")
push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls,
value_trie, gradient_trie))
# assemble value_choices and gradient_choices
value_choices = gensym("value_choices")
gradient_choices = gensym("gradient_choices")
push!(stmts, generate_value_gradient_choices(selected_choices, selected_calls,
value_choices, gradient_choices))

# gradients with respect to inputs
arg_grad = (node::ArgumentNode) -> node.compute_grad ? gradient_var(node) : :(nothing)
input_grads = Expr(:tuple, map(arg_grad, ir.arg_nodes)...)

# return values
push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie)))
push!(stmts, :(return ($input_grads, $value_choices, $gradient_choices)))

Expr(:block, stmts...)
end
Expand Down
4 changes: 2 additions & 2 deletions test/dsl/dynamic_dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,15 +345,15 @@ end
# check input gradient
@test isapprox(mu_a_grad, finite_diff(f, (mu_a, a, b, z, out), 1, dx))

# check value trie
# check value from choicemap
@test get_value(choices, :a) == a
@test get_value(choices, :out) == out
@test get_value(choices, :bar => :z) == z
@test !has_value(choices, :b) # was not selected
@test length(collect(get_submaps_shallow(choices))) == 1
@test length(collect(get_values_shallow(choices))) == 2

# check gradient trie
# check gradient from choicemap
@test length(collect(get_submaps_shallow(gradients))) == 1
@test length(collect(get_values_shallow(gradients))) == 2
@test !has_value(gradients, :b) # was not selected
Expand Down
12 changes: 6 additions & 6 deletions test/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
(new_trace, accepted) = hmc(trace, select(:x))

# For Normal(0,1), grad should be -x
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0)
values = to_array(values_trie, Float64)
grad = to_array(gradient_trie, Float64)
(_, value_choices, gradient_choices) = choice_gradients(trace, select(:x), 0)
values = to_array(value_choices, Float64)
grad = to_array(gradient_choices, Float64)
@test values ≈ -grad

# smoke test with vector metric
Expand Down Expand Up @@ -58,9 +58,9 @@
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)

# For each Normal(0,1), grad should be -x
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0)
values = to_array(values_trie, Float64)
grad = to_array(gradient_trie, Float64)
(_, value_choices, gradient_choices) = choice_gradients(trace, select(:x, :y), 0)
values = to_array(value_choices, Float64)
grad = to_array(gradient_choices, Float64)
@test values ≈ -grad

# smoke test with Diagonal metric and retval gradient
Expand Down
32 changes: 16 additions & 16 deletions test/static_ir/static_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,26 +351,26 @@ end
selection = select(:bar => :z, :a, :out)
selection = StaticSelection(selection)
retval_grad = 2.
((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad)
((mu_a_grad,), value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)

# check input gradient
@test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx))

# check value trie
@test get_value(value_trie, :a) == a
@test get_value(value_trie, :out) == out
@test get_value(value_trie, :bar => :z) == z
@test !has_value(value_trie, :b) # was not selected
@test length(get_submaps_shallow(value_trie)) == 1
@test length(get_values_shallow(value_trie)) == 2

# check gradient trie
@test length(get_submaps_shallow(gradient_trie)) == 1
@test length(get_values_shallow(gradient_trie)) == 2
@test !has_value(gradient_trie, :b) # was not selected
@test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx))
@test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))
@test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx))
# check value from choice map
@test get_value(value_choices, :a) == a
@test get_value(value_choices, :out) == out
@test get_value(value_choices, :bar => :z) == z
@test !has_value(value_choices, :b) # was not selected
@test length(get_submaps_shallow(value_choices)) == 1
@test length(get_values_shallow(value_choices)) == 2

# check gradient from choice map
@test length(get_submaps_shallow(gradient_choices)) == 1
@test length(get_values_shallow(gradient_choices)) == 2
@test !has_value(gradient_choices, :b) # was not selected
@test isapprox(get_value(gradient_choices, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx))
@test isapprox(get_value(gradient_choices, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))
@test isapprox(get_value(gradient_choices, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx))

# reset the trainable parameter gradient
zero_param_grad!(foo, :theta)
Expand Down
Loading