Skip to content

Commit 6ac7bdb

Browse files
committed
Gone for _choices
Renamed 'values_trie' and 'gradient_trie' to 'value_choices' and 'gradient_choices' in HMC and MAP optimization code, as well as corresponding test cases, to improve code readability and clarity.
1 parent c39c74d commit 6ac7bdb

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

src/inference/hmc.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ function hmc(
8888

8989
# run leapfrog dynamics
9090
new_trace = trace
91-
(_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
92-
values = to_array(values_trie, Float64)
93-
gradient = to_array(gradient_trie, Float64)
91+
(_, value_choices, gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
92+
values = to_array(value_choices, Float64)
93+
gradient = to_array(gradient_choices, Float64)
9494
momenta = sample_momenta(length(values), metric)
9595
prev_momenta_score = assess_momenta(momenta, metric)
9696
for step=1:L
@@ -102,10 +102,10 @@ function hmc(
102102
values += eps * momenta
103103

104104
# get new gradient
105-
values_trie = from_array(values_trie, values)
106-
(new_trace, _, _) = update(new_trace, args, argdiffs, values_trie)
107-
(_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
108-
gradient = to_array(gradient_trie, Float64)
105+
value_choices = from_array(value_choices, values)
106+
(new_trace, _, _) = update(new_trace, args, argdiffs, value_choices)
107+
(_, _, gradient_choices) = choice_gradients(new_trace, selection, retval_grad)
108+
gradient = to_array(gradient_choices, Float64)
109109

110110
# half step on momenta
111111
momenta += (eps / 2) * gradient

src/inference/map_optimize.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ function map_optimize(trace, selection::Selection;
1111
argdiffs = map((_) -> NoChange(), args)
1212
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
1313

14-
(_, values, gradient) = choice_gradients(trace, selection, retval_grad)
15-
values_vec = to_array(values, Float64)
16-
gradient_vec = to_array(gradient, Float64)
14+
(_, value_choices, gradient_choices) = choice_gradients(trace, selection, retval_grad)
15+
values_vec = to_array(value_choices, Float64)
16+
gradient_vec = to_array(gradient_choices, Float64)
1717
step_size = max_step_size
1818
score = get_score(trace)
1919
while true
2020
new_values_vec = values_vec + gradient_vec * step_size
21-
values = from_array(values, new_values_vec)
21+
value_choices = from_array(value_choices, new_values_vec)
2222
# TODO discard and weight are not actually needed, there should be a more specialized variant
23-
(new_trace, _, _, discard) = update(trace, args, argdiffs, values)
23+
(new_trace, _, _, discard) = update(trace, args, argdiffs, value_choices)
2424
new_score = get_score(new_trace)
2525
change = new_score - score
2626
if verbose

test/inference/hmc.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
(new_trace, accepted) = hmc(trace, select(:x))
2121

2222
# For Normal(0,1), grad should be -x
23-
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0)
24-
values = to_array(values_trie, Float64)
25-
grad = to_array(gradient_trie, Float64)
23+
(_, value_choices, gradient_choices) = choice_gradients(trace, select(:x), 0)
24+
values = to_array(value_choices, Float64)
25+
grad = to_array(gradient_choices, Float64)
2626
@test values -grad
2727

2828
# smoke test with vector metric
@@ -58,9 +58,9 @@
5858
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)
5959

6060
# For each Normal(0,1), grad should be -x
61-
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0)
62-
values = to_array(values_trie, Float64)
63-
grad = to_array(gradient_trie, Float64)
61+
(_, value_choices, gradient_choices) = choice_gradients(trace, select(:x, :y), 0)
62+
values = to_array(value_choices, Float64)
63+
grad = to_array(gradient_choices, Float64)
6464
@test values -grad
6565

6666
# smoke test with Diagonal metric and retval gradient

0 commit comments

Comments
 (0)