@@ -310,7 +310,11 @@ where
310
310
311
311
#[ derive( Debug , Clone ) ]
312
312
#[ non_exhaustive]
313
- pub struct NutsSampleStats < HStats : Send + Debug + Clone , AdaptStats : Send + Debug + Clone > {
313
+ pub struct NutsSampleStats <
314
+ PointStats : Send + Debug + Clone ,
315
+ HStats : Send + Debug + Clone ,
316
+ AdaptStats : Send + Debug + Clone ,
317
+ > {
314
318
pub depth : u64 ,
315
319
pub maxdepth_reached : bool ,
316
320
pub idx_in_trajectory : i64 ,
@@ -324,6 +328,7 @@ pub struct NutsSampleStats<HStats: Send + Debug + Clone, AdaptStats: Send + Debu
324
328
pub unconstrained : Option < Box < [ f64 ] > > ,
325
329
pub potential_stats : HStats ,
326
330
pub strategy_stats : AdaptStats ,
331
+ pub point_stats : PointStats ,
327
332
pub tuning : bool ,
328
333
}
329
334
@@ -338,7 +343,7 @@ pub struct SampleStats {
338
343
pub num_steps : u64 ,
339
344
}
340
345
341
- pub struct NutsStatsBuilder < H , A > {
346
+ pub struct NutsStatsBuilder < P , H , A > {
342
347
depth : PrimitiveBuilder < UInt64Type > ,
343
348
maxdepth_reached : BooleanBuilder ,
344
349
index_in_trajectory : PrimitiveBuilder < Int64Type > ,
@@ -351,6 +356,7 @@ pub struct NutsStatsBuilder<H, A> {
351
356
gradient : Option < FixedSizeListBuilder < PrimitiveBuilder < Float64Type > > > ,
352
357
hamiltonian : H ,
353
358
adapt : A ,
359
+ point : P ,
354
360
diverging : BooleanBuilder ,
355
361
divergence_start : Option < FixedSizeListBuilder < PrimitiveBuilder < Float64Type > > > ,
356
362
divergence_start_grad : Option < FixedSizeListBuilder < PrimitiveBuilder < Float64Type > > > ,
@@ -360,15 +366,17 @@ pub struct NutsStatsBuilder<H, A> {
360
366
n_dim : usize ,
361
367
}
362
368
363
- impl < HB , AB > NutsStatsBuilder < HB , AB > {
369
+ impl < PB , HB , AB > NutsStatsBuilder < PB , HB , AB > {
364
370
pub fn new_with_capacity <
365
371
M : Math ,
366
- H : Hamiltonian < M , Builder = HB > ,
372
+ P : Point < M , Builder = PB > ,
373
+ H : Hamiltonian < M , Builder = HB , Point = P > ,
367
374
A : AdaptStrategy < M , Builder = AB > ,
368
375
> (
369
376
settings : & impl Settings ,
370
377
hamiltonian : & H ,
371
378
adapt : & A ,
379
+ point : & P ,
372
380
dim : usize ,
373
381
options : & NutsOptions ,
374
382
) -> Self {
@@ -430,6 +438,7 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
430
438
unconstrained,
431
439
hamiltonian : hamiltonian. new_builder ( settings, dim) ,
432
440
adapt : adapt. new_builder ( settings, dim) ,
441
+ point : point. new_builder ( settings, dim) ,
433
442
diverging : BooleanBuilder :: with_capacity ( capacity) ,
434
443
divergence_start : div_start,
435
444
divergence_start_grad : div_start_grad,
@@ -441,14 +450,17 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
441
450
}
442
451
}
443
452
444
- impl < HS , AS , HB , AB > StatTraceBuilder < NutsSampleStats < HS , AS > > for NutsStatsBuilder < HB , AB >
453
+ impl < PS , HS , AS , PB , HB , AB > StatTraceBuilder < NutsSampleStats < PS , HS , AS > >
454
+ for NutsStatsBuilder < PB , HB , AB >
445
455
where
446
456
HB : StatTraceBuilder < HS > ,
447
457
AB : StatTraceBuilder < AS > ,
458
+ PB : StatTraceBuilder < PS > ,
448
459
HS : Clone + Send + Debug ,
449
460
AS : Clone + Send + Debug ,
461
+ PS : Clone + Send + Debug ,
450
462
{
451
- fn append_value ( & mut self , value : NutsSampleStats < HS , AS > ) {
463
+ fn append_value ( & mut self , value : NutsSampleStats < PS , HS , AS > ) {
452
464
let NutsSampleStats {
453
465
depth,
454
466
maxdepth_reached,
@@ -463,6 +475,7 @@ where
463
475
unconstrained,
464
476
potential_stats,
465
477
strategy_stats,
478
+ point_stats,
466
479
tuning,
467
480
} = value;
468
481
@@ -532,6 +545,7 @@ where
532
545
533
546
self . hamiltonian . append_value ( potential_stats) ;
534
547
self . adapt . append_value ( strategy_stats) ;
548
+ self . point . append_value ( point_stats) ;
535
549
}
536
550
537
551
fn finalize ( self ) -> Option < StructArray > {
@@ -548,6 +562,7 @@ where
548
562
gradient,
549
563
hamiltonian,
550
564
adapt,
565
+ point,
551
566
mut diverging,
552
567
divergence_start,
553
568
divergence_start_grad,
@@ -615,6 +630,7 @@ where
615
630
616
631
merge_into ( hamiltonian, & mut arrays, & mut fields) ;
617
632
merge_into ( adapt, & mut arrays, & mut fields) ;
633
+ merge_into ( point, & mut arrays, & mut fields) ;
618
634
619
635
add_field ( gradient, "gradient" , & mut arrays, & mut fields) ;
620
636
add_field (
@@ -667,6 +683,7 @@ where
667
683
gradient,
668
684
hamiltonian,
669
685
adapt,
686
+ point,
670
687
diverging,
671
688
divergence_start,
672
689
divergence_start_grad,
@@ -734,6 +751,7 @@ where
734
751
735
752
merge_into ( hamiltonian, & mut arrays, & mut fields) ;
736
753
merge_into ( adapt, & mut arrays, & mut fields) ;
754
+ merge_into ( point, & mut arrays, & mut fields) ;
737
755
738
756
add_field ( gradient, "gradient" , & mut arrays, & mut fields) ;
739
757
add_field (
0 commit comments