1- use criterion:: { black_box, criterion_group, criterion_main, Criterion } ;
1+ use std:: hint:: black_box;
2+
3+ use criterion:: { criterion_group, criterion_main, Criterion } ;
24use nix:: sched:: { sched_setaffinity, CpuSet } ;
35use nix:: unistd:: Pid ;
4- use nuts_rs:: math:: { axpy, axpy_out, vector_dot} ;
5- use nuts_rs:: test_logps:: NormalLogp ;
6- use nuts_rs:: { new_sampler, sample_parallel, Chain , JitterInitFunc , SamplerArgs } ;
6+ use nuts_rs:: { Chain , CpuLogpFunc , CpuMath , LogpError , Math , Settings } ;
7+ use rand:: SeedableRng ;
78use rayon:: ThreadPoolBuilder ;
9+ use thiserror:: Error ;
810
9- fn make_sampler ( dim : usize , mu : f64 ) -> impl Chain {
10- let func = NormalLogp :: new ( dim , mu ) ;
11- new_sampler ( func , SamplerArgs :: default ( ) , 0 , 0 )
11+ # [ derive ( Debug ) ]
12+ struct PosteriorDensity {
13+ dim : usize ,
1214}
1315
14- pub fn sample_one ( mu : f64 , out : & mut [ f64 ] ) {
15- let mut sampler = make_sampler ( out. len ( ) , mu) ;
16+ // The density might fail in a recoverable or non-recoverable manner...
17+ #[ derive( Debug , Error ) ]
18+ enum PosteriorLogpError { }
19+ impl LogpError for PosteriorLogpError {
20+ fn is_recoverable ( & self ) -> bool {
21+ false
22+ }
23+ }
24+
25+ impl CpuLogpFunc for PosteriorDensity {
26+ type LogpError = PosteriorLogpError ;
27+
28+ // Only used for transforming adaptation.
29+ type TransformParams = ( ) ;
30+
31+ // We define a 10 dimensional normal distribution
32+ fn dim ( & self ) -> usize {
33+ self . dim
34+ }
35+
36+ // The normal likelihood with mean 3 and its gradient.
37+ fn logp ( & mut self , position : & [ f64 ] , grad : & mut [ f64 ] ) -> Result < f64 , Self :: LogpError > {
38+ let mu = 3f64 ;
39+ let logp = position
40+ . iter ( )
41+ . copied ( )
42+ . zip ( grad. iter_mut ( ) )
43+ . map ( |( x, grad) | {
44+ let diff = x - mu;
45+ * grad = -diff;
46+ -0.5 * diff * diff
47+ } )
48+ . sum ( ) ;
49+ return Ok ( logp) ;
50+ }
51+ }
52+
53+ fn make_sampler ( dim : usize ) -> impl Chain < CpuMath < PosteriorDensity > > {
54+ let func = PosteriorDensity { dim : dim } ;
55+
56+ let settings = nuts_rs:: DiagGradNutsSettings {
57+ num_tune : 1000 ,
58+ maxdepth : 3 , // small value just for testing...
59+ ..Default :: default ( )
60+ } ;
61+
62+ let math = nuts_rs:: CpuMath :: new ( func) ;
63+ let mut rng = rand:: rngs:: StdRng :: seed_from_u64 ( 42u64 ) ;
64+ settings. new_chain ( 0 , math, & mut rng)
65+ }
66+
67+ pub fn sample_one ( out : & mut [ f64 ] ) {
68+ let mut sampler = make_sampler ( out. len ( ) ) ;
1669 let init = vec ! [ 3.5 ; out. len( ) ] ;
1770 sampler. set_position ( & init) . unwrap ( ) ;
1871 for _ in 0 ..1000 {
@@ -36,87 +89,79 @@ fn criterion_benchmark(c: &mut Criterion) {
3689 cpu_set. set ( 0 ) . unwrap ( ) ;
3790 sched_setaffinity ( Pid :: from_raw ( 0 ) , & cpu_set) . unwrap ( ) ;
3891
39- for n in [ 10 , 12 , 14 , 100 , 800 , 802 ] {
40- let x = vec ! [ 2.5 ; n] ;
41- let mut y = vec ! [ 3.5 ; n] ;
42- let mut out = vec ! [ 0. ; n] ;
92+ for n in [ 4 , 16 , 17 , 100 , 4567 ] {
93+ let mut math = CpuMath :: new ( PosteriorDensity { dim : n } ) ;
94+
95+ let x = math. new_array ( ) ;
96+ let p = math. new_array ( ) ;
97+ let p2 = math. new_array ( ) ;
98+ let n1 = math. new_array ( ) ;
99+ let mut y = math. new_array ( ) ;
100+ let mut out = math. new_array ( ) ;
101+
102+ let x_vec = vec ! [ 2.5 ; n] ;
103+ let mut y_vec = vec ! [ 2.5 ; n] ;
104+
105+ c. bench_function ( & format ! ( "multiply {}" , n) , |b| {
106+ b. iter ( || math. array_mult ( black_box ( & x) , black_box ( & y) , black_box ( & mut out) ) ) ;
107+ } ) ;
43108
44- //axpy(&x, &mut y, 4.);
45109 c. bench_function ( & format ! ( "axpy {}" , n) , |b| {
46- b. iter ( || axpy ( black_box ( & x) , black_box ( & mut y) , black_box ( 4. ) ) ) ;
110+ b. iter ( || math . axpy ( black_box ( & x) , black_box ( & mut y) , black_box ( 4. ) ) ) ;
47111 } ) ;
48112
49113 c. bench_function ( & format ! ( "axpy_ndarray {}" , n) , |b| {
50114 b. iter ( || {
51- let x = ndarray:: aview1 ( black_box ( & x ) ) ;
52- let mut y = ndarray:: aview_mut1 ( black_box ( & mut y ) ) ;
115+ let x = ndarray:: aview1 ( black_box ( & x_vec ) ) ;
116+ let mut y = ndarray:: aview_mut1 ( black_box ( & mut y_vec ) ) ;
53117 //y *= &x;// * black_box(4.);
54118 y. scaled_add ( black_box ( 4f64 ) , & x) ;
55119 } ) ;
56120 } ) ;
57121
58- //axpy_out(&x, &y, 4., &mut out);
59122 c. bench_function ( & format ! ( "axpy_out {}" , n) , |b| {
60123 b. iter ( || {
61- axpy_out (
124+ math . axpy_out (
62125 black_box ( & x) ,
63126 black_box ( & y) ,
64127 black_box ( 4. ) ,
65128 black_box ( & mut out) ,
66129 )
67130 } ) ;
68131 } ) ;
69- //vector_dot(&x, &y);
132+
70133 c. bench_function ( & format ! ( "vector_dot {}" , n) , |b| {
71- b. iter ( || vector_dot ( black_box ( & x) , black_box ( & y) ) ) ;
134+ b. iter ( || math . array_vector_dot ( black_box ( & x) , black_box ( & y) ) ) ;
72135 } ) ;
73- /*
74- scalar_prods_of_diff(&x, &y, &a, &d);
75- c.bench_function(&format!("scalar_prods_of_diff {}", n), |b| {
136+
137+ c. bench_function ( & format ! ( "scalar_prods2 {}" , n) , |b| {
76138 b. iter ( || {
77- scalar_prods_of_diff(black_box(&x), black_box(&y), black_box(&a), black_box(&d))
139+ math. scalar_prods2 ( black_box ( & p) , black_box ( & p2) , black_box ( & x) , black_box ( & y) )
140+ } ) ;
141+ } ) ;
142+
143+ c. bench_function ( & format ! ( "scalar_prods3 {}" , n) , |b| {
144+ b. iter ( || {
145+ math. scalar_prods3 (
146+ black_box ( & p) ,
147+ black_box ( & p2) ,
148+ black_box ( & n1) ,
149+ black_box ( & x) ,
150+ black_box ( & y) ,
151+ )
78152 } ) ;
79153 } ) ;
80- */
81154 }
82155
83156 let mut out = vec ! [ 0. ; 10 ] ;
84157 c. bench_function ( "sample_1000_10" , |b| {
85- b. iter ( || sample_one ( black_box ( 3. ) , black_box ( & mut out) ) )
158+ b. iter ( || sample_one ( black_box ( & mut out) ) )
86159 } ) ;
87160
88161 let mut out = vec ! [ 0. ; 1000 ] ;
89162 c. bench_function ( "sample_1000_1000" , |b| {
90- b. iter ( || sample_one ( black_box ( 3. ) , black_box ( & mut out) ) )
163+ b. iter ( || sample_one ( black_box ( & mut out) ) )
91164 } ) ;
92-
93- for n in [ 10 , 12 , 1000 ] {
94- c. bench_function ( & format ! ( "sample_parallel_{}" , n) , |b| {
95- b. iter ( || {
96- let func = NormalLogp :: new ( n, 0. ) ;
97- let settings = black_box ( SamplerArgs :: default ( ) ) ;
98- let mut init_point_func = JitterInitFunc :: new ( ) ;
99- let n_chains = black_box ( 10 ) ;
100- let n_draws = black_box ( 1000 ) ;
101- let seed = black_box ( 42 ) ;
102- let n_try_init = 10 ;
103- let ( handle, channel) = sample_parallel (
104- func,
105- & mut init_point_func,
106- settings,
107- n_chains,
108- n_draws,
109- seed,
110- n_try_init,
111- )
112- . unwrap ( ) ;
113- let draws: Vec < _ > = channel. iter ( ) . collect ( ) ;
114- //assert_eq!(draws.len() as u64, (n_draws + settings.num_tune) * n_chains);
115- handle. join ( ) . unwrap ( ) ;
116- draws
117- } ) ;
118- } ) ;
119- }
120165}
121166
122167criterion_group ! ( benches, criterion_benchmark) ;
0 commit comments