@@ -60,22 +60,41 @@ pub struct SampleInfo {
60
60
}
61
61
62
62
/// A part of the trajectory tree during NUTS sampling.
63
+ ///
64
+ /// Corresponds to SpanW in walnuts C++ code
63
65
struct NutsTree < M : Math , H : Hamiltonian < M > , C : Collector < M , H :: Point > > {
64
66
/// The left position of the tree.
65
67
///
66
68
/// The left side always has the smaller index_in_trajectory.
67
69
/// Leapfrogs in backward direction will replace the left.
70
+ ///
71
+ /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code
68
72
left : State < M , H :: Point > ,
73
+
74
+ /// The right position of the tree.
75
+ ///
76
+ /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code
69
77
right : State < M , H :: Point > ,
70
78
71
79
/// A draw from the trajectory between left and right using
72
80
/// multinomial sampling.
81
+ ///
82
+ /// theta_select_ in C++ code
73
83
draw : State < M , H :: Point > ,
84
+
85
+ /// Constant for acceptance probability
86
+ ///
87
+ /// logp_ in C++ code
74
88
log_size : f64 ,
89
+
90
+ /// The depth of the tree
75
91
depth : u64 ,
76
92
77
93
/// A tree is the main tree if it contains the initial point
78
94
/// of the trajectory.
95
+ ///
96
+ /// This is used to determine whether to use Metropolis
97
+ /// accptance or Barker
79
98
is_main : bool ,
80
99
_phantom2 : PhantomData < C > ,
81
100
}
@@ -172,6 +191,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
172
191
}
173
192
}
174
193
194
+ // `combine` in C++ code
175
195
fn merge_into < R : rand:: Rng + ?Sized > (
176
196
& mut self ,
177
197
_math : & mut M ,
@@ -209,6 +229,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
209
229
self . log_size = log_size;
210
230
}
211
231
232
+ // Corresponds to `build_leaf` in C++ code
212
233
fn single_step (
213
234
& self ,
214
235
math : & mut M ,
0 commit comments