1
+ using LinearAlgebra: dot
2
+
3
+ Tree = @NamedTuple begin
4
+ val_left
5
+ momenta_left
6
+ val_right
7
+ momenta_right
8
+ val_sample
9
+ n :: Int
10
+ weight :: Float64
11
+ stop :: Bool
12
+ end
13
+
14
+ Stats = @NamedTuple begin
15
+ depth
16
+ n
17
+ accept
18
+ end
19
+
20
+ function u_turn (values_left, values_right, momenta_left, momenta_right)
21
+ return (dot (values_left - values_right, momenta_right) >= 0 ) &&
22
+ (dot (values_right - values_left, momenta_left) >= 0 )
23
+ end
24
+
25
+ function leapfrog (values, momenta, eps, integrator_state)
26
+ values_trie, selection, retval_grad, trace = integrator_state
27
+
28
+ values_trie = from_array (values_trie, values)
29
+ (trace, _, _) = update (trace, values_trie)
30
+ (_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
31
+ gradient = to_array (gradient_trie, Float64)
32
+
33
+ # half step on momenta
34
+ momenta += (eps / 2 ) * gradient
35
+
36
+ # full step on positions
37
+ values += eps * momenta
38
+
39
+ # get new gradient
40
+ values_trie = from_array (values_trie, values)
41
+ (trace, _, _) = update (trace, values_trie)
42
+ (_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
43
+ gradient = to_array (gradient_trie, Float64)
44
+
45
+ # half step on momenta
46
+ momenta += (eps / 2 ) * gradient
47
+ return values, momenta, get_score (trace)
48
+ end
49
+
50
+ function build_root (val, momenta, eps, direction, integrator_state)
51
+ val, momenta, lp = leapfrog (val, momenta, direction * eps, integrator_state)
52
+ weight = lp + assess_momenta (momenta)
53
+
54
+ return Tree ((val, momenta, val, momenta, val, 1 , weight, false ))
55
+ end
56
+
57
+ function merge_trees (tree_left, tree_right)
58
+ # multinomial sampling
59
+ if log (rand ()) < tree_right. weight - tree_left. weight
60
+ sample = tree_right. val_sample
61
+ else
62
+ sample = tree_left. val_sample
63
+ end
64
+
65
+ weight = logsumexp (tree_left. weight, tree_right. weight)
66
+ n = tree_left. n + tree_right. n
67
+ stop = tree_left. stop || tree_right. stop || u_turn (
68
+ tree_left. val_left, tree_right. val_right, tree_left. momenta_left, tree_right. momenta_right
69
+ )
70
+
71
+ return Tree ((tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
72
+ tree_right. momenta_right, sample, n, weight, stop))
73
+ end
74
+
75
+ function build_tree (val, momenta, depth, eps, direction, integrator_state)
76
+ if depth == 0
77
+ return build_root (val, momenta, eps, direction, integrator_state)
78
+ end
79
+
80
+ tree = build_tree (val, momenta, depth - 1 , eps, direction, integrator_state)
81
+
82
+ if tree. stop
83
+ return tree
84
+ end
85
+
86
+ if direction == 1
87
+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction, integrator_state)
88
+ else
89
+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction, integrator_state)
90
+ end
91
+
92
+ if direction == 1
93
+ return merge_trees (tree, other_tree)
94
+ else
95
+ return merge_trees (other_tree, tree)
96
+ end
97
+ end
98
+
99
+ function nuts (
100
+ trace:: Trace , selection:: Selection ; eps= 0.1 , max_treedepth= 15 ,
101
+ check= false , observations= EmptyChoiceMap ())
102
+ prev_model_score = get_score (trace)
103
+ retval_grad = accepts_output_grad (get_gen_fn (trace)) ? zero (get_retval (trace)) : nothing
104
+
105
+ # values needed for a leapfrog step
106
+ (_, values_trie, _) = choice_gradients (trace, selection, retval_grad)
107
+ values = to_array (values_trie, Float64)
108
+ integrator_state = (values_trie, selection, retval_grad, trace)
109
+
110
+ momenta = sample_momenta (length (values))
111
+ prev_momenta_score = assess_momenta (momenta)
112
+
113
+ tree = Tree ((values, momenta, values, momenta, values, 1 , - Inf , false ))
114
+
115
+ direction = 0
116
+ depth = 0
117
+ stop = false
118
+ while depth < max_treedepth
119
+ direction = rand ([- 1 , 1 ])
120
+
121
+ if direction == 1 # going right
122
+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth, eps, direction, integrator_state)
123
+ tree = merge_trees (tree, other_tree)
124
+ else # going left
125
+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth, eps, direction, integrator_state)
126
+ tree = merge_trees (other_tree, tree)
127
+ end
128
+
129
+ stop = stop || tree. stop
130
+ if stop
131
+ break
132
+ end
133
+ depth += 1
134
+ end
135
+
136
+ (new_trace, _, _) = update (trace, from_array (values_trie, tree. val_sample))
137
+ check && check_observations (get_choices (new_trace), observations)
138
+
139
+ # assess new model score (negative potential energy)
140
+ new_model_score = get_score (new_trace)
141
+
142
+ # assess new momenta score (negative kinetic energy)
143
+ if direction == 1
144
+ new_momenta_score = assess_momenta (- tree. momenta_right)
145
+ else
146
+ new_momenta_score = assess_momenta (- tree. momenta_left)
147
+ end
148
+
149
+ # accept or reject
150
+ alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
151
+ if log (rand ()) < alpha
152
+ return (new_trace, Stats ((depth, tree. n, true )))
153
+ else
154
+ return (trace, Stats ((depth, tree. n, false )))
155
+ end
156
+ end
157
+
158
+ export nuts
0 commit comments