1515struct SamplerStats
1616 depth
1717 n
18+ diverging
1819 accept
1920end
2021
@@ -23,34 +24,30 @@ function u_turn(values_left, values_right, momenta_left, momenta_right)
2324 (dot (values_right - values_left, momenta_left) >= 0 )
2425end
2526
26- function leapfrog (values, momenta , eps, integrator_state)
27- values_trie, selection, retval_grad, trace = integrator_state
27+ function leapfrog (values_trie, momenta_trie , eps, integrator_state)
28+ selection, retval_grad, trace = integrator_state
2829
29- values_trie = from_array (values_trie, values)
3030 (trace, _, _) = update (trace, values_trie)
3131 (_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
32- gradient = to_array (gradient_trie, Float64)
3332
3433 # half step on momenta
35- momenta += ( eps / 2 ) * gradient
34+ momenta_trie = add_choicemaps (momenta_trie, scale_choicemap (gradient_trie, eps / 2 ))
3635
3736 # full step on positions
38- values += eps * momenta
37+ values_trie = add_choicemaps (values_trie, scale_choicemap (momenta_trie, eps))
3938
4039 # get new gradient
41- values_trie = from_array (values_trie, values)
4240 (trace, _, _) = update (trace, values_trie)
4341 (_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
44- gradient = to_array (gradient_trie, Float64)
4542
4643 # half step on momenta
47- momenta += ( eps / 2 ) * gradient
48- return values, momenta , get_score (trace)
44+ momenta_trie = add_choicemaps (momenta_trie, scale_choicemap (gradient_trie, eps / 2 ))
45+ return values_trie, momenta_trie , get_score (trace)
4946end
5047
5148function build_root (val, momenta, eps, direction, weight_init, integrator_state)
5249 val, momenta, lp = leapfrog (val, momenta, direction * eps, integrator_state)
53- weight = lp + assess_momenta (momenta)
50+ weight = lp + assess_momenta_trie (momenta)
5451
5552 diverging = weight - weight_init > 1000
5653
@@ -67,9 +64,16 @@ function merge_trees(tree_left, tree_right)
6764
6865 weight = logsumexp (tree_left. weight, tree_right. weight)
6966 n = tree_left. n + tree_right. n
70- stop = tree_left. stop || tree_right. stop || u_turn (
71- tree_left. val_left, tree_right. val_right, tree_left. momenta_left, tree_right. momenta_right
72- )
67+
68+ if u_turn (to_array (tree_left. val_left, Float64),
69+ to_array (tree_right. val_right, Float64),
70+ to_array (tree_left. momenta_left, Float64),
71+ to_array (tree_right. momenta_right, Float64))
72+ end
73+ stop = tree_left. stop || tree_right. stop || u_turn (to_array (tree_left. val_left, Float64),
74+ to_array (tree_right. val_right, Float64),
75+ to_array (tree_left. momenta_left, Float64),
76+ to_array (tree_right. momenta_right, Float64))
7377 diverging = tree_left. diverging || tree_right. diverging
7478
7579 return Tree (tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
@@ -90,14 +94,10 @@ function build_tree(val, momenta, depth, eps, direction, weight_init, integrator
9094 if direction == 1
9195 other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction,
9296 weight_init, integrator_state)
97+ return merge_trees (tree, other_tree)
9398 else
9499 other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction,
95100 weight_init, integrator_state)
96- end
97-
98- if direction == 1
99- return merge_trees (tree, other_tree)
100- else
101101 return merge_trees (other_tree, tree)
102102 end
103103end
@@ -131,16 +131,16 @@ function nuts(
131131
132132 # values needed for a leapfrog step
133133 (_, values_trie, _) = choice_gradients (trace, selection, retval_grad)
134- values = to_array (values_trie, Float64)
135134
136- momenta = sample_momenta (length (values))
135+ momenta = sample_momenta (length (to_array (values_trie, Float64)))
136+ momenta_trie = from_array (values_trie, momenta)
137137 prev_momenta_score = assess_momenta (momenta)
138138
139139 weight_init = prev_model_score + prev_momenta_score
140140
141- integrator_state = (values_trie, selection, retval_grad, trace)
141+ integrator_state = (selection, retval_grad, trace)
142142
143- tree = Tree (values, momenta, values, momenta, values , 1 , - Inf , false , false )
143+ tree = Tree (values_trie, momenta_trie, values_trie, momenta_trie, values_trie , 1 , - Inf , false , false )
144144
145145 direction = 0
146146 depth = 0
@@ -165,25 +165,25 @@ function nuts(
165165 depth += 1
166166 end
167167
168- (new_trace, _, _) = update (trace, from_array (values_trie, tree. val_sample) )
168+ (new_trace, _, _) = update (trace, tree. val_sample)
169169 check && check_observations (get_choices (new_trace), observations)
170170
171171 # assess new model score (negative potential energy)
172172 new_model_score = get_score (new_trace)
173173
174174 # assess new momenta score (negative kinetic energy)
175175 if direction == 1
176- new_momenta_score = assess_momenta ( - tree. momenta_right)
176+ new_momenta_score = assess_momenta_trie ( tree. momenta_right)
177177 else
178- new_momenta_score = assess_momenta ( - tree. momenta_left)
178+ new_momenta_score = assess_momenta_trie ( tree. momenta_left)
179179 end
180180
181181 # accept or reject
182182 alpha = new_model_score + new_momenta_score - weight_init
183183 if log (rand ()) < alpha
184- return (new_trace, SamplerStats (depth, tree. n, true ))
184+ return (new_trace, SamplerStats (depth, tree. n, tree . diverging, true ))
185185 else
186- return (trace, SamplerStats (depth, tree. n, false ))
186+ return (trace, SamplerStats (depth, tree. n, tree . diverging, false ))
187187 end
188188end
189189
0 commit comments