1
+ use serde:: Serialize ;
1
2
use thiserror:: Error ;
2
3
3
4
use std:: { fmt:: Debug , marker:: PhantomData } ;
@@ -120,7 +121,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
120
121
H : Hamiltonian < M > ,
121
122
R : rand:: Rng + ?Sized ,
122
123
{
123
- let mut other = match self . single_step ( math, hamiltonian, direction, collector) {
124
+ let mut other = match self . single_step ( math, hamiltonian, direction, options , collector) {
124
125
Ok ( Ok ( tree) ) => tree,
125
126
Ok ( Err ( info) ) => return ExtendResult :: Diverging ( self , info) ,
126
127
Err ( err) => return ExtendResult :: Err ( err) ,
@@ -213,19 +214,141 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
213
214
math : & mut M ,
214
215
hamiltonian : & mut H ,
215
216
direction : Direction ,
217
+ options : & NutsOptions ,
216
218
collector : & mut C ,
217
219
) -> Result < std:: result:: Result < NutsTree < M , H , C > , DivergenceInfo > > {
218
220
let start = match direction {
219
221
Direction :: Forward => & self . right ,
220
222
Direction :: Backward => & self . left ,
221
223
} ;
222
- let end = match hamiltonian. leapfrog ( math, start, direction, collector) {
223
- LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
224
- LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
225
- LeapfrogResult :: Ok ( end) => end,
224
+
225
+ let ( log_size, end) = match options. walnuts_options {
226
+ Some ( ref options) => {
227
+ // Walnuts implementation
228
+ // TODO: Shouldn't all be in this one big function...
229
+ let mut step_size_factor = 1.0 ;
230
+ let mut num_steps = 1 ;
231
+ let mut current = start. clone ( ) ;
232
+
233
+ let mut success = false ;
234
+
235
+ ' step_size_search: for _ in 0 ..options. max_step_size_halvings {
236
+ current = start. clone ( ) ;
237
+ let mut min_energy = current. energy ( ) ;
238
+ let mut max_energy = min_energy;
239
+
240
+ for _ in 0 ..num_steps {
241
+ current = match hamiltonian. leapfrog (
242
+ math,
243
+ & current,
244
+ direction,
245
+ step_size_factor,
246
+ collector,
247
+ ) {
248
+ LeapfrogResult :: Ok ( state) => state,
249
+ LeapfrogResult :: Divergence ( _) => {
250
+ num_steps *= 2 ;
251
+ step_size_factor *= 0.5 ;
252
+ continue ' step_size_search;
253
+ }
254
+ LeapfrogResult :: Err ( err) => {
255
+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
256
+ }
257
+ } ;
258
+
259
+ // Update min/max energies
260
+ let current_energy = current. energy ( ) ;
261
+ min_energy = min_energy. min ( current_energy) ;
262
+ max_energy = max_energy. max ( current_energy) ;
263
+ }
264
+
265
+ if max_energy - min_energy > options. max_energy_error {
266
+ num_steps *= 2 ;
267
+ step_size_factor *= 0.5 ;
268
+ continue ' step_size_search;
269
+ }
270
+
271
+ success = true ;
272
+ break ' step_size_search;
273
+ }
274
+
275
+ if !success {
276
+ // TODO: More info
277
+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
278
+ }
279
+
280
+ // TODO
281
+ let back = direction. reverse ( ) ;
282
+ let mut current_backward;
283
+
284
+ let mut reversible = true ;
285
+
286
+ ' rev_step_size: while num_steps >= 2 {
287
+ num_steps /= 2 ;
288
+ step_size_factor *= 0.5 ;
289
+
290
+ // TODO: Can we share code for the micro steps in the two directions?
291
+ current_backward = current. clone ( ) ;
292
+
293
+ let mut min_energy = current_backward. energy ( ) ;
294
+ let mut max_energy = min_energy;
295
+
296
+ for _ in 0 ..num_steps {
297
+ current_backward = match hamiltonian. leapfrog (
298
+ math,
299
+ & current_backward,
300
+ back,
301
+ step_size_factor,
302
+ collector,
303
+ ) {
304
+ LeapfrogResult :: Ok ( state) => state,
305
+ LeapfrogResult :: Divergence ( _) => {
306
+ // We also reject in the backward direction, all is good so far...
307
+ continue ' rev_step_size;
308
+ }
309
+ LeapfrogResult :: Err ( err) => {
310
+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
311
+ }
312
+ } ;
313
+
314
+ // Update min/max energies
315
+ let current_energy = current_backward. energy ( ) ;
316
+ min_energy = min_energy. min ( current_energy) ;
317
+ max_energy = max_energy. max ( current_energy) ;
318
+ if max_energy - min_energy > options. max_energy_error {
319
+ // We reject also in the backward direction, all good so far...
320
+ continue ' rev_step_size;
321
+ }
322
+ }
323
+
324
+ // We did not reject in the backward direction, so we are not reversible
325
+ reversible = false ;
326
+ break ;
327
+ }
328
+
329
+ if reversible {
330
+ let log_size = -current. point ( ) . energy_error ( ) ;
331
+ ( log_size, current)
332
+ } else {
333
+ // TODO: More info
334
+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
335
+ }
336
+ }
337
+ None => {
338
+ // Classical NUTS
339
+ //
340
+ let end = match hamiltonian. leapfrog ( math, start, direction, 1.0 , collector) {
341
+ LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
342
+ LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
343
+ LeapfrogResult :: Ok ( end) => end,
344
+ } ;
345
+
346
+ let log_size = -end. point ( ) . energy_error ( ) ;
347
+
348
+ ( log_size, end)
349
+ }
226
350
} ;
227
351
228
- let log_size = -end. point ( ) . energy_error ( ) ;
229
352
Ok ( Ok ( NutsTree {
230
353
right : end. clone ( ) ,
231
354
left : end. clone ( ) ,
@@ -248,13 +371,22 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
248
371
}
249
372
}
250
373
374
+ #[ derive( Debug , Clone , Copy , Serialize ) ]
375
+ pub struct WalnutsOptions {
376
+ pub max_energy_error : f64 ,
377
+ pub max_step_size_halvings : u64 ,
378
+ }
379
+
380
+ #[ derive( Debug , Clone , Copy ) ]
251
381
pub struct NutsOptions {
252
382
pub maxdepth : u64 ,
253
383
pub mindepth : u64 ,
254
384
pub store_gradient : bool ,
255
385
pub store_unconstrained : bool ,
256
386
pub check_turning : bool ,
257
387
pub store_divergences : bool ,
388
+
389
+ pub walnuts_options : Option < WalnutsOptions > ,
258
390
}
259
391
260
392
pub ( crate ) fn draw < M , H , R , C > (
0 commit comments