1
1
using LinearAlgebra: dot
2
2
3
- Tree = @NamedTuple begin
3
+ struct Tree
4
4
val_left
5
5
momenta_left
6
6
val_right
@@ -9,9 +9,10 @@ Tree = @NamedTuple begin
9
9
n :: Int
10
10
weight :: Float64
11
11
stop :: Bool
12
+ diverging :: Bool
12
13
end
13
14
14
- Stats = @NamedTuple begin
15
+ struct SamplerStats
15
16
depth
16
17
n
17
18
accept
@@ -47,11 +48,13 @@ function leapfrog(values, momenta, eps, integrator_state)
47
48
return values, momenta, get_score (trace)
48
49
end
49
50
50
- function build_root (val, momenta, eps, direction, integrator_state)
51
+ function build_root (val, momenta, eps, direction, weight_init, integrator_state)
51
52
val, momenta, lp = leapfrog (val, momenta, direction * eps, integrator_state)
52
53
weight = lp + assess_momenta (momenta)
53
54
54
- return Tree ((val, momenta, val, momenta, val, 1 , weight, false ))
55
+ diverging = weight - weight_init > 1000
56
+
57
+ return Tree (val, momenta, val, momenta, val, 1 , weight, false , diverging)
55
58
end
56
59
57
60
function merge_trees (tree_left, tree_right)
@@ -67,26 +70,29 @@ function merge_trees(tree_left, tree_right)
67
70
stop = tree_left. stop || tree_right. stop || u_turn (
68
71
tree_left. val_left, tree_right. val_right, tree_left. momenta_left, tree_right. momenta_right
69
72
)
73
+ diverging = tree_left. diverging || tree_right. diverging
70
74
71
- return Tree (( tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
72
- tree_right. momenta_right, sample, n, weight, stop) )
75
+ return Tree (tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
76
+ tree_right. momenta_right, sample, n, weight, stop, diverging )
73
77
end
74
78
75
- function build_tree (val, momenta, depth, eps, direction, integrator_state)
79
+ function build_tree (val, momenta, depth, eps, direction, weight_init, integrator_state)
76
80
if depth == 0
77
- return build_root (val, momenta, eps, direction, integrator_state)
81
+ return build_root (val, momenta, eps, direction, weight_init, integrator_state)
78
82
end
79
83
80
- tree = build_tree (val, momenta, depth - 1 , eps, direction, integrator_state)
84
+ tree = build_tree (val, momenta, depth - 1 , eps, direction, weight_init, integrator_state)
81
85
82
- if tree. stop
86
+ if tree. stop || tree . diverging
83
87
return tree
84
88
end
85
89
86
90
if direction == 1
87
- other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction, integrator_state)
91
+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction,
92
+ weight_init, integrator_state)
88
93
else
89
- other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction, integrator_state)
94
+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction,
95
+ weight_init, integrator_state)
90
96
end
91
97
92
98
if direction == 1
@@ -96,6 +102,27 @@ function build_tree(val, momenta, depth, eps, direction, integrator_state)
96
102
end
97
103
end
98
104
105
+ """
106
+ (new_trace, sampler_statistics) = nuts(
107
+ trace, selection::Selection;eps=0.1,
108
+ max_treedepth=15, check=false, observations=EmptyChoiceMap())
109
+
110
+ Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a struct `sampler_statistics` containing information about the sampled trajectory.
111
+
112
+ The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter.
113
+ The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`.
114
+
115
+ `sampler_statistics` is a struct containing the following fields:
116
+ - depth: the depth of the trajectory tree
117
+ - n: the number of samples in the trajectory tree
118
+ - sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree
119
+ - n_accept: how many intermediate samples were accepted
120
+ - accept: whether the sample was accepted or not
121
+
122
+ # References
123
+ Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434
124
+ Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246
125
+ """
99
126
function nuts (
100
127
trace:: Trace , selection:: Selection ; eps= 0.1 , max_treedepth= 15 ,
101
128
check= false , observations= EmptyChoiceMap ())
@@ -105,12 +132,15 @@ function nuts(
105
132
# values needed for a leapfrog step
106
133
(_, values_trie, _) = choice_gradients (trace, selection, retval_grad)
107
134
values = to_array (values_trie, Float64)
108
- integrator_state = (values_trie, selection, retval_grad, trace)
109
135
110
136
momenta = sample_momenta (length (values))
111
137
prev_momenta_score = assess_momenta (momenta)
112
138
113
- tree = Tree ((values, momenta, values, momenta, values, 1 , - Inf , false ))
139
+ weight_init = prev_model_score + prev_momenta_score
140
+
141
+ integrator_state = (values_trie, selection, retval_grad, trace)
142
+
143
+ tree = Tree (values, momenta, values, momenta, values, 1 , - Inf , false , false )
114
144
115
145
direction = 0
116
146
depth = 0
@@ -119,14 +149,16 @@ function nuts(
119
149
direction = rand ([- 1 , 1 ])
120
150
121
151
if direction == 1 # going right
122
- other_tree = build_tree (tree. val_right, tree. momenta_right, depth, eps, direction, integrator_state)
152
+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth, eps, direction,
153
+ weight_init, integrator_state)
123
154
tree = merge_trees (tree, other_tree)
124
155
else # going left
125
- other_tree = build_tree (tree. val_left, tree. momenta_left, depth, eps, direction, integrator_state)
156
+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth, eps, direction,
157
+ weight_init, integrator_state)
126
158
tree = merge_trees (other_tree, tree)
127
159
end
128
160
129
- stop = stop || tree. stop
161
+ stop = stop || tree. stop || tree . diverging
130
162
if stop
131
163
break
132
164
end
@@ -147,12 +179,13 @@ function nuts(
147
179
end
148
180
149
181
# accept or reject
150
- alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
182
+ alpha = new_model_score + new_momenta_score - weight_init
151
183
if log (rand ()) < alpha
152
- return (new_trace, Stats (( depth, tree. n, true ) ))
184
+ return (new_trace, SamplerStats ( depth, tree. n, true ))
153
185
else
154
- return (trace, Stats (( depth, tree. n, false ) ))
186
+ return (trace, SamplerStats ( depth, tree. n, false ))
155
187
end
156
188
end
157
189
158
- export nuts
190
+ export nuts
191
+
0 commit comments