diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..9e26dfeeb --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 5e5232035..7ff55030a 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -373,7 +373,7 @@ function accumulate_param_gradients!(trace) end """ - (arg_grads, choice_values, choice_grads) = choice_gradients( + (arg_grads, value_choices, gradient_choices) = choice_gradients( trace, selection=EmptySelection(), retgrad=nothing) Given a previous trace \$(x, t)\$ (`trace`) and a gradient with respect to the @@ -389,7 +389,7 @@ If an argument is not annotated with `(grad)`, the corresponding value in `arg_grads` will be `nothing`. Also given a set of addresses \$A\$ (`selection`) that are continuous-valued -random choices, return the folowing gradient (`choice_grads`) with respect to +random choices, return the folowing gradient (`gradient_choices`) with respect to the values of these choices: ```math ∇_A \\left( \\log P(t; x) + J \\right) @@ -397,7 +397,7 @@ the values of these choices: The gradient is represented as a choicemap whose value at (hierarchical) address `addr` is \$∂J/∂t[\\texttt{addr}]\$. -Also return the choicemap (`choice_values`) that is the restriction of \$t\$ to \$A\$. +Also return the choicemap (`value_choices`) that is the restriction of \$t\$ to \$A\$. """ function choice_gradients(trace, selection::Selection, retgrad) error("Not implemented") diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 5951aeea1..57bd1cf8c 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -88,9 +88,9 @@ function hmc( # run leapfrog dynamics new_trace = trace - (_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) - values = to_array(values_trie, Float64) - gradient = to_array(gradient_trie, Float64) + (_, value_choices, gradient_choices) = choice_gradients(new_trace, selection, retval_grad) + values = to_array(value_choices, Float64) + gradient = to_array(gradient_choices, Float64) momenta = sample_momenta(length(values), metric) prev_momenta_score = assess_momenta(momenta, metric) for step=1:L @@ -102,10 +102,10 @@ function hmc( values += eps * momenta # get new gradient - values_trie = from_array(values_trie, values) - (new_trace, _, _) = update(new_trace, args, argdiffs, values_trie) - (_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) - gradient = to_array(gradient_trie, Float64) + value_choices = from_array(value_choices, values) + (new_trace, _, _) = update(new_trace, args, argdiffs, value_choices) + (_, _, gradient_choices) = choice_gradients(new_trace, selection, retval_grad) + gradient = to_array(gradient_choices, Float64) # half step on momenta momenta += (eps / 2) * gradient diff --git a/src/inference/mala.jl b/src/inference/mala.jl index 033a45a7e..84765de7e 100644 --- a/src/inference/mala.jl +++ b/src/inference/mala.jl @@ -17,9 +17,9 @@ function mala( retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing # forward proposal - (_, values_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad) - values = to_array(values_trie, Float64) - gradient = to_array(gradient_trie, Float64) + (_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad) + values = to_array(value_choices, Float64) + gradient = to_array(gradient_choices, Float64) forward_mu = values + tau * gradient forward_score = 0. proposed_values = Vector{Float64}(undef, length(values)) @@ -29,14 +29,14 @@ function mala( end # evaluate model weight - constraints = from_array(values_trie, proposed_values) + constraints = from_array(value_choices, proposed_values) (new_trace, weight, _, discard) = update(trace, args, argdiffs, constraints) check && check_observations(get_choices(new_trace), observations) # backward proposal - (_, _, backward_gradient_trie) = choice_gradients(new_trace, selection, retval_grad) - backward_gradient = to_array(backward_gradient_trie, Float64) + (_, _, backward_gradient_choices) = choice_gradients(new_trace, selection, retval_grad) + backward_gradient = to_array(backward_gradient_choices, Float64) @assert length(backward_gradient) == length(values) backward_score = 0. backward_mu = proposed_values + tau * backward_gradient diff --git a/src/inference/map_optimize.jl b/src/inference/map_optimize.jl index 3250e9e38..d5bb80d1a 100644 --- a/src/inference/map_optimize.jl +++ b/src/inference/map_optimize.jl @@ -11,16 +11,16 @@ function map_optimize(trace, selection::Selection; argdiffs = map((_) -> NoChange(), args) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing - (_, values, gradient) = choice_gradients(trace, selection, retval_grad) - values_vec = to_array(values, Float64) - gradient_vec = to_array(gradient, Float64) + (_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad) + values_vec = to_array(value_choices, Float64) + gradient_vec = to_array(gradient_choices, Float64) step_size = max_step_size score = get_score(trace) while true new_values_vec = values_vec + gradient_vec * step_size - values = from_array(values, new_values_vec) + value_choices = from_array(value_choices, new_values_vec) # TODO discard and weight are not actually needed, there should be a more specialized variant - (new_trace, _, _, discard) = update(trace, args, argdiffs, values) + (new_trace, _, _, discard) = update(trace, args, argdiffs, value_choices) new_score = get_score(new_trace) change = new_score - score if verbose diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 0c5b997bc..e4c384641 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -144,11 +144,11 @@ end function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) subselection = selection[trace.key] - (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( + (kernel_input_grads, value_choices_submap, gradient_choices_submap) = choice_gradients( trace.subtrace, subselection, retval_grad) input_grads = (kernel_input_grads..., nothing) - value_choices = CallAtChoiceMap(trace.key, value_submap) - gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) + value_choices = CallAtChoiceMap(trace.key, value_choices_submap) + gradient_choices = CallAtChoiceMap(trace.key, gradient_choices_submap) (input_grads, value_choices, gradient_choices) end diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 8549be4e3..ea0f030fc 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -4,11 +4,11 @@ struct BackpropParamsMode end const gradient_prefix = gensym("gradient") gradient_var(node::StaticIRNode) = Symbol("$(gradient_prefix)_$(node.name)") -const value_trie_prefix = gensym("value_trie") -value_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(value_trie_prefix)_$(node.addr)") +const value_choices_prefix = gensym("value_choices") +value_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(value_choices_prefix)_$(node.addr)") -const gradient_trie_prefix = gensym("gradient_trie") -gradient_trie_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_trie_prefix)_$(node.addr)") +const gradient_choices_prefix = gensym("gradient_choices") +gradient_choices_var(node::GenerativeFunctionCallNode) = Symbol("$(gradient_choices_prefix)_$(node.addr)") const tape_prefix = gensym("tape") tape_var(node::JuliaNode) = Symbol("$(tape_prefix)_$(node.name)") @@ -128,7 +128,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) # we need the value for initializing gradient to zero (to get the type # and e.g. shape), and for reference by other nodes during # back_codegen! we could be more selective about which JuliaNodes need - # to be evalutaed, that is a performance optimization for the future + # to be evaluated, that is a performance optimization for the future args = map((input_node) -> input_node.name, node.inputs) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...)))) end @@ -275,8 +275,8 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, if node in fwd_marked input_grads = gensym("call_input_grads") - value_trie = value_trie_var(node) - gradient_trie = gradient_trie_var(node) + value_choices = value_choices_var(node) + gradient_choices = gradient_choices_var(node) subtrace_fieldname = get_subtrace_fieldname(node) call_selection = gensym("call_selection") if node in selected_calls @@ -285,7 +285,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, push!(stmts, :($call_selection = EmptySelection())) end retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + push!(stmts, :(($input_grads, $value_choices, $gradient_choices) = choice_gradients( trace.$subtrace_fieldname, $call_selection, $retval_grad))) end @@ -297,7 +297,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, end end - # NOTE: the value_trie and gradient_trie are dealt with later + # NOTE: the value_choices and gradient_choices are dealt with later end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, @@ -325,9 +325,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, end end -function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, +function generate_value_gradient_choices(selected_choices::Set{RandomChoiceNode}, selected_calls::Set{GenerativeFunctionCallNode}, - value_trie::Symbol, gradient_trie::Symbol) + value_choices::Symbol, gradient_choices::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) @@ -337,12 +337,12 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) + internal_gradients = map((node) -> gradient_choices_var(node), selected_calls_vec) quote - $value_trie = StaticChoiceMap( + $value_choices = StaticChoiceMap( NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( + $gradient_choices = StaticChoiceMap( NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) end @@ -429,18 +429,18 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode()) end - # assemble value_trie and gradient_trie - value_trie = gensym("value_trie") - gradient_trie = gensym("gradient_trie") - push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, - value_trie, gradient_trie)) + # assemble value_choices and gradient_choices + value_choices = gensym("value_choices") + gradient_choices = gensym("gradient_choices") + push!(stmts, generate_value_gradient_choices(selected_choices, selected_calls, + value_choices, gradient_choices)) # gradients with respect to inputs arg_grad = (node::ArgumentNode) -> node.compute_grad ? gradient_var(node) : :(nothing) input_grads = Expr(:tuple, map(arg_grad, ir.arg_nodes)...) # return values - push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie))) + push!(stmts, :(return ($input_grads, $value_choices, $gradient_choices))) Expr(:block, stmts...) end diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index dbe77030d..34d4c4f21 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -345,7 +345,7 @@ end # check input gradient @test isapprox(mu_a_grad, finite_diff(f, (mu_a, a, b, z, out), 1, dx)) - # check value trie + # check value from choicemap @test get_value(choices, :a) == a @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @@ -353,7 +353,7 @@ end @test length(collect(get_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 - # check gradient trie + # check gradient from choicemap @test length(collect(get_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index e465351b6..2e2fb9b08 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -20,9 +20,9 @@ (new_trace, accepted) = hmc(trace, select(:x)) # For Normal(0,1), grad should be -x - (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0) - values = to_array(values_trie, Float64) - grad = to_array(gradient_trie, Float64) + (_, value_choices, gradient_choices) = choice_gradients(trace, select(:x), 0) + values = to_array(value_choices, Float64) + grad = to_array(gradient_choices, Float64) @test values ≈ -grad # smoke test with vector metric @@ -58,9 +58,9 @@ (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) # For each Normal(0,1), grad should be -x - (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0) - values = to_array(values_trie, Float64) - grad = to_array(gradient_trie, Float64) + (_, value_choices, gradient_choices) = choice_gradients(trace, select(:x, :y), 0) + values = to_array(value_choices, Float64) + grad = to_array(gradient_choices, Float64) @test values ≈ -grad # smoke test with Diagonal metric and retval gradient diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 44a308b83..b1895ef42 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -351,26 +351,26 @@ end selection = select(:bar => :z, :a, :out) selection = StaticSelection(selection) retval_grad = 2. - ((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad) + ((mu_a_grad,), value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad) # check input gradient @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) - # check value trie - @test get_value(value_trie, :a) == a - @test get_value(value_trie, :out) == out - @test get_value(value_trie, :bar => :z) == z - @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 - - # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 - @test !has_value(gradient_trie, :b) # was not selected - @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) - @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) - @test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) + # check value from choice map + @test get_value(value_choices, :a) == a + @test get_value(value_choices, :out) == out + @test get_value(value_choices, :bar => :z) == z + @test !has_value(value_choices, :b) # was not selected + @test length(get_submaps_shallow(value_choices)) == 1 + @test length(get_values_shallow(value_choices)) == 2 + + # check gradient from choice map + @test length(get_submaps_shallow(gradient_choices)) == 1 + @test length(get_values_shallow(gradient_choices)) == 2 + @test !has_value(gradient_choices, :b) # was not selected + @test isapprox(get_value(gradient_choices, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) + @test isapprox(get_value(gradient_choices, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) + @test isapprox(get_value(gradient_choices, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) # reset the trainable parameter gradient zero_param_grad!(foo, :theta)