@@ -238,7 +238,6 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
238
238
rng : & mut R ,
239
239
potential : & mut P ,
240
240
direction : Direction ,
241
- options : & NutsOptions ,
242
241
collector : & mut C ,
243
242
) -> ExtendResult < P , C >
244
243
where
@@ -253,7 +252,7 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
253
252
254
253
while other. depth < self . depth {
255
254
use ExtendResult :: * ;
256
- other = match other. extend ( pool, rng, potential, direction, options , collector) {
255
+ other = match other. extend ( pool, rng, potential, direction, collector) {
257
256
Ok ( tree) => tree,
258
257
Turning ( _) => {
259
258
return Turning ( self ) ;
@@ -358,13 +357,9 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
358
357
}
359
358
360
359
fn info ( & self , maxdepth : bool , divergence_info : Option < DivergenceInfo > ) -> SampleInfo {
361
- let info: Option < DivergenceInfo > = match divergence_info {
362
- Some ( info) => Some ( info) ,
363
- None => None ,
364
- } ;
365
360
SampleInfo {
366
361
depth : self . depth ,
367
- divergence_info : info ,
362
+ divergence_info,
368
363
reached_maxdepth : maxdepth,
369
364
}
370
365
}
@@ -395,7 +390,7 @@ where
395
390
let mut tree = NutsTree :: new ( init. clone ( ) ) ;
396
391
while tree. depth < options. maxdepth {
397
392
let direction: Direction = rng. gen ( ) ;
398
- tree = match tree. extend ( pool, rng, potential, direction, options , collector) {
393
+ tree = match tree. extend ( pool, rng, potential, direction, collector) {
399
394
ExtendResult :: Ok ( tree) => tree,
400
395
ExtendResult :: Turning ( tree) => {
401
396
let info = tree. info ( false , None ) ;
@@ -769,6 +764,8 @@ where
769
764
#[ cfg( test) ]
770
765
#[ cfg( feature = "arrow" ) ]
771
766
mod tests {
767
+ use rand:: thread_rng;
768
+
772
769
use crate :: { adapt_strategy:: test_logps:: NormalLogp , new_sampler, Chain , SamplerArgs } ;
773
770
774
771
use super :: ArrowBuilder ;
@@ -779,8 +776,9 @@ mod tests {
779
776
let func = NormalLogp :: new ( ndim, 3. ) ;
780
777
781
778
let settings = SamplerArgs :: default ( ) ;
779
+ let mut rng = thread_rng ( ) ;
782
780
783
- let mut chain = new_sampler ( func, settings, 0 , 0 ) ;
781
+ let mut chain = new_sampler ( func, settings, 0 , & mut rng ) ;
784
782
785
783
let mut builder = chain. stats_builder ( ndim, & settings) ;
786
784
0 commit comments