1+ use serde:: Serialize ;
12use thiserror:: Error ;
23
34use std:: { fmt:: Debug , marker:: PhantomData } ;
@@ -120,7 +121,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
120121 H : Hamiltonian < M > ,
121122 R : rand:: Rng + ?Sized ,
122123 {
123- let mut other = match self . single_step ( math, hamiltonian, direction, collector) {
124+ let mut other = match self . single_step ( math, hamiltonian, direction, options , collector) {
124125 Ok ( Ok ( tree) ) => tree,
125126 Ok ( Err ( info) ) => return ExtendResult :: Diverging ( self , info) ,
126127 Err ( err) => return ExtendResult :: Err ( err) ,
@@ -213,19 +214,141 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
213214 math : & mut M ,
214215 hamiltonian : & mut H ,
215216 direction : Direction ,
217+ options : & NutsOptions ,
216218 collector : & mut C ,
217219 ) -> Result < std:: result:: Result < NutsTree < M , H , C > , DivergenceInfo > > {
218220 let start = match direction {
219221 Direction :: Forward => & self . right ,
220222 Direction :: Backward => & self . left ,
221223 } ;
222- let end = match hamiltonian. leapfrog ( math, start, direction, collector) {
223- LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
224- LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
225- LeapfrogResult :: Ok ( end) => end,
224+
225+ let ( log_size, end) = match options. walnuts_options {
226+ Some ( ref options) => {
227+ // Walnuts implementation
228+ // TODO: Shouldn't all be in this one big function...
229+ let mut step_size_factor = 1.0 ;
230+ let mut num_steps = 1 ;
231+ let mut current = start. clone ( ) ;
232+
233+ let mut success = false ;
234+
235+ ' step_size_search: for _ in 0 ..options. max_step_size_halvings {
236+ current = start. clone ( ) ;
237+ let mut min_energy = current. energy ( ) ;
238+ let mut max_energy = min_energy;
239+
240+ for _ in 0 ..num_steps {
241+ current = match hamiltonian. leapfrog (
242+ math,
243+ & current,
244+ direction,
245+ step_size_factor,
246+ collector,
247+ ) {
248+ LeapfrogResult :: Ok ( state) => state,
249+ LeapfrogResult :: Divergence ( _) => {
250+ num_steps *= 2 ;
251+ step_size_factor *= 0.5 ;
252+ continue ' step_size_search;
253+ }
254+ LeapfrogResult :: Err ( err) => {
255+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
256+ }
257+ } ;
258+
259+ // Update min/max energies
260+ let current_energy = current. energy ( ) ;
261+ min_energy = min_energy. min ( current_energy) ;
262+ max_energy = max_energy. max ( current_energy) ;
263+ }
264+
265+ if max_energy - min_energy > options. max_energy_error {
266+ num_steps *= 2 ;
267+ step_size_factor *= 0.5 ;
268+ continue ' step_size_search;
269+ }
270+
271+ success = true ;
272+ break ' step_size_search;
273+ }
274+
275+ if !success {
276+ // TODO: More info
277+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
278+ }
279+
280+ // TODO
281+ let back = direction. reverse ( ) ;
282+ let mut current_backward;
283+
284+ let mut reversible = true ;
285+
286+ ' rev_step_size: while num_steps >= 2 {
287+ num_steps /= 2 ;
288+ step_size_factor *= 0.5 ;
289+
290+ // TODO: Can we share code for the micro steps in the two directions?
291+ current_backward = current. clone ( ) ;
292+
293+ let mut min_energy = current_backward. energy ( ) ;
294+ let mut max_energy = min_energy;
295+
296+ for _ in 0 ..num_steps {
297+ current_backward = match hamiltonian. leapfrog (
298+ math,
299+ & current_backward,
300+ back,
301+ step_size_factor,
302+ collector,
303+ ) {
304+ LeapfrogResult :: Ok ( state) => state,
305+ LeapfrogResult :: Divergence ( _) => {
306+ // We also reject in the backward direction, all is good so far...
307+ continue ' rev_step_size;
308+ }
309+ LeapfrogResult :: Err ( err) => {
310+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
311+ }
312+ } ;
313+
314+ // Update min/max energies
315+ let current_energy = current_backward. energy ( ) ;
316+ min_energy = min_energy. min ( current_energy) ;
317+ max_energy = max_energy. max ( current_energy) ;
318+ if max_energy - min_energy > options. max_energy_error {
319+ // We reject also in the backward direction, all good so far...
320+ continue ' rev_step_size;
321+ }
322+ }
323+
324+ // We did not reject in the backward direction, so we are not reversible
325+ reversible = false ;
326+ break ;
327+ }
328+
329+ if reversible {
330+ let log_size = -current. point ( ) . energy_error ( ) ;
331+ ( log_size, current)
332+ } else {
333+ // TODO: More info
334+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
335+ }
336+ }
337+ None => {
338+ // Classical NUTS
339+ //
340+ let end = match hamiltonian. leapfrog ( math, start, direction, 1.0 , collector) {
341+ LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
342+ LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
343+ LeapfrogResult :: Ok ( end) => end,
344+ } ;
345+
346+ let log_size = -end. point ( ) . energy_error ( ) ;
347+
348+ ( log_size, end)
349+ }
226350 } ;
227351
228- let log_size = -end. point ( ) . energy_error ( ) ;
229352 Ok ( Ok ( NutsTree {
230353 right : end. clone ( ) ,
231354 left : end. clone ( ) ,
@@ -248,13 +371,22 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
248371 }
249372}
250373
374+ #[ derive( Debug , Clone , Copy , Serialize ) ]
375+ pub struct WalnutsOptions {
376+ pub max_energy_error : f64 ,
377+ pub max_step_size_halvings : u64 ,
378+ }
379+
380+ #[ derive( Debug , Clone , Copy ) ]
251381pub struct NutsOptions {
252382 pub maxdepth : u64 ,
253383 pub mindepth : u64 ,
254384 pub store_gradient : bool ,
255385 pub store_unconstrained : bool ,
256386 pub check_turning : bool ,
257387 pub store_divergences : bool ,
388+
389+ pub walnuts_options : Option < WalnutsOptions > ,
258390}
259391
260392impl Default for NutsOptions {
0 commit comments