Skip to content

Commit 129b94c

Browse files
committed
Add diverging sampler stats
1 parent 73cc0b3 commit 129b94c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/nuts.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
518518
gradient: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
519519
hamiltonian: <H::Stats as ArrowRow>::Builder,
520520
adapt: <A::Stats as ArrowRow>::Builder,
521+
diverging: MutableBooleanArray,
521522
}
522523

523524
#[cfg(feature = "arrow")]
@@ -555,6 +556,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
555556
unconstrained,
556557
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
557558
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
559+
diverging: MutableBooleanArray::with_capacity(capacity),
558560
}
559561
}
560562
}
@@ -571,6 +573,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
571573
self.energy.push(Some(value.energy));
572574
self.chain.push(Some(value.chain));
573575
self.draw.push(Some(value.draw));
576+
self.diverging.push(Some(value.divergence_info().is_some()));
574577

575578
if let Some(store) = self.gradient.as_mut() {
576579
store
@@ -607,6 +610,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
607610
Field::new("energy", DataType::Float64, false),
608611
Field::new("chain", DataType::UInt64, false),
609612
Field::new("draw", DataType::UInt64, false),
613+
Field::new("diverging", DataType::Boolean, false),
610614
];
611615

612616
let mut arrays = vec![
@@ -617,6 +621,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
617621
self.energy.as_box(),
618622
self.chain.as_box(),
619623
self.draw.as_box(),
624+
self.diverging.as_box(),
620625
];
621626

622627
if let Some(hamiltonian) = self.hamiltonian.finalize() {

0 commit comments

Comments
 (0)