@@ -17,9 +17,9 @@ function mala(
1717 retval_grad = accepts_output_grad (get_gen_fn (trace)) ? zero (get_retval (trace)) : nothing
1818
1919 # forward proposal
20- (_, choice_values, choice_gradient ) = choice_gradients (trace, selection, retval_grad)
21- values = to_array (choice_values , Float64)
22- gradient = to_array (choice_gradient , Float64)
20+ (_, value_choices, gradient_choices ) = choice_gradients (trace, selection, retval_grad)
21+ values = to_array (value_choices , Float64)
22+ gradient = to_array (gradient_choices , Float64)
2323 forward_mu = values + tau * gradient
2424 forward_score = 0.
2525 proposed_values = Vector {Float64} (undef, length (values))
@@ -29,14 +29,14 @@ function mala(
2929 end
3030
3131 # evaluate model weight
32- constraints = from_array (choice_values , proposed_values)
32+ constraints = from_array (value_choices , proposed_values)
3333 (new_trace, weight, _, discard) = update (trace,
3434 args, argdiffs, constraints)
3535 check && check_observations (get_choices (new_trace), observations)
3636
3737 # backward proposal
38- (_, _, backward_choice_gradient ) = choice_gradients (new_trace, selection, retval_grad)
39- backward_gradient = to_array (backward_choice_gradient , Float64)
38+ (_, _, backward_gradient_choices ) = choice_gradients (new_trace, selection, retval_grad)
39+ backward_gradient = to_array (backward_gradient_choices , Float64)
4040 @assert length (backward_gradient) == length (values)
4141 backward_score = 0.
4242 backward_mu = proposed_values + tau * backward_gradient
0 commit comments