@@ -16,6 +16,7 @@ use crate::{
1616/// a cutoff value or nan.
1717/// - The logp function caused a recoverable error (eg if an ODE solver
1818/// failed)
19+ #[ non_exhaustive]
1920#[ derive( Debug , Clone ) ]
2021pub struct DivergenceInfo {
2122 pub start_momentum : Option < Box < [ f64 ] > > ,
@@ -26,6 +27,7 @@ pub struct DivergenceInfo {
2627 pub end_idx_in_trajectory : Option < i64 > ,
2728 pub start_idx_in_trajectory : Option < i64 > ,
2829 pub logp_function_error : Option < Arc < dyn std:: error:: Error + Send + Sync > > ,
30+ pub non_reversible : bool ,
2931}
3032
3133impl DivergenceInfo {
@@ -39,8 +41,67 @@ impl DivergenceInfo {
3941 end_idx_in_trajectory : None ,
4042 start_idx_in_trajectory : None ,
4143 logp_function_error : None ,
44+ non_reversible : false ,
4245 }
4346 }
47+
48+ pub fn new_energy_error_too_large < M : Math > (
49+ math : & mut M ,
50+ start : & State < M , impl Point < M > > ,
51+ stop : & State < M , impl Point < M > > ,
52+ ) -> Self {
53+ DivergenceInfo {
54+ logp_function_error : None ,
55+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
56+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
57+ // TODO
58+ start_momentum : None ,
59+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
60+ end_location : Some ( math. box_array ( & stop. point ( ) . position ( ) ) ) ,
61+ end_idx_in_trajectory : Some ( stop. index_in_trajectory ( ) ) ,
62+ // TODO
63+ energy_error : None ,
64+ non_reversible : false ,
65+ }
66+ }
67+
68+ pub fn new_logp_function_error < M : Math > (
69+ math : & mut M ,
70+ start : & State < M , impl Point < M > > ,
71+ logp_function_error : Arc < dyn std:: error:: Error + Send + Sync > ,
72+ ) -> Self {
73+ DivergenceInfo {
74+ logp_function_error : Some ( logp_function_error) ,
75+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
76+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
77+ // TODO
78+ start_momentum : None ,
79+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
80+ end_location : None ,
81+ end_idx_in_trajectory : None ,
82+ energy_error : None ,
83+ non_reversible : false ,
84+ }
85+ }
86+
87+ pub fn new_not_reversible < M : Math > ( math : & mut M , start : & State < M , impl Point < M > > ) -> Self {
88+ // TODO add info about what went wrong
89+ DivergenceInfo {
90+ logp_function_error : None ,
91+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
92+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
93+ // TODO
94+ start_momentum : None ,
95+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
96+ end_location : None ,
97+ end_idx_in_trajectory : None ,
98+ energy_error : None ,
99+ non_reversible : true ,
100+ }
101+ }
102+ pub fn new_max_step_size_halvings < M : Math > ( math : & mut M , num_steps : u64 , info : Self ) -> Self {
103+ info // TODO
104+ }
44105}
45106
46107#[ derive( Debug , Copy , Clone ) ]
@@ -106,10 +167,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
106167 math : & mut M ,
107168 start : & State < M , Self :: Point > ,
108169 dir : Direction ,
109- step_size_factor : f64 ,
170+ step_size_splits : u64 ,
110171 collector : & mut C ,
111172 ) -> LeapfrogResult < M , Self :: Point > ;
112173
174+ fn split_leapfrog < C : Collector < M , Self :: Point > > (
175+ & mut self ,
176+ math : & mut M ,
177+ start : & State < M , Self :: Point > ,
178+ dir : Direction ,
179+ num_steps : u64 ,
180+ collector : & mut C ,
181+ max_error : f64 ,
182+ ) -> LeapfrogResult < M , Self :: Point > {
183+ let mut state = start. clone ( ) ;
184+
185+ let mut min_energy = start. energy ( ) ;
186+ let mut max_energy = min_energy;
187+
188+ for _ in 0 ..num_steps {
189+ state = match self . leapfrog ( math, & state, dir, num_steps, collector) {
190+ LeapfrogResult :: Ok ( state) => state,
191+ LeapfrogResult :: Divergence ( info) => return LeapfrogResult :: Divergence ( info) ,
192+ LeapfrogResult :: Err ( err) => return LeapfrogResult :: Err ( err) ,
193+ } ;
194+ let energy = state. energy ( ) ;
195+ min_energy = min_energy. min ( energy) ;
196+ max_energy = max_energy. max ( energy) ;
197+
198+ // TODO: walnuts papers says to use abs, but c++ code doesn't?
199+ if max_energy - min_energy > max_error {
200+ let info = DivergenceInfo :: new_energy_error_too_large ( math, start, & state) ;
201+ return LeapfrogResult :: Divergence ( info) ;
202+ }
203+ }
204+
205+ LeapfrogResult :: Ok ( state)
206+ }
207+
113208 fn is_turning (
114209 & self ,
115210 math : & mut M ,
@@ -141,4 +236,6 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
141236
142237 fn step_size ( & self ) -> f64 ;
143238 fn step_size_mut ( & mut self ) -> & mut f64 ;
239+
240+ fn max_energy_error ( & self ) -> f64 ;
144241}
0 commit comments