Skip to content

Commit 442d807

Browse files
committed
perf: replace multiversion with pulp for simd
1 parent e4642f1 commit 442d807

File tree

6 files changed

+506
-320
lines changed

6 files changed

+506
-320
lines changed

Cargo.toml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
name = "nuts-rs"
33
version = "0.15.1"
44
authors = [
5-
"Adrian Seyboldt <[email protected]>",
6-
"PyMC Developers <[email protected]>",
5+
"Adrian Seyboldt <[email protected]>",
6+
"PyMC Developers <[email protected]>",
77
]
88
edition = "2021"
99
license = "MIT"
@@ -20,25 +20,24 @@ codegen-units = 1
2020
[dependencies]
2121
rand = { version = "0.9.0", features = ["small_rng"] }
2222
rand_distr = "0.5.0"
23-
multiversion = "0.8.0"
2423
itertools = "0.14.0"
2524
thiserror = "2.0.3"
2625
arrow = { version = "54.2.0", default-features = false, features = ["ffi"] }
2726
rand_chacha = "0.9.0"
2827
anyhow = "1.0.72"
29-
faer = { version = "0.21.4", default-features = false, features = [
30-
"std",
31-
"npy",
32-
"linalg",
28+
faer = { version = "0.22.6", default-features = false, features = [
29+
"std",
30+
"npy",
31+
"linalg",
3332
] }
3433
pulp = "0.21.4"
3534
rayon = "1.10.0"
3635

3736
[dev-dependencies]
3837
proptest = "1.6.0"
3938
pretty_assertions = "1.4.0"
40-
criterion = "0.5.1"
41-
nix = "0.29.0"
39+
criterion = "0.6.0"
40+
nix = { version = "0.30.0", features = ["sched"] }
4241
approx = "0.5.1"
4342
ndarray = "0.16.1"
4443
equator = "0.4.2"

benches/sample.rs

Lines changed: 101 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,71 @@
1-
use criterion::{black_box, criterion_group, criterion_main, Criterion};
1+
use std::hint::black_box;
2+
3+
use criterion::{criterion_group, criterion_main, Criterion};
24
use nix::sched::{sched_setaffinity, CpuSet};
35
use nix::unistd::Pid;
4-
use nuts_rs::math::{axpy, axpy_out, vector_dot};
5-
use nuts_rs::test_logps::NormalLogp;
6-
use nuts_rs::{new_sampler, sample_parallel, Chain, JitterInitFunc, SamplerArgs};
6+
use nuts_rs::{Chain, CpuLogpFunc, CpuMath, LogpError, Math, Settings};
7+
use rand::SeedableRng;
78
use rayon::ThreadPoolBuilder;
9+
use thiserror::Error;
810

9-
fn make_sampler(dim: usize, mu: f64) -> impl Chain {
10-
let func = NormalLogp::new(dim, mu);
11-
new_sampler(func, SamplerArgs::default(), 0, 0)
11+
#[derive(Debug)]
12+
struct PosteriorDensity {
13+
dim: usize,
1214
}
1315

14-
pub fn sample_one(mu: f64, out: &mut [f64]) {
15-
let mut sampler = make_sampler(out.len(), mu);
16+
// The density might fail in a recoverable or non-recoverable manner...
17+
#[derive(Debug, Error)]
18+
enum PosteriorLogpError {}
19+
impl LogpError for PosteriorLogpError {
20+
fn is_recoverable(&self) -> bool {
21+
false
22+
}
23+
}
24+
25+
impl CpuLogpFunc for PosteriorDensity {
26+
type LogpError = PosteriorLogpError;
27+
28+
// Only used for transforming adaptation.
29+
type TransformParams = ();
30+
31+
// We define a 10 dimensional normal distribution
32+
fn dim(&self) -> usize {
33+
self.dim
34+
}
35+
36+
// The normal likelihood with mean 3 and its gradient.
37+
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
38+
let mu = 3f64;
39+
let logp = position
40+
.iter()
41+
.copied()
42+
.zip(grad.iter_mut())
43+
.map(|(x, grad)| {
44+
let diff = x - mu;
45+
*grad = -diff;
46+
-0.5 * diff * diff
47+
})
48+
.sum();
49+
return Ok(logp);
50+
}
51+
}
52+
53+
fn make_sampler(dim: usize) -> impl Chain<CpuMath<PosteriorDensity>> {
54+
let func = PosteriorDensity { dim: dim };
55+
56+
let settings = nuts_rs::DiagGradNutsSettings {
57+
num_tune: 1000,
58+
maxdepth: 3, // small value just for testing...
59+
..Default::default()
60+
};
61+
62+
let math = nuts_rs::CpuMath::new(func);
63+
let mut rng = rand::rngs::StdRng::seed_from_u64(42u64);
64+
settings.new_chain(0, math, &mut rng)
65+
}
66+
67+
pub fn sample_one(out: &mut [f64]) {
68+
let mut sampler = make_sampler(out.len());
1669
let init = vec![3.5; out.len()];
1770
sampler.set_position(&init).unwrap();
1871
for _ in 0..1000 {
@@ -36,87 +89,79 @@ fn criterion_benchmark(c: &mut Criterion) {
3689
cpu_set.set(0).unwrap();
3790
sched_setaffinity(Pid::from_raw(0), &cpu_set).unwrap();
3891

39-
for n in [10, 12, 14, 100, 800, 802] {
40-
let x = vec![2.5; n];
41-
let mut y = vec![3.5; n];
42-
let mut out = vec![0.; n];
92+
for n in [4, 16, 17, 100, 4567] {
93+
let mut math = CpuMath::new(PosteriorDensity { dim: n });
94+
95+
let x = math.new_array();
96+
let p = math.new_array();
97+
let p2 = math.new_array();
98+
let n1 = math.new_array();
99+
let mut y = math.new_array();
100+
let mut out = math.new_array();
101+
102+
let x_vec = vec![2.5; n];
103+
let mut y_vec = vec![2.5; n];
104+
105+
c.bench_function(&format!("multiply {}", n), |b| {
106+
b.iter(|| math.array_mult(black_box(&x), black_box(&y), black_box(&mut out)));
107+
});
43108

44-
//axpy(&x, &mut y, 4.);
45109
c.bench_function(&format!("axpy {}", n), |b| {
46-
b.iter(|| axpy(black_box(&x), black_box(&mut y), black_box(4.)));
110+
b.iter(|| math.axpy(black_box(&x), black_box(&mut y), black_box(4.)));
47111
});
48112

49113
c.bench_function(&format!("axpy_ndarray {}", n), |b| {
50114
b.iter(|| {
51-
let x = ndarray::aview1(black_box(&x));
52-
let mut y = ndarray::aview_mut1(black_box(&mut y));
115+
let x = ndarray::aview1(black_box(&x_vec));
116+
let mut y = ndarray::aview_mut1(black_box(&mut y_vec));
53117
//y *= &x;// * black_box(4.);
54118
y.scaled_add(black_box(4f64), &x);
55119
});
56120
});
57121

58-
//axpy_out(&x, &y, 4., &mut out);
59122
c.bench_function(&format!("axpy_out {}", n), |b| {
60123
b.iter(|| {
61-
axpy_out(
124+
math.axpy_out(
62125
black_box(&x),
63126
black_box(&y),
64127
black_box(4.),
65128
black_box(&mut out),
66129
)
67130
});
68131
});
69-
//vector_dot(&x, &y);
132+
70133
c.bench_function(&format!("vector_dot {}", n), |b| {
71-
b.iter(|| vector_dot(black_box(&x), black_box(&y)));
134+
b.iter(|| math.array_vector_dot(black_box(&x), black_box(&y)));
72135
});
73-
/*
74-
scalar_prods_of_diff(&x, &y, &a, &d);
75-
c.bench_function(&format!("scalar_prods_of_diff {}", n), |b| {
136+
137+
c.bench_function(&format!("scalar_prods2 {}", n), |b| {
76138
b.iter(|| {
77-
scalar_prods_of_diff(black_box(&x), black_box(&y), black_box(&a), black_box(&d))
139+
math.scalar_prods2(black_box(&p), black_box(&p2), black_box(&x), black_box(&y))
140+
});
141+
});
142+
143+
c.bench_function(&format!("scalar_prods3 {}", n), |b| {
144+
b.iter(|| {
145+
math.scalar_prods3(
146+
black_box(&p),
147+
black_box(&p2),
148+
black_box(&n1),
149+
black_box(&x),
150+
black_box(&y),
151+
)
78152
});
79153
});
80-
*/
81154
}
82155

83156
let mut out = vec![0.; 10];
84157
c.bench_function("sample_1000_10", |b| {
85-
b.iter(|| sample_one(black_box(3.), black_box(&mut out)))
158+
b.iter(|| sample_one(black_box(&mut out)))
86159
});
87160

88161
let mut out = vec![0.; 1000];
89162
c.bench_function("sample_1000_1000", |b| {
90-
b.iter(|| sample_one(black_box(3.), black_box(&mut out)))
163+
b.iter(|| sample_one(black_box(&mut out)))
91164
});
92-
93-
for n in [10, 12, 1000] {
94-
c.bench_function(&format!("sample_parallel_{}", n), |b| {
95-
b.iter(|| {
96-
let func = NormalLogp::new(n, 0.);
97-
let settings = black_box(SamplerArgs::default());
98-
let mut init_point_func = JitterInitFunc::new();
99-
let n_chains = black_box(10);
100-
let n_draws = black_box(1000);
101-
let seed = black_box(42);
102-
let n_try_init = 10;
103-
let (handle, channel) = sample_parallel(
104-
func,
105-
&mut init_point_func,
106-
settings,
107-
n_chains,
108-
n_draws,
109-
seed,
110-
n_try_init,
111-
)
112-
.unwrap();
113-
let draws: Vec<_> = channel.iter().collect();
114-
//assert_eq!(draws.len() as u64, (n_draws + settings.num_tune) * n_chains);
115-
handle.join().unwrap();
116-
draws
117-
});
118-
});
119-
}
120165
}
121166

122167
criterion_group!(benches, criterion_benchmark);

proptest-regressions/math.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ cc cf16a8d08e8ee8f7f3d3cfd60840e136ac51d130dffcd42db1a9a68d7e51f394 # shrinks to
99
cc 28897b64919482133f3885c3de51da0895409d23c9dd503a7b51a3e949bda307 # shrinks to (x1, x2, x3, y1, y2) = ([0.0], [0.0], [-4.0946726283401733e139], [0.0], [1.3157422010991668e73])
1010
cc acf6caef8a89a75ddab31ec3e391850723a625084df032aec2b650c2f95ba1fb # shrinks to (x, y) = ([0.0, 0.0, 0.0, 1.2271235629394547e205, 0.0, 0.0, -0.0, 0.0], [0.0, 0.0, 0.0, 7.121658452243713e81, 0.0, 0.0, 0.0, 0.0]), a = -6.261465657118442e-124
1111
cc 7ef2902af043f2f37325a29f48a403a32a2593b8089f085492b1010c68627341 # shrinks to a = 1.033664102276113e155, (x, y, out) = ([-1.847508293460042e-54, 0.0, 0.0], [1.8293708670672727e101, 0.0, 0.0], [0.0, 0.0, 0.0])
12+
cc 934b98345a50e6ded57733192b3f9f126cd28c04398fdb896353a19d00e9455c # shrinks to (x, y) = ([0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0, inf])

src/cpu_math.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
103103
y: &Self::Vector,
104104
) -> (f64, f64) {
105105
scalar_prods3(
106+
self.arch,
106107
positive1.try_as_col_major().unwrap().as_slice(),
107108
negative1.try_as_col_major().unwrap().as_slice(),
108109
positive2.try_as_col_major().unwrap().as_slice(),
@@ -119,6 +120,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
119120
y: &Self::Vector,
120121
) -> (f64, f64) {
121122
scalar_prods2(
123+
self.arch,
122124
positive1.try_as_col_major().unwrap().as_slice(),
123125
positive2.try_as_col_major().unwrap().as_slice(),
124126
x.try_as_col_major().unwrap().as_slice(),
@@ -153,6 +155,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
153155

154156
fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
155157
axpy_out(
158+
self.arch,
156159
x.try_as_col_major().unwrap().as_slice(),
157160
y.try_as_col_major().unwrap().as_slice(),
158161
a,
@@ -162,6 +165,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
162165

163166
fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
164167
axpy(
168+
self.arch,
165169
x.try_as_col_major().unwrap().as_slice(),
166170
y.try_as_col_major_mut().unwrap().as_slice_mut(),
167171
a,
@@ -174,7 +178,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
174178

175179
fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
176180
let mut ok = true;
177-
faer::zip!(array).for_each(|faer::unzip!(val)| ok = ok & val.is_finite());
181+
faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
178182
ok
179183
}
180184

@@ -196,6 +200,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
196200
dest: &mut Self::Vector,
197201
) {
198202
multiply(
203+
self.arch,
199204
array1.try_as_col_major().unwrap().as_slice(),
200205
array2.try_as_col_major().unwrap().as_slice(),
201206
dest.try_as_col_major_mut().unwrap().as_slice_mut(),
@@ -220,6 +225,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
220225

221226
fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
222227
vector_dot(
228+
self.arch,
223229
array1.try_as_col_major().unwrap().as_slice(),
224230
array2.try_as_col_major().unwrap().as_slice(),
225231
)

0 commit comments

Comments
 (0)