@@ -120,7 +120,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
120
120
H : Hamiltonian < M > ,
121
121
R : rand:: Rng + ?Sized ,
122
122
{
123
- let mut other = match self . single_step ( math, hamiltonian, direction, collector) {
123
+ let mut other = match self . single_step ( math, hamiltonian, direction, options , collector) {
124
124
Ok ( Ok ( tree) ) => tree,
125
125
Ok ( Err ( info) ) => return ExtendResult :: Diverging ( self , info) ,
126
126
Err ( err) => return ExtendResult :: Err ( err) ,
@@ -213,19 +213,141 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
213
213
math : & mut M ,
214
214
hamiltonian : & mut H ,
215
215
direction : Direction ,
216
+ options : & NutsOptions ,
216
217
collector : & mut C ,
217
218
) -> Result < std:: result:: Result < NutsTree < M , H , C > , DivergenceInfo > > {
218
219
let start = match direction {
219
220
Direction :: Forward => & self . right ,
220
221
Direction :: Backward => & self . left ,
221
222
} ;
222
- let end = match hamiltonian. leapfrog ( math, start, direction, collector) {
223
- LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
224
- LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
225
- LeapfrogResult :: Ok ( end) => end,
223
+
224
+ let ( log_size, end) = match options. walnuts_options {
225
+ Some ( ref options) => {
226
+ // Walnuts implementation
227
+ // TODO: Shouldn't all be in this one big function...
228
+ let mut step_size_factor = 1.0 ;
229
+ let mut num_steps = 1 ;
230
+ let mut current = start. clone ( ) ;
231
+
232
+ let mut success = false ;
233
+
234
+ ' step_size_search: for _ in 0 ..options. max_step_size_halvings {
235
+ current = start. clone ( ) ;
236
+ let mut min_energy = current. energy ( ) ;
237
+ let mut max_energy = min_energy;
238
+
239
+ for _ in 0 ..num_steps {
240
+ current = match hamiltonian. leapfrog (
241
+ math,
242
+ & current,
243
+ direction,
244
+ step_size_factor,
245
+ collector,
246
+ ) {
247
+ LeapfrogResult :: Ok ( state) => state,
248
+ LeapfrogResult :: Divergence ( _) => {
249
+ num_steps *= 2 ;
250
+ step_size_factor *= 0.5 ;
251
+ continue ' step_size_search;
252
+ }
253
+ LeapfrogResult :: Err ( err) => {
254
+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
255
+ }
256
+ } ;
257
+
258
+ // Update min/max energies
259
+ let current_energy = current. energy ( ) ;
260
+ min_energy = min_energy. min ( current_energy) ;
261
+ max_energy = max_energy. max ( current_energy) ;
262
+ }
263
+
264
+ if max_energy - min_energy > options. max_energy_error {
265
+ num_steps *= 2 ;
266
+ step_size_factor *= 0.5 ;
267
+ continue ' step_size_search;
268
+ }
269
+
270
+ success = true ;
271
+ break ' step_size_search;
272
+ }
273
+
274
+ if !success {
275
+ // TODO: More info
276
+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
277
+ }
278
+
279
+ // TODO
280
+ let back = direction. reverse ( ) ;
281
+ let mut current_backward;
282
+
283
+ let mut reversible = true ;
284
+
285
+ ' rev_step_size: while num_steps >= 2 {
286
+ num_steps /= 2 ;
287
+ step_size_factor *= 0.5 ;
288
+
289
+ // TODO: Can we share code for the micro steps in the two directions?
290
+ current_backward = current. clone ( ) ;
291
+
292
+ let mut min_energy = current_backward. energy ( ) ;
293
+ let mut max_energy = min_energy;
294
+
295
+ for _ in 0 ..num_steps {
296
+ current_backward = match hamiltonian. leapfrog (
297
+ math,
298
+ & current_backward,
299
+ back,
300
+ step_size_factor,
301
+ collector,
302
+ ) {
303
+ LeapfrogResult :: Ok ( state) => state,
304
+ LeapfrogResult :: Divergence ( _) => {
305
+ // We also reject in the backward direction, all is good so far...
306
+ continue ' rev_step_size;
307
+ }
308
+ LeapfrogResult :: Err ( err) => {
309
+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
310
+ }
311
+ } ;
312
+
313
+ // Update min/max energies
314
+ let current_energy = current_backward. energy ( ) ;
315
+ min_energy = min_energy. min ( current_energy) ;
316
+ max_energy = max_energy. max ( current_energy) ;
317
+ if max_energy - min_energy > options. max_energy_error {
318
+ // We reject also in the backward direction, all good so far...
319
+ continue ' rev_step_size;
320
+ }
321
+ }
322
+
323
+ // We did not reject in the backward direction, so we are not reversible
324
+ reversible = false ;
325
+ break ;
326
+ }
327
+
328
+ if reversible {
329
+ let log_size = -current. point ( ) . energy_error ( ) ;
330
+ ( log_size, current)
331
+ } else {
332
+ // TODO: More info
333
+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
334
+ }
335
+ }
336
+ None => {
337
+ // Classical NUTS
338
+ //
339
+ let end = match hamiltonian. leapfrog ( math, start, direction, 1.0 , collector) {
340
+ LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
341
+ LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
342
+ LeapfrogResult :: Ok ( end) => end,
343
+ } ;
344
+
345
+ let log_size = -end. point ( ) . energy_error ( ) ;
346
+
347
+ ( log_size, end)
348
+ }
226
349
} ;
227
350
228
- let log_size = -end. point ( ) . energy_error ( ) ;
229
351
Ok ( Ok ( NutsTree {
230
352
right : end. clone ( ) ,
231
353
left : end. clone ( ) ,
@@ -248,12 +370,21 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
248
370
}
249
371
}
250
372
373
+ #[ derive( Debug , Clone , Copy ) ]
374
+ pub struct WalnutsOptions {
375
+ pub max_energy_error : f64 ,
376
+ pub max_step_size_halvings : u64 ,
377
+ }
378
+
379
+ #[ derive( Debug , Clone , Copy ) ]
251
380
pub struct NutsOptions {
252
381
pub maxdepth : u64 ,
253
382
pub store_gradient : bool ,
254
383
pub store_unconstrained : bool ,
255
384
pub check_turning : bool ,
256
385
pub store_divergences : bool ,
386
+
387
+ pub walnuts_options : Option < WalnutsOptions > ,
257
388
}
258
389
259
390
pub ( crate ) fn draw < M , H , R , C > (
0 commit comments