1
- use criterion:: { black_box, criterion_group, criterion_main, Criterion } ;
1
+ use std:: hint:: black_box;
2
+
3
+ use criterion:: { criterion_group, criterion_main, Criterion } ;
2
4
use nix:: sched:: { sched_setaffinity, CpuSet } ;
3
5
use nix:: unistd:: Pid ;
4
- use nuts_rs:: math:: { axpy, axpy_out, vector_dot} ;
5
- use nuts_rs:: test_logps:: NormalLogp ;
6
- use nuts_rs:: { new_sampler, sample_parallel, Chain , JitterInitFunc , SamplerArgs } ;
6
+ use nuts_rs:: { Chain , CpuLogpFunc , CpuMath , LogpError , Math , Settings } ;
7
+ use rand:: SeedableRng ;
7
8
use rayon:: ThreadPoolBuilder ;
9
+ use thiserror:: Error ;
8
10
9
- fn make_sampler ( dim : usize , mu : f64 ) -> impl Chain {
10
- let func = NormalLogp :: new ( dim , mu ) ;
11
- new_sampler ( func , SamplerArgs :: default ( ) , 0 , 0 )
11
+ # [ derive ( Debug ) ]
12
+ struct PosteriorDensity {
13
+ dim : usize ,
12
14
}
13
15
14
- pub fn sample_one ( mu : f64 , out : & mut [ f64 ] ) {
15
- let mut sampler = make_sampler ( out. len ( ) , mu) ;
16
+ // The density might fail in a recoverable or non-recoverable manner...
17
+ #[ derive( Debug , Error ) ]
18
+ enum PosteriorLogpError { }
19
+ impl LogpError for PosteriorLogpError {
20
+ fn is_recoverable ( & self ) -> bool {
21
+ false
22
+ }
23
+ }
24
+
25
+ impl CpuLogpFunc for PosteriorDensity {
26
+ type LogpError = PosteriorLogpError ;
27
+
28
+ // Only used for transforming adaptation.
29
+ type TransformParams = ( ) ;
30
+
31
+ // We define a 10 dimensional normal distribution
32
+ fn dim ( & self ) -> usize {
33
+ self . dim
34
+ }
35
+
36
+ // The normal likelihood with mean 3 and its gradient.
37
+ fn logp ( & mut self , position : & [ f64 ] , grad : & mut [ f64 ] ) -> Result < f64 , Self :: LogpError > {
38
+ let mu = 3f64 ;
39
+ let logp = position
40
+ . iter ( )
41
+ . copied ( )
42
+ . zip ( grad. iter_mut ( ) )
43
+ . map ( |( x, grad) | {
44
+ let diff = x - mu;
45
+ * grad = -diff;
46
+ -0.5 * diff * diff
47
+ } )
48
+ . sum ( ) ;
49
+ return Ok ( logp) ;
50
+ }
51
+ }
52
+
53
+ fn make_sampler ( dim : usize ) -> impl Chain < CpuMath < PosteriorDensity > > {
54
+ let func = PosteriorDensity { dim : dim } ;
55
+
56
+ let settings = nuts_rs:: DiagGradNutsSettings {
57
+ num_tune : 1000 ,
58
+ maxdepth : 3 , // small value just for testing...
59
+ ..Default :: default ( )
60
+ } ;
61
+
62
+ let math = nuts_rs:: CpuMath :: new ( func) ;
63
+ let mut rng = rand:: rngs:: StdRng :: seed_from_u64 ( 42u64 ) ;
64
+ settings. new_chain ( 0 , math, & mut rng)
65
+ }
66
+
67
+ pub fn sample_one ( out : & mut [ f64 ] ) {
68
+ let mut sampler = make_sampler ( out. len ( ) ) ;
16
69
let init = vec ! [ 3.5 ; out. len( ) ] ;
17
70
sampler. set_position ( & init) . unwrap ( ) ;
18
71
for _ in 0 ..1000 {
@@ -36,87 +89,79 @@ fn criterion_benchmark(c: &mut Criterion) {
36
89
cpu_set. set ( 0 ) . unwrap ( ) ;
37
90
sched_setaffinity ( Pid :: from_raw ( 0 ) , & cpu_set) . unwrap ( ) ;
38
91
39
- for n in [ 10 , 12 , 14 , 100 , 800 , 802 ] {
40
- let x = vec ! [ 2.5 ; n] ;
41
- let mut y = vec ! [ 3.5 ; n] ;
42
- let mut out = vec ! [ 0. ; n] ;
92
+ for n in [ 4 , 16 , 17 , 100 , 4567 ] {
93
+ let mut math = CpuMath :: new ( PosteriorDensity { dim : n } ) ;
94
+
95
+ let x = math. new_array ( ) ;
96
+ let p = math. new_array ( ) ;
97
+ let p2 = math. new_array ( ) ;
98
+ let n1 = math. new_array ( ) ;
99
+ let mut y = math. new_array ( ) ;
100
+ let mut out = math. new_array ( ) ;
101
+
102
+ let x_vec = vec ! [ 2.5 ; n] ;
103
+ let mut y_vec = vec ! [ 2.5 ; n] ;
104
+
105
+ c. bench_function ( & format ! ( "multiply {}" , n) , |b| {
106
+ b. iter ( || math. array_mult ( black_box ( & x) , black_box ( & y) , black_box ( & mut out) ) ) ;
107
+ } ) ;
43
108
44
- //axpy(&x, &mut y, 4.);
45
109
c. bench_function ( & format ! ( "axpy {}" , n) , |b| {
46
- b. iter ( || axpy ( black_box ( & x) , black_box ( & mut y) , black_box ( 4. ) ) ) ;
110
+ b. iter ( || math . axpy ( black_box ( & x) , black_box ( & mut y) , black_box ( 4. ) ) ) ;
47
111
} ) ;
48
112
49
113
c. bench_function ( & format ! ( "axpy_ndarray {}" , n) , |b| {
50
114
b. iter ( || {
51
- let x = ndarray:: aview1 ( black_box ( & x ) ) ;
52
- let mut y = ndarray:: aview_mut1 ( black_box ( & mut y ) ) ;
115
+ let x = ndarray:: aview1 ( black_box ( & x_vec ) ) ;
116
+ let mut y = ndarray:: aview_mut1 ( black_box ( & mut y_vec ) ) ;
53
117
//y *= &x;// * black_box(4.);
54
118
y. scaled_add ( black_box ( 4f64 ) , & x) ;
55
119
} ) ;
56
120
} ) ;
57
121
58
- //axpy_out(&x, &y, 4., &mut out);
59
122
c. bench_function ( & format ! ( "axpy_out {}" , n) , |b| {
60
123
b. iter ( || {
61
- axpy_out (
124
+ math . axpy_out (
62
125
black_box ( & x) ,
63
126
black_box ( & y) ,
64
127
black_box ( 4. ) ,
65
128
black_box ( & mut out) ,
66
129
)
67
130
} ) ;
68
131
} ) ;
69
- //vector_dot(&x, &y);
132
+
70
133
c. bench_function ( & format ! ( "vector_dot {}" , n) , |b| {
71
- b. iter ( || vector_dot ( black_box ( & x) , black_box ( & y) ) ) ;
134
+ b. iter ( || math . array_vector_dot ( black_box ( & x) , black_box ( & y) ) ) ;
72
135
} ) ;
73
- /*
74
- scalar_prods_of_diff(&x, &y, &a, &d);
75
- c.bench_function(&format!("scalar_prods_of_diff {}", n), |b| {
136
+
137
+ c. bench_function ( & format ! ( "scalar_prods2 {}" , n) , |b| {
76
138
b. iter ( || {
77
- scalar_prods_of_diff(black_box(&x), black_box(&y), black_box(&a), black_box(&d))
139
+ math. scalar_prods2 ( black_box ( & p) , black_box ( & p2) , black_box ( & x) , black_box ( & y) )
140
+ } ) ;
141
+ } ) ;
142
+
143
+ c. bench_function ( & format ! ( "scalar_prods3 {}" , n) , |b| {
144
+ b. iter ( || {
145
+ math. scalar_prods3 (
146
+ black_box ( & p) ,
147
+ black_box ( & p2) ,
148
+ black_box ( & n1) ,
149
+ black_box ( & x) ,
150
+ black_box ( & y) ,
151
+ )
78
152
} ) ;
79
153
} ) ;
80
- */
81
154
}
82
155
83
156
let mut out = vec ! [ 0. ; 10 ] ;
84
157
c. bench_function ( "sample_1000_10" , |b| {
85
- b. iter ( || sample_one ( black_box ( 3. ) , black_box ( & mut out) ) )
158
+ b. iter ( || sample_one ( black_box ( & mut out) ) )
86
159
} ) ;
87
160
88
161
let mut out = vec ! [ 0. ; 1000 ] ;
89
162
c. bench_function ( "sample_1000_1000" , |b| {
90
- b. iter ( || sample_one ( black_box ( 3. ) , black_box ( & mut out) ) )
163
+ b. iter ( || sample_one ( black_box ( & mut out) ) )
91
164
} ) ;
92
-
93
- for n in [ 10 , 12 , 1000 ] {
94
- c. bench_function ( & format ! ( "sample_parallel_{}" , n) , |b| {
95
- b. iter ( || {
96
- let func = NormalLogp :: new ( n, 0. ) ;
97
- let settings = black_box ( SamplerArgs :: default ( ) ) ;
98
- let mut init_point_func = JitterInitFunc :: new ( ) ;
99
- let n_chains = black_box ( 10 ) ;
100
- let n_draws = black_box ( 1000 ) ;
101
- let seed = black_box ( 42 ) ;
102
- let n_try_init = 10 ;
103
- let ( handle, channel) = sample_parallel (
104
- func,
105
- & mut init_point_func,
106
- settings,
107
- n_chains,
108
- n_draws,
109
- seed,
110
- n_try_init,
111
- )
112
- . unwrap ( ) ;
113
- let draws: Vec < _ > = channel. iter ( ) . collect ( ) ;
114
- //assert_eq!(draws.len() as u64, (n_draws + settings.num_tune) * n_chains);
115
- handle. join ( ) . unwrap ( ) ;
116
- draws
117
- } ) ;
118
- } ) ;
119
- }
120
165
}
121
166
122
167
criterion_group ! ( benches, criterion_benchmark) ;
0 commit comments