@@ -310,7 +310,11 @@ where
310310
311311#[ derive( Debug , Clone ) ]
312312#[ non_exhaustive]
313- pub struct NutsSampleStats < HStats : Send + Debug + Clone , AdaptStats : Send + Debug + Clone > {
313+ pub struct NutsSampleStats <
314+ PointStats : Send + Debug + Clone ,
315+ HStats : Send + Debug + Clone ,
316+ AdaptStats : Send + Debug + Clone ,
317+ > {
314318 pub depth : u64 ,
315319 pub maxdepth_reached : bool ,
316320 pub idx_in_trajectory : i64 ,
@@ -324,6 +328,7 @@ pub struct NutsSampleStats<HStats: Send + Debug + Clone, AdaptStats: Send + Debu
324328 pub unconstrained : Option < Box < [ f64 ] > > ,
325329 pub potential_stats : HStats ,
326330 pub strategy_stats : AdaptStats ,
331+ pub point_stats : PointStats ,
327332 pub tuning : bool ,
328333}
329334
@@ -338,7 +343,7 @@ pub struct SampleStats {
338343 pub num_steps : u64 ,
339344}
340345
341- pub struct NutsStatsBuilder < H , A > {
346+ pub struct NutsStatsBuilder < P , H , A > {
342347 depth : PrimitiveBuilder < UInt64Type > ,
343348 maxdepth_reached : BooleanBuilder ,
344349 index_in_trajectory : PrimitiveBuilder < Int64Type > ,
@@ -351,6 +356,7 @@ pub struct NutsStatsBuilder<H, A> {
351356 gradient : Option < FixedSizeListBuilder < PrimitiveBuilder < Float64Type > > > ,
352357 hamiltonian : H ,
353358 adapt : A ,
359+ point : P ,
354360 diverging : BooleanBuilder ,
355361 divergence_start : Option < FixedSizeListBuilder < PrimitiveBuilder < Float64Type > > > ,
356362 divergence_start_grad : Option < FixedSizeListBuilder < PrimitiveBuilder < Float64Type > > > ,
@@ -360,15 +366,17 @@ pub struct NutsStatsBuilder<H, A> {
360366 n_dim : usize ,
361367}
362368
363- impl < HB , AB > NutsStatsBuilder < HB , AB > {
369+ impl < PB , HB , AB > NutsStatsBuilder < PB , HB , AB > {
364370 pub fn new_with_capacity <
365371 M : Math ,
366- H : Hamiltonian < M , Builder = HB > ,
372+ P : Point < M , Builder = PB > ,
373+ H : Hamiltonian < M , Builder = HB , Point = P > ,
367374 A : AdaptStrategy < M , Builder = AB > ,
368375 > (
369376 settings : & impl Settings ,
370377 hamiltonian : & H ,
371378 adapt : & A ,
379+ point : & P ,
372380 dim : usize ,
373381 options : & NutsOptions ,
374382 ) -> Self {
@@ -430,6 +438,7 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
430438 unconstrained,
431439 hamiltonian : hamiltonian. new_builder ( settings, dim) ,
432440 adapt : adapt. new_builder ( settings, dim) ,
441+ point : point. new_builder ( settings, dim) ,
433442 diverging : BooleanBuilder :: with_capacity ( capacity) ,
434443 divergence_start : div_start,
435444 divergence_start_grad : div_start_grad,
@@ -441,14 +450,17 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
441450 }
442451}
443452
444- impl < HS , AS , HB , AB > StatTraceBuilder < NutsSampleStats < HS , AS > > for NutsStatsBuilder < HB , AB >
453+ impl < PS , HS , AS , PB , HB , AB > StatTraceBuilder < NutsSampleStats < PS , HS , AS > >
454+ for NutsStatsBuilder < PB , HB , AB >
445455where
446456 HB : StatTraceBuilder < HS > ,
447457 AB : StatTraceBuilder < AS > ,
458+ PB : StatTraceBuilder < PS > ,
448459 HS : Clone + Send + Debug ,
449460 AS : Clone + Send + Debug ,
461+ PS : Clone + Send + Debug ,
450462{
451- fn append_value ( & mut self , value : NutsSampleStats < HS , AS > ) {
463+ fn append_value ( & mut self , value : NutsSampleStats < PS , HS , AS > ) {
452464 let NutsSampleStats {
453465 depth,
454466 maxdepth_reached,
@@ -463,6 +475,7 @@ where
463475 unconstrained,
464476 potential_stats,
465477 strategy_stats,
478+ point_stats,
466479 tuning,
467480 } = value;
468481
@@ -532,6 +545,7 @@ where
532545
533546 self . hamiltonian . append_value ( potential_stats) ;
534547 self . adapt . append_value ( strategy_stats) ;
548+ self . point . append_value ( point_stats) ;
535549 }
536550
537551 fn finalize ( self ) -> Option < StructArray > {
@@ -548,6 +562,7 @@ where
548562 gradient,
549563 hamiltonian,
550564 adapt,
565+ point,
551566 mut diverging,
552567 divergence_start,
553568 divergence_start_grad,
@@ -615,6 +630,7 @@ where
615630
616631 merge_into ( hamiltonian, & mut arrays, & mut fields) ;
617632 merge_into ( adapt, & mut arrays, & mut fields) ;
633+ merge_into ( point, & mut arrays, & mut fields) ;
618634
619635 add_field ( gradient, "gradient" , & mut arrays, & mut fields) ;
620636 add_field (
@@ -667,6 +683,7 @@ where
667683 gradient,
668684 hamiltonian,
669685 adapt,
686+ point,
670687 diverging,
671688 divergence_start,
672689 divergence_start_grad,
@@ -734,6 +751,7 @@ where
734751
735752 merge_into ( hamiltonian, & mut arrays, & mut fields) ;
736753 merge_into ( adapt, & mut arrays, & mut fields) ;
754+ merge_into ( point, & mut arrays, & mut fields) ;
737755
738756 add_field ( gradient, "gradient" , & mut arrays, & mut fields) ;
739757 add_field (
0 commit comments