Skip to content

Commit 89594a5

Browse files
committed
wip
1 parent 2a959a8 commit 89594a5

File tree

3 files changed

+89
-28
lines changed

3 files changed

+89
-28
lines changed

src/inference/hmc_common.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,44 @@ function assess_momenta(momenta)
99
end
1010
logprob
1111
end
12+
13+
function add_choicemaps(a::ChoiceMap, b::ChoiceMap)
14+
out = choicemap()
15+
16+
for (name, val) in get_values_shallow(a)
17+
out[name] = val + b[name]
18+
end
19+
20+
for (name, submap) in get_submaps_shallow(a)
21+
out.internal_nodes[name] = add_choicemaps(submap, get_submap(b, name))
22+
end
23+
24+
return out
25+
end
26+
27+
function scale_choicemap(a::ChoiceMap, scale)
28+
out = choicemap()
29+
30+
for (name, val) in get_values_shallow(a)
31+
out[name] = val * scale
32+
end
33+
34+
for (name, submap) in get_submaps_shallow(a)
35+
out.internal_nodes[name] = scale_choicemap(submap, scale)
36+
end
37+
38+
return out
39+
end
40+
41+
function assess_momenta_trie(momenta_trie)
42+
logprob = 0.
43+
for (_, val) in get_values_shallow(momenta_trie)
44+
logprob += logpdf(normal, val, 0, 1)
45+
end
46+
47+
for (_, submap) in get_submaps_shallow(momenta_trie)
48+
logprob += assess_momenta_trie(submap)
49+
end
50+
51+
return logprob
52+
end

src/inference/nuts.jl

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ end
1515
struct SamplerStats
1616
depth
1717
n
18+
diverging
1819
accept
1920
end
2021

@@ -23,34 +24,30 @@ function u_turn(values_left, values_right, momenta_left, momenta_right)
2324
(dot(values_right - values_left, momenta_left) >= 0)
2425
end
2526

26-
function leapfrog(values, momenta, eps, integrator_state)
27-
values_trie, selection, retval_grad, trace = integrator_state
27+
function leapfrog(values_trie, momenta_trie, eps, integrator_state)
28+
selection, retval_grad, trace = integrator_state
2829

29-
values_trie = from_array(values_trie, values)
3030
(trace, _, _) = update(trace, values_trie)
3131
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
32-
gradient = to_array(gradient_trie, Float64)
3332

3433
# half step on momenta
35-
momenta += (eps / 2) * gradient
34+
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))
3635

3736
# full step on positions
38-
values += eps * momenta
37+
values_trie = add_choicemaps(values_trie, scale_choicemap(momenta_trie, eps))
3938

4039
# get new gradient
41-
values_trie = from_array(values_trie, values)
4240
(trace, _, _) = update(trace, values_trie)
4341
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
44-
gradient = to_array(gradient_trie, Float64)
4542

4643
# half step on momenta
47-
momenta += (eps / 2) * gradient
48-
return values, momenta, get_score(trace)
44+
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))
45+
return values_trie, momenta_trie, get_score(trace)
4946
end
5047

5148
function build_root(val, momenta, eps, direction, weight_init, integrator_state)
5249
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
53-
weight = lp + assess_momenta(momenta)
50+
weight = lp + assess_momenta_trie(momenta)
5451

5552
diverging = weight - weight_init > 1000
5653

@@ -67,9 +64,16 @@ function merge_trees(tree_left, tree_right)
6764

