Skip to content

Commit 60ccdbd

Browse files
committed
feat: add rng to Model.math()
1 parent 7f235fb commit 60ccdbd

File tree

9 files changed

+26
-14
lines changed

9 files changed

+26
-14
lines changed

examples/csv_trace.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ impl Model for MvnModel {
179179
where
180180
Self: 'model;
181181

182-
fn math(&self) -> Result<Self::Math<'_>> {
182+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
183183
Ok(self.math.clone())
184184
}
185185

examples/hashmap_storage.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ impl MultivariateNormal {
2727
}
2828

2929
// Custom LogpError implementation
30+
#[allow(dead_code)]
3031
#[derive(Debug, Error)]
3132
enum MyLogpError {
3233
#[error("Recoverable error in logp calculation: {0}")]
@@ -114,7 +115,7 @@ impl Model for MvnModel {
114115
where
115116
Self: 'model;
116117

117-
fn math(&self) -> Result<Self::Math<'_>> {
118+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
118119
Ok(self.math.clone())
119120
}
120121

@@ -228,6 +229,9 @@ fn main() -> Result<()> {
228229
nuts_rs::HashMapValue::U64(vec) => {
229230
println!(" {}: {} u64 draws", name, vec.len());
230231
}
232+
nuts_rs::HashMapValue::String(vec) => {
233+
println!(" {}: {} string draws", name, vec.len());
234+
}
231235
}
232236
}
233237

examples/ndarray_storage.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl Model for MvnModel {
116116
where
117117
Self: 'model;
118118

119-
fn math(&self) -> Result<Self::Math<'_>> {
119+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
120120
Ok(self.math.clone())
121121
}
122122

examples/zarr_async_trace.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ impl Model for MvnModel {
186186
where
187187
Self: 'model;
188188

189-
fn math(&self) -> Result<Self::Math<'_>> {
189+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
190190
Ok(self.math.clone())
191191
}
192192

examples/zarr_trace.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ impl Model for MvnModel {
185185
where
186186
Self: 'model;
187187

188-
fn math(&self) -> Result<Self::Math<'_>> {
188+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
189189
Ok(self.math.clone())
190190
}
191191

src/csv_storage.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ mod tests {
712712
where
713713
Self: 'model;
714714

715-
fn math(&self) -> Result<Self::Math<'_>> {
715+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
716716
Ok(self.math.clone())
717717
}
718718

@@ -790,7 +790,7 @@ mod tests {
790790
where
791791
Self: 'model;
792792

793-
fn math(&self) -> Result<Self::Math<'_>> {
793+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
794794
Ok(self.math.clone())
795795
}
796796

src/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub trait Model: Send + Sync + 'static {
2727
Self: 'model;
2828

2929
/// Returns the math backend for this model.
30-
fn math(&self) -> Result<Self::Math<'_>>;
30+
fn math<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Self::Math<'_>>;
3131

3232
/// Initializes the starting position for MCMC sampling.
3333
///

src/sampler.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ impl<T: TraceStorage> ChainProcess<T> {
563563
let (stop_marker_tx, stop_marker_rx) = channel();
564564

565565
let mut rng = ChaCha8Rng::seed_from_u64(seed);
566-
rng.set_stream(chain_id);
566+
rng.set_stream(chain_id + 1);
567567

568568
let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
569569
let progress = Arc::new(Mutex::new(ChainProgress::new(
@@ -578,7 +578,9 @@ impl<T: TraceStorage> ChainProcess<T> {
578578
let progress = progress_inner;
579579

580580
let mut sample = move || {
581-
let logp = model.math().context("Failed to create model density")?;
581+
let logp = model
582+
.math(&mut rng)
583+
.context("Failed to create model density")?;
582584
let dim = logp.dim();
583585

584586
let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
@@ -660,7 +662,7 @@ impl<T: TraceStorage> ChainProcess<T> {
660662

661663
let result = sample();
662664

663-
// We intentially ignore errors here, because this means some other
665+
// We intentionally ignore errors here, because this means some other
664666
// chain already failed, and should have reported the error.
665667
let _ = results.send(result);
666668
drop(results);
@@ -749,7 +751,12 @@ impl<F: Send + 'static> Sampler<F> {
749751
let results = results_tx;
750752
let mut chains = Vec::with_capacity(settings.num_chains());
751753

752-
let math = model_ref.math().context("Could not create model density")?;
754+
let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
755+
rng.set_stream(0);
756+
757+
let math = model_ref
758+
.math(&mut rng)
759+
.context("Could not create model density")?;
753760
let trace = trace_config
754761
.new_trace(settings_ref, &math)
755762
.context("Could not create trace object")?;
@@ -962,6 +969,7 @@ pub mod test_logps {
962969
};
963970
use anyhow::Result;
964971
use nuts_storable::HasDims;
972+
use rand::Rng;
965973
use thiserror::Error;
966974

967975
#[derive(Clone, Debug)]
@@ -1103,7 +1111,7 @@ pub mod test_logps {
11031111
{
11041112
type Math<'model> = CpuMath<&'model F>;
11051113

1106-
fn math(&self) -> Result<Self::Math<'_>> {
1114+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
11071115
Ok(CpuMath::new(&self.logp))
11081116
}
11091117

tests/sample_normal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl Model for NormalModel {
9595
where
9696
Self: 'model;
9797

98-
fn math(&self) -> anyhow::Result<Self::Math<'_>> {
98+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> anyhow::Result<Self::Math<'_>> {
9999
Ok(CpuMath::new(NormalLogp {
100100
dim: self.mu.len(),
101101
mu: &self.mu,

0 commit comments

Comments
 (0)