1+ using LinearAlgebra: dot
2+
3+ Tree = @NamedTuple begin
4+ val_left
5+ momenta_left
6+ val_right
7+ momenta_right
8+ val_sample
9+ n :: Int
10+ weight :: Float64
11+ stop :: Bool
12+ end
13+
14+ Stats = @NamedTuple begin
15+ depth
16+ n
17+ accept
18+ end
19+
20+ function u_turn (values_left, values_right, momenta_left, momenta_right)
21+ return (dot (values_left - values_right, momenta_right) >= 0 ) &&
22+ (dot (values_right - values_left, momenta_left) >= 0 )
23+ end
24+
25+ function leapfrog (values, momenta, eps, integrator_state)
26+ values_trie, selection, retval_grad, trace = integrator_state
27+
28+ values_trie = from_array (values_trie, values)
29+ (trace, _, _) = update (trace, values_trie)
30+ (_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
31+ gradient = to_array (gradient_trie, Float64)
32+
33+ # half step on momenta
34+ momenta += (eps / 2 ) * gradient
35+
36+ # full step on positions
37+ values += eps * momenta
38+
39+ # get new gradient
40+ values_trie = from_array (values_trie, values)
41+ (trace, _, _) = update (trace, values_trie)
42+ (_, _, gradient_trie) = choice_gradients (trace, selection, retval_grad)
43+ gradient = to_array (gradient_trie, Float64)
44+
45+ # half step on momenta
46+ momenta += (eps / 2 ) * gradient
47+ return values, momenta, get_score (trace)
48+ end
49+
50+ function build_root (val, momenta, eps, direction, integrator_state)
51+ val, momenta, lp = leapfrog (val, momenta, direction * eps, integrator_state)
52+ weight = lp + assess_momenta (momenta)
53+
54+ return Tree ((val, momenta, val, momenta, val, 1 , weight, false ))
55+ end
56+
57+ function merge_trees (tree_left, tree_right)
58+ # multinomial sampling
59+ if log (rand ()) < tree_right. weight - tree_left. weight
60+ sample = tree_right. val_sample
61+ else
62+ sample = tree_left. val_sample
63+ end
64+
65+ weight = logsumexp (tree_left. weight, tree_right. weight)
66+ n = tree_left. n + tree_right. n
67+ stop = tree_left. stop || tree_right. stop || u_turn (
68+ tree_left. val_left, tree_right. val_right, tree_left. momenta_left, tree_right. momenta_right
69+ )
70+
71+ return Tree ((tree_left. val_left, tree_left. momenta_left, tree_right. val_right,
72+ tree_right. momenta_right, sample, n, weight, stop))
73+ end
74+
75+ function build_tree (val, momenta, depth, eps, direction, integrator_state)
76+ if depth == 0
77+ return build_root (val, momenta, eps, direction, integrator_state)
78+ end
79+
80+ tree = build_tree (val, momenta, depth - 1 , eps, direction, integrator_state)
81+
82+ if tree. stop
83+ return tree
84+ end
85+
86+ if direction == 1
87+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth - 1 , eps, direction, integrator_state)
88+ else
89+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth - 1 , eps, direction, integrator_state)
90+ end
91+
92+ if direction == 1
93+ return merge_trees (tree, other_tree)
94+ else
95+ return merge_trees (other_tree, tree)
96+ end
97+ end
98+
99+ function nuts (
100+ trace:: Trace , selection:: Selection ; eps= 0.1 , max_treedepth= 15 ,
101+ check= false , observations= EmptyChoiceMap ())
102+ prev_model_score = get_score (trace)
103+ retval_grad = accepts_output_grad (get_gen_fn (trace)) ? zero (get_retval (trace)) : nothing
104+
105+ # values needed for a leapfrog step
106+ (_, values_trie, _) = choice_gradients (trace, selection, retval_grad)
107+ values = to_array (values_trie, Float64)
108+ integrator_state = (values_trie, selection, retval_grad, trace)
109+
110+ momenta = sample_momenta (length (values))
111+ prev_momenta_score = assess_momenta (momenta)
112+
113+ tree = Tree ((values, momenta, values, momenta, values, 1 , - Inf , false ))
114+
115+ direction = 0
116+ depth = 0
117+ stop = false
118+ while depth < max_treedepth
119+ direction = rand ([- 1 , 1 ])
120+
121+ if direction == 1 # going right
122+ other_tree = build_tree (tree. val_right, tree. momenta_right, depth, eps, direction, integrator_state)
123+ tree = merge_trees (tree, other_tree)
124+ else # going left
125+ other_tree = build_tree (tree. val_left, tree. momenta_left, depth, eps, direction, integrator_state)
126+ tree = merge_trees (other_tree, tree)
127+ end
128+
129+ stop = stop || tree. stop
130+ if stop
131+ break
132+ end
133+ depth += 1
134+ end
135+
136+ (new_trace, _, _) = update (trace, from_array (values_trie, tree. val_sample))
137+ check && check_observations (get_choices (new_trace), observations)
138+
139+ # assess new model score (negative potential energy)
140+ new_model_score = get_score (new_trace)
141+
142+ # assess new momenta score (negative kinetic energy)
143+ if direction == 1
144+ new_momenta_score = assess_momenta (- tree. momenta_right)
145+ else
146+ new_momenta_score = assess_momenta (- tree. momenta_left)
147+ end
148+
149+ # accept or reject
150+ alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
151+ if log (rand ()) < alpha
152+ return (new_trace, Stats ((depth, tree. n, true )))
153+ else
154+ return (trace, Stats ((depth, tree. n, false )))
155+ end
156+ end
157+
158+ export nuts
0 commit comments