@@ -563,7 +563,7 @@ impl<T: TraceStorage> ChainProcess<T> {
563
563
let ( stop_marker_tx, stop_marker_rx) = channel ( ) ;
564
564
565
565
let mut rng = ChaCha8Rng :: seed_from_u64 ( seed) ;
566
- rng. set_stream ( chain_id) ;
566
+ rng. set_stream ( chain_id + 1 ) ;
567
567
568
568
let chain_trace = Arc :: new ( Mutex :: new ( Some ( chain_trace) ) ) ;
569
569
let progress = Arc :: new ( Mutex :: new ( ChainProgress :: new (
@@ -578,7 +578,9 @@ impl<T: TraceStorage> ChainProcess<T> {
578
578
let progress = progress_inner;
579
579
580
580
let mut sample = move || {
581
- let logp = model. math ( ) . context ( "Failed to create model density" ) ?;
581
+ let logp = model
582
+ . math ( & mut rng)
583
+ . context ( "Failed to create model density" ) ?;
582
584
let dim = logp. dim ( ) ;
583
585
584
586
let mut sampler = settings. new_chain ( chain_id, logp, & mut rng) ;
@@ -660,7 +662,7 @@ impl<T: TraceStorage> ChainProcess<T> {
660
662
661
663
let result = sample ( ) ;
662
664
663
- // We intentially ignore errors here, because this means some other
665
+ // We intentionally ignore errors here, because this means some other
664
666
// chain already failed, and should have reported the error.
665
667
let _ = results. send ( result) ;
666
668
drop ( results) ;
@@ -749,7 +751,12 @@ impl<F: Send + 'static> Sampler<F> {
749
751
let results = results_tx;
750
752
let mut chains = Vec :: with_capacity ( settings. num_chains ( ) ) ;
751
753
752
- let math = model_ref. math ( ) . context ( "Could not create model density" ) ?;
754
+ let mut rng = ChaCha8Rng :: seed_from_u64 ( settings. seed ( ) ) ;
755
+ rng. set_stream ( 0 ) ;
756
+
757
+ let math = model_ref
758
+ . math ( & mut rng)
759
+ . context ( "Could not create model density" ) ?;
753
760
let trace = trace_config
754
761
. new_trace ( settings_ref, & math)
755
762
. context ( "Could not create trace object" ) ?;
@@ -962,6 +969,7 @@ pub mod test_logps {
962
969
} ;
963
970
use anyhow:: Result ;
964
971
use nuts_storable:: HasDims ;
972
+ use rand:: Rng ;
965
973
use thiserror:: Error ;
966
974
967
975
#[ derive( Clone , Debug ) ]
@@ -1103,7 +1111,7 @@ pub mod test_logps {
1103
1111
{
1104
1112
type Math < ' model > = CpuMath < & ' model F > ;
1105
1113
1106
- fn math ( & self ) -> Result < Self :: Math < ' _ > > {
1114
+ fn math < R : Rng + ? Sized > ( & self , _rng : & mut R ) -> Result < Self :: Math < ' _ > > {
1107
1115
Ok ( CpuMath :: new ( & self . logp ) )
1108
1116
}
1109
1117
0 commit comments