15
15
struct SamplerStats
16
16
depth
17
17
n
18
+ diverging
18
19
accept
19
20
end
20
21
@@ -23,34 +24,30 @@ function u_turn(values_left, values_right, momenta_left, momenta_right)
23
24
(dot (values_right - values_left, momenta_left) >= 0 )
24
25
end
25
26
26
- function leapfrog (values, momenta , eps, integrator_state)
27
- values_trie, selection, retval_grad, trace = integrator_state
27
+ function leapfrog (values_trie, momenta_trie , eps, integrator_state)
28
+ selection, retval_grad, trace = integrator_state
28
29
29
- values_trie = from_array (values_trie, values)
30
30
(trace, _, _) = update (trace, values_trie)
31
31
(_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
32
- gradient = to_array (gradient_trie, Float64)
33
32
34
33
# half step on momenta
35
- momenta += ( eps / 2 ) * gradient
34
+ momenta_trie = add_choicemaps (momenta_trie, scale_choicemap (gradient_trie, eps / 2 ))
36
35
37
36
# full step on positions
38
- values += eps * momenta
37
+ values_trie = add_choicemaps (values_trie, scale_choicemap (momenta_trie, eps))
39
38
40
39
# get new gradient
41
- values_trie = from_array (values_trie, values)
42
40
(trace, _, _) = update (trace, values_trie)
43
41
(_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
44
- gradient = to_array (gradient_trie, Float64)
45
42
46
43
# half step on momenta
47
- momenta += ( eps / 2 ) * gradient
48
- return values, momenta , get_score (trace)
44
+ momenta_trie = add_choicemaps (momenta_trie, scale_choicemap (gradient_trie, eps / 2 ))
45
+ return values_trie, momenta_trie , get_score (trace)
49
46
end
50
47
51
48
function build_root (val, momenta, eps, direction, weight_init, integrator_state)
52
49
val, momenta, lp = leapfrog (val, momenta, direction * eps, integrator_state)
53
- weight = lp + assess_momenta (momenta)
50
+ weight = lp + assess_momenta_trie (momenta)
54
51
55
52
diverging = weight - weight_init > 1000
56
53
@@ -67,9 +64,16 @@ function merge_trees(tree_left, tree_right)
67
64
68
65
weight = logsumexp (tree_left. weight, tree_right. weight)
69
66
n = tree_left. n + tree_right. n
70
- stop = tree_left. stop || tree_right. stop || u_turn (
71
- tree_left. val_left, tree_right. val_right, tree_left. momenta_left, tree_right. momenta_right
72
- )
67
+
68
+ if u_turn (to_array (tree_left. val_left, Float64),
69
+ to_array (tree_right. val_right, Float64),
70
+ to_array (tree_left. momenta_left, Float64),
71
+ to_array (tree_right. momenta_right, Float64))
72
+ end
73
+ stop = tree_left. stop || tree_right. stop || u_turn (to_array (tree_left. val_left, Float64),
74
+ to_array (tree_right. val_right, Float64),
75
+ to_array (tree_left. momenta_left, Float64),
76
+ to_array (tree_right. momenta_right, Float64))
73
77
diverging = tree_left. diverging || tree_right. diverging
74
78
75
79
return Tree (tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
@@ -90,14 +94,10 @@ function build_tree(val, momenta, depth, eps, direction, weight_init, integrator
90
94
if direction == 1
91
95
other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction,
92
96
weight_init, integrator_state)
97
+ return merge_trees (tree, other_tree)
93
98
else
94
99
other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction,
95
100
weight_init, integrator_state)
96
- end
97
-
98
- if direction == 1
99
- return merge_trees (tree, other_tree)
100
- else
101
101
return merge_trees (other_tree, tree)
102
102
end
103
103
end
@@ -131,16 +131,16 @@ function nuts(
131
131
132
132
# values needed for a leapfrog step
133
133
(_, values_trie, _) = choice_gradients (trace, selection, retval_grad)
134
- values = to_array (values_trie, Float64)
135
134
136
- momenta = sample_momenta (length (values))
135
+ momenta = sample_momenta (length (to_array (values_trie, Float64)))
136
+ momenta_trie = from_array (values_trie, momenta)
137
137
prev_momenta_score = assess_momenta (momenta)
138
138
139
139
weight_init = prev_model_score + prev_momenta_score
140
140
141
- integrator_state = (values_trie, selection, retval_grad, trace)
141
+ integrator_state = (selection, retval_grad, trace)
142
142
143
- tree = Tree (values, momenta, values, momenta, values , 1 , - Inf , false , false )
143
+ tree = Tree (values_trie, momenta_trie, values_trie, momenta_trie, values_trie , 1 , - Inf , false , false )
144
144
145
145
direction = 0
146
146
depth = 0
@@ -165,25 +165,25 @@ function nuts(
165
165
depth += 1
166
166
end
167
167
168
- (new_trace, _, _) = update (trace, from_array (values_trie, tree. val_sample) )
168
+ (new_trace, _, _) = update (trace, tree. val_sample)
169
169
check && check_observations (get_choices (new_trace), observations)
170
170
171
171
# assess new model score (negative potential energy)
172
172
new_model_score = get_score (new_trace)
173
173
174
174
# assess new momenta score (negative kinetic energy)
175
175
if direction == 1
176
- new_momenta_score = assess_momenta ( - tree. momenta_right)
176
+ new_momenta_score = assess_momenta_trie ( tree. momenta_right)
177
177
else
178
- new_momenta_score = assess_momenta ( - tree. momenta_left)
178
+ new_momenta_score = assess_momenta_trie ( tree. momenta_left)
179
179
end
180
180
181
181
# accept or reject
182
182
alpha = new_model_score + new_momenta_score - weight_init
183
183
if log (rand ()) < alpha
184
- return (new_trace, SamplerStats (depth, tree. n, true ))
184
+ return (new_trace, SamplerStats (depth, tree. n, tree . diverging, true ))
185
185
else
186
- return (trace, SamplerStats (depth, tree. n, false ))
186
+ return (trace, SamplerStats (depth, tree. n, tree . diverging, false ))
187
187
end
188
188
end
189
189
0 commit comments