Skip to content

Commit 90807bf

Browse files
committed
More consistency on arg refactor
1 parent 6ac7bdb commit 90807bf

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

src/gen_fn_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ function accumulate_param_gradients!(trace)
373373
end
374374

375375
"""
376-
(arg_grads, choice_values, choice_grads) = choice_gradients(
376+
(arg_grads, choice_values, choice_gradient) = choice_gradients(
377377
trace, selection=EmptySelection(), retgrad=nothing)
378378
379379
Given a previous trace \$(x, t)\$ (`trace`) and a gradient with respect to the

src/inference/mala.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ function mala(
1717
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
1818

1919
# forward proposal
20-
(_, choice_values, choice_gradient) = choice_gradients(trace, selection, retval_grad)
21-
values = to_array(choice_values, Float64)
22-
gradient = to_array(choice_gradient, Float64)
20+
(_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
21+
values = to_array(value_choices, Float64)
22+
gradient = to_array(gradient_choices, Float64)
2323
forward_mu = values + tau * gradient
2424
forward_score = 0.
2525
proposed_values = Vector{Float64}(undef, length(values))
@@ -29,14 +29,14 @@ function mala(
2929
end
3030

3131
# evaluate model weight
32-
constraints = from_array(choice_values, proposed_values)
32+
constraints = from_array(value_choices, proposed_values)
3333
(new_trace, weight, _, discard) = update(trace,
3434
args, argdiffs, constraints)
3535
check && check_observations(get_choices(new_trace), observations)
3636

3737
# backward proposal
38-
(_, _, backward_choice_gradient) = choice_gradients(new_trace, selection, retval_grad)
39-
backward_gradient = to_array(backward_choice_gradient, Float64)
38+
(_, _, backward_gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
39+
backward_gradient = to_array(backward_gradient_choices, Float64)
4040
@assert length(backward_gradient) == length(values)
4141
backward_score = 0.
4242
backward_mu = proposed_values + tau * backward_gradient

src/modeling_library/call_at/call_at.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ end
144144

145145
function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad)
146146
subselection = selection[trace.key]
147-
(kernel_input_grads, value_submap, gradient_submap) = choice_gradients(
147+
(kernel_input_grads, value_choices_submap, gradient_choices_submap) = choice_gradients(
148148
trace.subtrace, subselection, retval_grad)
149149
input_grads = (kernel_input_grads..., nothing)
150-
value_choices = CallAtChoiceMap(trace.key, value_submap)
151-
gradient_choices = CallAtChoiceMap(trace.key, gradient_submap)
150+
value_choices = CallAtChoiceMap(trace.key, value_choices_submap)
151+
gradient_choices = CallAtChoiceMap(trace.key, gradient_choices_submap)
152152
(input_grads, value_choices, gradient_choices)
153153
end
154154

0 commit comments

Comments
 (0)