Skip to content

Commit 5553cfb

Browse files
committed
feat(inference): add nuts sampler
1 parent dce003c commit 5553cfb

File tree

4 files changed

+172
-12
lines changed

4 files changed

+172
-12
lines changed

src/inference/hmc.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
function sample_momenta(n::Int)
2-
Float64[random(normal, 0, 1) for _=1:n]
3-
end
4-
5-
function assess_momenta(momenta)
6-
logprob = 0.
7-
for val in momenta
8-
logprob += logpdf(normal, val, 0, 1)
9-
end
10-
logprob
11-
end
12-
131
"""
142
(new_trace, accepted) = hmc(
153
trace, selection::Selection; L=10, eps=0.1,

src/inference/hmc_common.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
function sample_momenta(n::Int)
2+
Float64[random(normal, 0, 1) for _=1:n]
3+
end
4+
5+
function assess_momenta(momenta)
6+
logprob = 0.
7+
for val in momenta
8+
logprob += logpdf(normal, val, 0, 1)
9+
end
10+
logprob
11+
end

src/inference/inference.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ export logsumexp
1414

1515
include("trace_translators.jl")
1616

17+
include("hmc_common.jl")
18+
1719
# mcmc
1820
include("kernel_dsl.jl")
1921
include("mh.jl")
2022
include("hmc.jl")
23+
include("nuts.jl")
2124
include("mala.jl")
2225
include("elliptical_slice.jl")
2326

src/inference/nuts.jl

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
using LinearAlgebra: dot
2+
3+
Tree = @NamedTuple begin
4+
val_left
5+
momenta_left
6+
val_right
7+
momenta_right
8+
val_sample
9+
n :: Int
10+
weight :: Float64
11+
stop :: Bool
12+
end
13+
14+
Stats = @NamedTuple begin
15+
depth
16+
n
17+
accept
18+
end
19+
20+
function u_turn(values_left, values_right, momenta_left, momenta_right)
21+
return (dot(values_left - values_right, momenta_right) >= 0) &&
22+
(dot(values_right - values_left, momenta_left) >= 0)
23+
end
24+
25+
function leapfrog(values, momenta, eps, integrator_state)
26+
values_trie, selection, retval_grad, trace = integrator_state
27+
28+
values_trie = from_array(values_trie, values)
29+
(trace, _, _) = update(trace, values_trie)
30+
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
31+
gradient = to_array(gradient_trie, Float64)
32+
33+
# half step on momenta
34+
momenta += (eps / 2) * gradient
35+
36+
# full step on positions
37+
values += eps * momenta
38+
39+
# get new gradient
40+
values_trie = from_array(values_trie, values)
41+
(trace, _, _) = update(trace, values_trie)
42+
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
43+
gradient = to_array(gradient_trie, Float64)
44+
45+
# half step on momenta
46+
momenta += (eps / 2) * gradient
47+
return values, momenta, get_score(trace)
48+
end
49+
50+
function build_root(val, momenta, eps, direction, integrator_state)
51+
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
52+
weight = lp + assess_momenta(momenta)
53+
54+
return Tree((val, momenta, val, momenta, val, 1, weight, false))
55+
end
56+
57+
function merge_trees(tree_left, tree_right)
58+
# multinomial sampling
59+
if log(rand()) < tree_right.weight - tree_left.weight
60+
sample = tree_right.val_sample
61+
else
62+
sample = tree_left.val_sample
63+
end
64+
65+
weight = logsumexp(tree_left.weight, tree_right.weight)
66+
n = tree_left.n + tree_right.n
67+
stop = tree_left.stop || tree_right.stop || u_turn(
68+
tree_left.val_left, tree_right.val_right, tree_left.momenta_left, tree_right.momenta_right
69+
)
70+
71+
return Tree((tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
72+
tree_right.momenta_right, sample, n, weight, stop))
73+
end
74+
75+
function build_tree(val, momenta, depth, eps, direction, integrator_state)
76+
if depth == 0
77+
return build_root(val, momenta, eps, direction, integrator_state)
78+
end
79+
80+
tree = build_tree(val, momenta, depth - 1, eps, direction, integrator_state)
81+
82+
if tree.stop
83+
return tree
84+
end
85+
86+
if direction == 1
87+
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction, integrator_state)
88+
else
89+
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction, integrator_state)
90+
end
91+
92+
if direction == 1
93+
return merge_trees(tree, other_tree)
94+
else
95+
return merge_trees(other_tree, tree)
96+
end
97+
end
98+
99+
function nuts(
100+
trace::Trace, selection::Selection; eps=0.1, max_treedepth=15,
101+
check=false, observations=EmptyChoiceMap())
102+
prev_model_score = get_score(trace)
103+
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
104+
105+
# values needed for a leapfrog step
106+
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)
107+
values = to_array(values_trie, Float64)
108+
integrator_state = (values_trie, selection, retval_grad, trace)
109+
110+
momenta = sample_momenta(length(values))
111+
prev_momenta_score = assess_momenta(momenta)
112+
113+
tree = Tree((values, momenta, values, momenta, values, 1, -Inf, false))
114+
115+
direction = 0
116+
depth = 0
117+
stop = false
118+
while depth < max_treedepth
119+
direction = rand([-1, 1])
120+
121+
if direction == 1 # going right
122+
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction, integrator_state)
123+
tree = merge_trees(tree, other_tree)
124+
else # going left
125+
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction, integrator_state)
126+
tree = merge_trees(other_tree, tree)
127+
end
128+
129+
stop = stop || tree.stop
130+
if stop
131+
break
132+
end
133+
depth += 1
134+
end
135+
136+
(new_trace, _, _) = update(trace, from_array(values_trie, tree.val_sample))
137+
check && check_observations(get_choices(new_trace), observations)
138+
139+
# assess new model score (negative potential energy)
140+
new_model_score = get_score(new_trace)
141+
142+
# assess new momenta score (negative kinetic energy)
143+
if direction == 1
144+
new_momenta_score = assess_momenta(-tree.momenta_right)
145+
else
146+
new_momenta_score = assess_momenta(-tree.momenta_left)
147+
end
148+
149+
# accept or reject
150+
alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
151+
if log(rand()) < alpha
152+
return (new_trace, Stats((depth, tree.n, true)))
153+
else
154+
return (trace, Stats((depth, tree.n, false)))
155+
end
156+
end
157+
158+
export nuts

0 commit comments

Comments
 (0)