Skip to content

Commit 2a959a8

Browse files
committed
wip: check for divergence
1 parent 5553cfb commit 2a959a8

File tree

2 files changed

+55
-22
lines changed

2 files changed

+55
-22
lines changed

src/inference/hmc_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ function assess_momenta(momenta)
88
logprob += logpdf(normal, val, 0, 1)
99
end
1010
logprob
11-
end
11+
end

src/inference/nuts.jl

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra: dot
22

3-
Tree = @NamedTuple begin
3+
struct Tree
44
val_left
55
momenta_left
66
val_right
@@ -9,9 +9,10 @@ Tree = @NamedTuple begin
99
n :: Int
1010
weight :: Float64
1111
stop :: Bool
12+
diverging :: Bool
1213
end
1314

14-
Stats = @NamedTuple begin
15+
struct SamplerStats
1516
depth
1617
n
1718
accept
@@ -47,11 +48,13 @@ function leapfrog(values, momenta, eps, integrator_state)
4748
return values, momenta, get_score(trace)
4849
end
4950

50-
function build_root(val, momenta, eps, direction, integrator_state)
51+
function build_root(val, momenta, eps, direction, weight_init, integrator_state)
5152
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
5253
weight = lp + assess_momenta(momenta)
5354

54-
return Tree((val, momenta, val, momenta, val, 1, weight, false))
55+
diverging = weight - weight_init > 1000
56+
57+
return Tree(val, momenta, val, momenta, val, 1, weight, false, diverging)
5558
end
5659

5760
function merge_trees(tree_left, tree_right)
@@ -67,26 +70,29 @@ function merge_trees(tree_left, tree_right)
6770
stop = tree_left.stop || tree_right.stop || u_turn(
6871
tree_left.val_left, tree_right.val_right, tree_left.momenta_left, tree_right.momenta_right
6972
)
73+
diverging = tree_left.diverging || tree_right.diverging
7074

71-
return Tree((tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
72-
tree_right.momenta_right, sample, n, weight, stop))
75+
return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
76+
tree_right.momenta_right, sample, n, weight, stop, diverging)
7377
end
7478

75-
function build_tree(val, momenta, depth, eps, direction, integrator_state)
79+
function build_tree(val, momenta, depth, eps, direction, weight_init, integrator_state)
7680
if depth == 0
77-
return build_root(val, momenta, eps, direction, integrator_state)
81+
return build_root(val, momenta, eps, direction, weight_init, integrator_state)
7882
end
7983

80-
tree = build_tree(val, momenta, depth - 1, eps, direction, integrator_state)
84+
tree = build_tree(val, momenta, depth - 1, eps, direction, weight_init, integrator_state)
8185

82-
if tree.stop
86+
if tree.stop || tree.diverging
8387
return tree
8488
end
8589

8690
if direction == 1
87-
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction, integrator_state)
91+
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction,
92+
weight_init, integrator_state)
8893
else
89-
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction, integrator_state)
94+
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction,
95+
weight_init, integrator_state)
9096
end
9197

9298
if direction == 1
@@ -96,6 +102,27 @@ function build_tree(val, momenta, depth, eps, direction, integrator_state)
96102
end
97103
end
98104

105+
"""
106+
(new_trace, sampler_statistics) = nuts(
107+
trace, selection::Selection;eps=0.1,
108+
max_treedepth=15, check=false, observations=EmptyChoiceMap())
109+
110+
Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a struct `sampler_statistics` containing information about the sampled trajectory.
111+
112+
The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter.
113+
The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`.
114+
115+
`sampler_statistics` is a struct containing the following fields:
116+
- depth: the depth of the trajectory tree
117+
- n: the number of samples in the trajectory tree
118+
- sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree
119+
- n_accept: how many intermediate samples were accepted
120+
- accept: whether the sample was accepted or not
121+
122+
# References
123+
Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434
124+
Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246
125+
"""
99126
function nuts(
100127
trace::Trace, selection::Selection; eps=0.1, max_treedepth=15,
101128
check=false, observations=EmptyChoiceMap())
@@ -105,12 +132,15 @@ function nuts(
105132
# values needed for a leapfrog step
106133
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)
107134
values = to_array(values_trie, Float64)
108-
integrator_state = (values_trie, selection, retval_grad, trace)
109135

110136
momenta = sample_momenta(length(values))
111137
prev_momenta_score = assess_momenta(momenta)
112138

113-
tree = Tree((values, momenta, values, momenta, values, 1, -Inf, false))
139+
weight_init = prev_model_score + prev_momenta_score
140+
141+
integrator_state = (values_trie, selection, retval_grad, trace)
142+
143+
tree = Tree(values, momenta, values, momenta, values, 1, -Inf, false, false)
114144

115145
direction = 0
116146
depth = 0
@@ -119,14 +149,16 @@ function nuts(
119149
direction = rand([-1, 1])
120150

121151
if direction == 1 # going right
122-
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction, integrator_state)
152+
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction,
153+
weight_init, integrator_state)
123154
tree = merge_trees(tree, other_tree)
124155
else # going left
125-
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction, integrator_state)
156+
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction,
157+
weight_init, integrator_state)
126158
tree = merge_trees(other_tree, tree)
127159
end
128160

129-
stop = stop || tree.stop
161+
stop = stop || tree.stop || tree.diverging
130162
if stop
131163
break
132164
end
@@ -147,12 +179,13 @@ function nuts(
147179
end
148180

149181
# accept or reject
150-
alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
182+
alpha = new_model_score + new_momenta_score - weight_init
151183
if log(rand()) < alpha
152-
return (new_trace, Stats((depth, tree.n, true)))
184+
return (new_trace, SamplerStats(depth, tree.n, true))
153185
else
154-
return (trace, Stats((depth, tree.n, false)))
186+
return (trace, SamplerStats(depth, tree.n, false))
155187
end
156188
end
157189

158-
export nuts
190+
export nuts
191+

0 commit comments

Comments
 (0)