6865
weight = logsumexp(tree_left.weight, tree_right.weight)
6966
n = tree_left.n + tree_right.n
70-
stop = tree_left.stop || tree_right.stop || u_turn(
71-
tree_left.val_left, tree_right.val_right, tree_left.momenta_left, tree_right.momenta_right
72-
)
67+
68+
if u_turn(to_array(tree_left.val_left, Float64),
69+
to_array(tree_right.val_right, Float64),
70+
to_array(tree_left.momenta_left, Float64),
71+
to_array(tree_right.momenta_right, Float64))
72+
end
73+
stop = tree_left.stop || tree_right.stop || u_turn(to_array(tree_left.val_left, Float64),
74+
to_array(tree_right.val_right, Float64),
75+
to_array(tree_left.momenta_left, Float64),
76+
to_array(tree_right.momenta_right, Float64))
7377
diverging = tree_left.diverging || tree_right.diverging
7478

7579
return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
@@ -90,14 +94,10 @@ function build_tree(val, momenta, depth, eps, direction, weight_init, integrator
9094
if direction == 1
9195
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction,
9296
weight_init, integrator_state)
97+
return merge_trees(tree, other_tree)
9398
else
9499
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction,
95100
weight_init, integrator_state)
96-
end
97-
98-
if direction == 1
99-
return merge_trees(tree, other_tree)
100-
else
101101
return merge_trees(other_tree, tree)
102102
end
103103
end
@@ -131,16 +131,16 @@ function nuts(
131131

132132
# values needed for a leapfrog step
133133
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)
134-
values = to_array(values_trie, Float64)
135134

136-
momenta = sample_momenta(length(values))
135+
momenta = sample_momenta(length(to_array(values_trie, Float64)))
136+
momenta_trie = from_array(values_trie, momenta)
137137
prev_momenta_score = assess_momenta(momenta)
138138

139139
weight_init = prev_model_score + prev_momenta_score
140140

141-
integrator_state = (values_trie, selection, retval_grad, trace)
141+
integrator_state = (selection, retval_grad, trace)
142142

143-
tree = Tree(values, momenta, values, momenta, values, 1, -Inf, false, false)
143+
tree = Tree(values_trie, momenta_trie, values_trie, momenta_trie, values_trie, 1, -Inf, false, false)
144144

145145
direction = 0
146146
depth = 0
@@ -165,25 +165,25 @@ function nuts(
165165
depth += 1
166166
end
167167

168-
(new_trace, _, _) = update(trace, from_array(values_trie, tree.val_sample))
168+
(new_trace, _, _) = update(trace, tree.val_sample)
169169
check && check_observations(get_choices(new_trace), observations)
170170

171171
# assess new model score (negative potential energy)
172172
new_model_score = get_score(new_trace)
173173

174174
# assess new momenta score (negative kinetic energy)
175175
if direction == 1
176-
new_momenta_score = assess_momenta(-tree.momenta_right)
176+
new_momenta_score = assess_momenta_trie(tree.momenta_right)
177177
else
178-
new_momenta_score = assess_momenta(-tree.momenta_left)
178+
new_momenta_score = assess_momenta_trie(tree.momenta_left)
179179
end
180180

181181
# accept or reject
182182
alpha = new_model_score + new_momenta_score - weight_init
183183
if log(rand()) < alpha
184-
return (new_trace, SamplerStats(depth, tree.n, true))
184+
return (new_trace, SamplerStats(depth, tree.n, tree.diverging, true))
185185
else
186-
return (trace, SamplerStats(depth, tree.n, false))
186+
return (trace, SamplerStats(depth, tree.n, tree.diverging, false))
187187
end
188188
end
189189

test/inference/nuts.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@testset "nuts" begin
2+
3+
# smoke test a function without retval gradient
4+
@gen function foo()
5+
x = @trace(normal(0, 1), :x)
6+
return x
7+
end
8+
9+
(trace, _) = generate(foo, ())
10+
(new_trace, accepted) = nuts(trace, select(:x))
11+
12+
# smoke test a function with retval gradient
13+
@gen (grad) function foo()
14+
x = @trace(normal(0, 1), :x)
15+
return x
16+
end
17+
18+
(trace, _) = generate(foo, ())
19+
(new_trace, accepted) = nuts(trace, select(:x))
20+
end

0 commit comments

Comments
 (0)