Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ jobs:
- name: Generate code coverage
run: cargo llvm-cov --workspace --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
files: lcov.info
fail_ci_if_error: true
2 changes: 0 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: check
args: --features=nightly
- name: Run cargo check
uses: actions-rs/cargo@v1
with:
Expand Down Expand Up @@ -57,7 +56,6 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --features=nightly
- name: Run cargo test
uses: actions-rs/cargo@v1
with:
Expand Down
33 changes: 32 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,43 @@

All notable changes to this project will be documented in this file.

## [0.15.0] - 2025-02-12
## [0.16.0] - 2025-05-27

### Bug Fixes

- Eigen decomposition error for low rank mass matrix (Adrian Seyboldt)


### Miscellaneous Tasks

- Bump arrow version (Adrian Seyboldt)


### Performance

- Replace multiversion with pulp for simd (Adrian Seyboldt)


### Build

- Remove simd_support feature (Adrian Seyboldt)


## [0.15.1] - 2025-03-18

### Features

- Change defaults for transform adapt (Adrian Seyboldt)


### Miscellaneous Tasks

- Update dependencies (Adrian Seyboldt)

- Update dependencies (Adrian Seyboldt)

- Bump version (Adrian Seyboldt)


### Ci

Expand Down
16 changes: 5 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "nuts-rs"
version = "0.15.1"
version = "0.16.0"
authors = [
"Adrian Seyboldt <[email protected]>",
"PyMC Developers <[email protected]>",
Expand All @@ -20,13 +20,12 @@ codegen-units = 1
[dependencies]
rand = { version = "0.9.0", features = ["small_rng"] }
rand_distr = "0.5.0"
multiversion = "0.8.0"
itertools = "0.14.0"
thiserror = "2.0.3"
arrow = { version = "54.2.0", default-features = false, features = ["ffi"] }
arrow = { version = "55.1.0", default-features = false, features = ["ffi"] }
rand_chacha = "0.9.0"
anyhow = "1.0.72"
faer = { version = "0.21.4", default-features = false, features = [
faer = { version = "0.22.6", default-features = false, features = [
"std",
"npy",
"linalg",
Expand All @@ -37,17 +36,12 @@ rayon = "1.10.0"
[dev-dependencies]
proptest = "1.6.0"
pretty_assertions = "1.4.0"
criterion = "0.5.1"
nix = "0.29.0"
criterion = "0.6.0"
nix = { version = "0.30.0", features = ["sched"] }
approx = "0.5.1"
ndarray = "0.16.1"
equator = "0.4.2"

[[bench]]
name = "sample"
harness = false

[features]
nightly = ["simd_support"]

simd_support = []
157 changes: 101 additions & 56 deletions benches/sample.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,71 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use std::hint::black_box;

use criterion::{criterion_group, criterion_main, Criterion};
use nix::sched::{sched_setaffinity, CpuSet};
use nix::unistd::Pid;
use nuts_rs::math::{axpy, axpy_out, vector_dot};
use nuts_rs::test_logps::NormalLogp;
use nuts_rs::{new_sampler, sample_parallel, Chain, JitterInitFunc, SamplerArgs};
use nuts_rs::{Chain, CpuLogpFunc, CpuMath, LogpError, Math, Settings};
use rand::SeedableRng;
use rayon::ThreadPoolBuilder;
use thiserror::Error;

fn make_sampler(dim: usize, mu: f64) -> impl Chain {
let func = NormalLogp::new(dim, mu);
new_sampler(func, SamplerArgs::default(), 0, 0)
#[derive(Debug)]
struct PosteriorDensity {
dim: usize,
}

pub fn sample_one(mu: f64, out: &mut [f64]) {
let mut sampler = make_sampler(out.len(), mu);
// The density might fail in a recoverable or non-recoverable manner...
#[derive(Debug, Error)]
enum PosteriorLogpError {}
impl LogpError for PosteriorLogpError {
fn is_recoverable(&self) -> bool {
false
}
}

impl CpuLogpFunc for PosteriorDensity {
type LogpError = PosteriorLogpError;

// Only used for transforming adaptation.
type TransformParams = ();

// We define a 10 dimensional normal distribution
fn dim(&self) -> usize {
self.dim
}

// The normal likelihood with mean 3 and its gradient.
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
let mu = 3f64;
let logp = position
.iter()
.copied()
.zip(grad.iter_mut())
.map(|(x, grad)| {
let diff = x - mu;
*grad = -diff;
-0.5 * diff * diff
})
.sum();
return Ok(logp);
}
}

fn make_sampler(dim: usize) -> impl Chain<CpuMath<PosteriorDensity>> {
let func = PosteriorDensity { dim: dim };

let settings = nuts_rs::DiagGradNutsSettings {
num_tune: 1000,
maxdepth: 3, // small value just for testing...
..Default::default()
};

let math = nuts_rs::CpuMath::new(func);
let mut rng = rand::rngs::StdRng::seed_from_u64(42u64);
settings.new_chain(0, math, &mut rng)
}

pub fn sample_one(out: &mut [f64]) {
let mut sampler = make_sampler(out.len());
let init = vec![3.5; out.len()];
sampler.set_position(&init).unwrap();
for _ in 0..1000 {
Expand All @@ -36,87 +89,79 @@ fn criterion_benchmark(c: &mut Criterion) {
cpu_set.set(0).unwrap();
sched_setaffinity(Pid::from_raw(0), &cpu_set).unwrap();

for n in [10, 12, 14, 100, 800, 802] {
let x = vec![2.5; n];
let mut y = vec![3.5; n];
let mut out = vec![0.; n];
for n in [4, 16, 17, 100, 4567] {
let mut math = CpuMath::new(PosteriorDensity { dim: n });

let x = math.new_array();
let p = math.new_array();
let p2 = math.new_array();
let n1 = math.new_array();
let mut y = math.new_array();
let mut out = math.new_array();

let x_vec = vec![2.5; n];
let mut y_vec = vec![2.5; n];

c.bench_function(&format!("multiply {}", n), |b| {
b.iter(|| math.array_mult(black_box(&x), black_box(&y), black_box(&mut out)));
});

//axpy(&x, &mut y, 4.);
c.bench_function(&format!("axpy {}", n), |b| {
b.iter(|| axpy(black_box(&x), black_box(&mut y), black_box(4.)));
b.iter(|| math.axpy(black_box(&x), black_box(&mut y), black_box(4.)));
});

c.bench_function(&format!("axpy_ndarray {}", n), |b| {
b.iter(|| {
let x = ndarray::aview1(black_box(&x));
let mut y = ndarray::aview_mut1(black_box(&mut y));
let x = ndarray::aview1(black_box(&x_vec));
let mut y = ndarray::aview_mut1(black_box(&mut y_vec));
//y *= &x;// * black_box(4.);
y.scaled_add(black_box(4f64), &x);
});
});

//axpy_out(&x, &y, 4., &mut out);
c.bench_function(&format!("axpy_out {}", n), |b| {
b.iter(|| {
axpy_out(
math.axpy_out(
black_box(&x),
black_box(&y),
black_box(4.),
black_box(&mut out),
)
});
});
//vector_dot(&x, &y);

c.bench_function(&format!("vector_dot {}", n), |b| {
b.iter(|| vector_dot(black_box(&x), black_box(&y)));
b.iter(|| math.array_vector_dot(black_box(&x), black_box(&y)));
});
/*
scalar_prods_of_diff(&x, &y, &a, &d);
c.bench_function(&format!("scalar_prods_of_diff {}", n), |b| {

c.bench_function(&format!("scalar_prods2 {}", n), |b| {
b.iter(|| {
scalar_prods_of_diff(black_box(&x), black_box(&y), black_box(&a), black_box(&d))
math.scalar_prods2(black_box(&p), black_box(&p2), black_box(&x), black_box(&y))
});
});

c.bench_function(&format!("scalar_prods3 {}", n), |b| {
b.iter(|| {
math.scalar_prods3(
black_box(&p),
black_box(&p2),
black_box(&n1),
black_box(&x),
black_box(&y),
)
});
});
*/
}

let mut out = vec![0.; 10];
c.bench_function("sample_1000_10", |b| {
b.iter(|| sample_one(black_box(3.), black_box(&mut out)))
b.iter(|| sample_one(black_box(&mut out)))
});

let mut out = vec![0.; 1000];
c.bench_function("sample_1000_1000", |b| {
b.iter(|| sample_one(black_box(3.), black_box(&mut out)))
b.iter(|| sample_one(black_box(&mut out)))
});

for n in [10, 12, 1000] {
c.bench_function(&format!("sample_parallel_{}", n), |b| {
b.iter(|| {
let func = NormalLogp::new(n, 0.);
let settings = black_box(SamplerArgs::default());
let mut init_point_func = JitterInitFunc::new();
let n_chains = black_box(10);
let n_draws = black_box(1000);
let seed = black_box(42);
let n_try_init = 10;
let (handle, channel) = sample_parallel(
func,
&mut init_point_func,
settings,
n_chains,
n_draws,
seed,
n_try_init,
)
.unwrap();
let draws: Vec<_> = channel.iter().collect();
//assert_eq!(draws.len() as u64, (n_draws + settings.num_tune) * n_chains);
handle.join().unwrap();
draws
});
});
}
}

criterion_group!(benches, criterion_benchmark);
Expand Down
1 change: 1 addition & 0 deletions proptest-regressions/math.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ cc cf16a8d08e8ee8f7f3d3cfd60840e136ac51d130dffcd42db1a9a68d7e51f394 # shrinks to
cc 28897b64919482133f3885c3de51da0895409d23c9dd503a7b51a3e949bda307 # shrinks to (x1, x2, x3, y1, y2) = ([0.0], [0.0], [-4.0946726283401733e139], [0.0], [1.3157422010991668e73])
cc acf6caef8a89a75ddab31ec3e391850723a625084df032aec2b650c2f95ba1fb # shrinks to (x, y) = ([0.0, 0.0, 0.0, 1.2271235629394547e205, 0.0, 0.0, -0.0, 0.0], [0.0, 0.0, 0.0, 7.121658452243713e81, 0.0, 0.0, 0.0, 0.0]), a = -6.261465657118442e-124
cc 7ef2902af043f2f37325a29f48a403a32a2593b8089f085492b1010c68627341 # shrinks to a = 1.033664102276113e155, (x, y, out) = ([-1.847508293460042e-54, 0.0, 0.0], [1.8293708670672727e101, 0.0, 0.0], [0.0, 0.0, 0.0])
cc 934b98345a50e6ded57733192b3f9f126cd28c04398fdb896353a19d00e9455c # shrinks to (x, y) = ([0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0, inf])
2 changes: 1 addition & 1 deletion src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M

if let Some(div_msg) = divergence_msg.as_mut() {
if let Some(err) = div_info.and_then(|info| info.logp_function_error.as_ref()) {
div_msg.append_value(format!("{}", err));
div_msg.append_value(format!("{err}"));
} else {
div_msg.append_null();
}
Expand Down
8 changes: 7 additions & 1 deletion src/cpu_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
y: &Self::Vector,
) -> (f64, f64) {
scalar_prods3(
self.arch,
positive1.try_as_col_major().unwrap().as_slice(),
negative1.try_as_col_major().unwrap().as_slice(),
positive2.try_as_col_major().unwrap().as_slice(),
Expand All @@ -119,6 +120,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
y: &Self::Vector,
) -> (f64, f64) {
scalar_prods2(
self.arch,
positive1.try_as_col_major().unwrap().as_slice(),
positive2.try_as_col_major().unwrap().as_slice(),
x.try_as_col_major().unwrap().as_slice(),
Expand Down Expand Up @@ -153,6 +155,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {

fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
axpy_out(
self.arch,
x.try_as_col_major().unwrap().as_slice(),
y.try_as_col_major().unwrap().as_slice(),
a,
Expand All @@ -162,6 +165,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {

fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
axpy(
self.arch,
x.try_as_col_major().unwrap().as_slice(),
y.try_as_col_major_mut().unwrap().as_slice_mut(),
a,
Expand All @@ -174,7 +178,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {

fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
let mut ok = true;
faer::zip!(array).for_each(|faer::unzip!(val)| ok = ok & val.is_finite());
faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
ok
}

Expand All @@ -196,6 +200,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
dest: &mut Self::Vector,
) {
multiply(
self.arch,
array1.try_as_col_major().unwrap().as_slice(),
array2.try_as_col_major().unwrap().as_slice(),
dest.try_as_col_major_mut().unwrap().as_slice_mut(),
Expand All @@ -220,6 +225,7 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {

fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
vector_dot(
self.arch,
array1.try_as_col_major().unwrap().as_slice(),
array2.try_as_col_major().unwrap().as_slice(),
)
Expand Down
Loading
Loading