@@ -60,22 +60,41 @@ pub struct SampleInfo {
6060}
6161
6262/// A part of the trajectory tree during NUTS sampling.
63+ ///
64+ /// Corresponds to SpanW in walnuts C++ code
6365struct NutsTree < M : Math , H : Hamiltonian < M > , C : Collector < M , H :: Point > > {
6466 /// The left position of the tree.
6567 ///
6668 /// The left side always has the smaller index_in_trajectory.
6769 /// Leapfrogs in backward direction will replace the left.
70+ ///
71+ /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code
6872 left : State < M , H :: Point > ,
73+
74+ /// The right position of the tree.
75+ ///
76+ /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code
6977 right : State < M , H :: Point > ,
7078
7179 /// A draw from the trajectory between left and right using
7280 /// multinomial sampling.
81+ ///
82+ /// theta_select_ in C++ code
7383 draw : State < M , H :: Point > ,
84+
85+ /// Constant for acceptance probability
86+ ///
87+ /// logp_ in C++ code
7488 log_size : f64 ,
89+
90+ /// The depth of the tree
7591 depth : u64 ,
7692
7793 /// A tree is the main tree if it contains the initial point
7894 /// of the trajectory.
95+ ///
96+ /// This is used to determine whether to use Metropolis
97+ /// accptance or Barker
7998 is_main : bool ,
8099 _phantom2 : PhantomData < C > ,
81100}
@@ -172,6 +191,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
172191 }
173192 }
174193
194+ // `combine` in C++ code
175195 fn merge_into < R : rand:: Rng + ?Sized > (
176196 & mut self ,
177197 _math : & mut M ,
@@ -209,6 +229,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
209229 self . log_size = log_size;
210230 }
211231
232+ // Corresponds to `build_leaf` in C++ code
212233 fn single_step (
213234 & self ,
214235 math : & mut M ,
0 commit comments