11using LinearAlgebra: dot
22
3- Tree = @NamedTuple begin
3+ struct Tree
44 val_left
55 momenta_left
66 val_right
@@ -9,9 +9,10 @@ Tree = @NamedTuple begin
99 n :: Int
1010 weight :: Float64
1111 stop :: Bool
12+ diverging :: Bool
1213end
1314
14- Stats = @NamedTuple begin
15+ struct SamplerStats
1516 depth
1617 n
1718 accept
@@ -47,11 +48,13 @@ function leapfrog(values, momenta, eps, integrator_state)
4748 return values, momenta, get_score (trace)
4849end
4950
50- function build_root (val, momenta, eps, direction, integrator_state)
51+ function build_root (val, momenta, eps, direction, weight_init, integrator_state)
5152 val, momenta, lp = leapfrog (val, momenta, direction * eps, integrator_state)
5253 weight = lp + assess_momenta (momenta)
5354
54- return Tree ((val, momenta, val, momenta, val, 1 , weight, false ))
55+ diverging = weight - weight_init > 1000
56+
57+ return Tree (val, momenta, val, momenta, val, 1 , weight, false , diverging)
5558end
5659
5760function merge_trees (tree_left, tree_right)
@@ -67,26 +70,29 @@ function merge_trees(tree_left, tree_right)
6770 stop = tree_left. stop || tree_right. stop || u_turn (
6871 tree_left. val_left, tree_right. val_right, tree_left. momenta_left, tree_right. momenta_right
6972 )
73+ diverging = tree_left. diverging || tree_right. diverging
7074
71- return Tree (( tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
72- tree_right. momenta_right, sample, n, weight, stop) )
75+ return Tree (tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
76+ tree_right. momenta_right, sample, n, weight, stop, diverging )
7377end
7478
75- function build_tree (val, momenta, depth, eps, direction, integrator_state)
79+ function build_tree (val, momenta, depth, eps, direction, weight_init, integrator_state)
7680 if depth == 0
77- return build_root (val, momenta, eps, direction, integrator_state)
81+ return build_root (val, momenta, eps, direction, weight_init, integrator_state)
7882 end
7983
80- tree = build_tree (val, momenta, depth - 1 , eps, direction, integrator_state)
84+ tree = build_tree (val, momenta, depth - 1 , eps, direction, weight_init, integrator_state)
8185
82- if tree. stop
86+ if tree. stop || tree . diverging
8387 return tree
8488 end
8589
8690 if direction == 1
87- other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction, integrator_state)
91+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction,
92+ weight_init, integrator_state)
8893 else
89- other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction, integrator_state)
94+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction,
95+ weight_init, integrator_state)
9096 end
9197
9298 if direction == 1
@@ -96,6 +102,27 @@ function build_tree(val, momenta, depth, eps, direction, integrator_state)
96102 end
97103end
98104
105+ """
106+ (new_trace, sampler_statistics) = nuts(
107+ trace, selection::Selection;eps=0.1,
108+ max_treedepth=15, check=false, observations=EmptyChoiceMap())
109+
110+ Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a struct `sampler_statistics` containing information about the sampled trajectory.
111+
112+ The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter.
113+ The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`.
114+
115+ `sampler_statistics` is a struct containing the following fields:
116+ - depth: the depth of the trajectory tree
117+ - n: the number of samples in the trajectory tree
118+ - sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree
119+ - n_accept: how many intermediate samples were accepted
120+ - accept: whether the sample was accepted or not
121+
122+ # References
123+ Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434
124+ Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246
125+ """
99126function nuts (
100127 trace:: Trace , selection:: Selection ; eps= 0.1 , max_treedepth= 15 ,
101128 check= false , observations= EmptyChoiceMap ())
@@ -105,12 +132,15 @@ function nuts(
105132 # values needed for a leapfrog step
106133 (_, values_trie, _) = choice_gradients (trace, selection, retval_grad)
107134 values = to_array (values_trie, Float64)
108- integrator_state = (values_trie, selection, retval_grad, trace)
109135
110136 momenta = sample_momenta (length (values))
111137 prev_momenta_score = assess_momenta (momenta)
112138
113- tree = Tree ((values, momenta, values, momenta, values, 1 , - Inf , false ))
139+ weight_init = prev_model_score + prev_momenta_score
140+
141+ integrator_state = (values_trie, selection, retval_grad, trace)
142+
143+ tree = Tree (values, momenta, values, momenta, values, 1 , - Inf , false , false )
114144
115145 direction = 0
116146 depth = 0
@@ -119,14 +149,16 @@ function nuts(
119149 direction = rand ([- 1 , 1 ])
120150
121151 if direction == 1 # going right
122- other_tree = build_tree (tree. val_right, tree. momenta_right, depth, eps, direction, integrator_state)
152+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth, eps, direction,
153+ weight_init, integrator_state)
123154 tree = merge_trees (tree, other_tree)
124155 else # going left
125- other_tree = build_tree (tree. val_left, tree. momenta_left, depth, eps, direction, integrator_state)
156+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth, eps, direction,
157+ weight_init, integrator_state)
126158 tree = merge_trees (other_tree, tree)
127159 end
128160
129- stop = stop || tree. stop
161+ stop = stop || tree. stop || tree . diverging
130162 if stop
131163 break
132164 end
@@ -147,12 +179,13 @@ function nuts(
147179 end
148180
149181 # accept or reject
150- alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
182+ alpha = new_model_score + new_momenta_score - weight_init
151183 if log (rand ()) < alpha
152- return (new_trace, Stats (( depth, tree. n, true ) ))
184+ return (new_trace, SamplerStats ( depth, tree. n, true ))
153185 else
154- return (trace, Stats (( depth, tree. n, false ) ))
186+ return (trace, SamplerStats ( depth, tree. n, false ))
155187 end
156188end
157189
158- export nuts
190+ export nuts
191+
0 commit comments