diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 6d01189..151301f 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -16,7 +16,9 @@ jobs: - name: Generate code coverage run: cargo llvm-cov --workspace --lcov --output-path lcov.info - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: files: lcov.info fail_ci_if_error: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4747ffe..06da3ba 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,6 @@ jobs: uses: actions-rs/cargo@v1 with: command: check - args: --features=nightly - name: Run cargo check uses: actions-rs/cargo@v1 with: @@ -57,7 +56,6 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --features=nightly - name: Run cargo test uses: actions-rs/cargo@v1 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a53edc..5c6ac8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,43 @@ All notable changes to this project will be documented in this file. -## [0.15.0] - 2025-02-12 +## [0.16.0] - 2025-05-27 + +### Bug Fixes + +- Eigen decomposition error for low rank mass matrix (Adrian Seyboldt) + ### Miscellaneous Tasks +- Bump arrow version (Adrian Seyboldt) + + +### Performance + +- Replace multiversion with pulp for simd (Adrian Seyboldt) + + +### Build + +- Remove simd_support feature (Adrian Seyboldt) + + +## [0.15.1] - 2025-03-18 + +### Features + +- Change defaults for transform adapt (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Update dependencies (Adrian Seyboldt) + - Update dependencies (Adrian Seyboldt) +- Bump version (Adrian Seyboldt) + ### Ci diff --git a/Cargo.toml b/Cargo.toml index 596e96f..2ac675b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nuts-rs" -version = "0.15.1" +version = "0.16.0" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -20,13 +20,12 @@ codegen-units = 1 [dependencies] rand = { version = "0.9.0", features = ["small_rng"] } rand_distr = "0.5.0" -multiversion = "0.8.0" itertools = "0.14.0" thiserror = "2.0.3" -arrow = { version = "54.2.0", default-features = false, features = ["ffi"] } +arrow = { version = "55.1.0", default-features = false, features = ["ffi"] } rand_chacha = "0.9.0" anyhow = "1.0.72" -faer = { version = "0.21.4", default-features = false, features = [ +faer = { version = "0.22.6", default-features = false, features = [ "std", "npy", "linalg", @@ -37,8 +36,8 @@ rayon = "1.10.0" [dev-dependencies] proptest = "1.6.0" pretty_assertions = "1.4.0" -criterion = "0.5.1" -nix = "0.29.0" +criterion = "0.6.0" +nix = { version = "0.30.0", features = ["sched"] } approx = "0.5.1" ndarray = "0.16.1" equator = "0.4.2" @@ -46,8 +45,3 @@ equator = "0.4.2" [[bench]] name = "sample" harness = false - -[features] -nightly = ["simd_support"] - -simd_support = [] diff --git a/benches/sample.rs b/benches/sample.rs index 3eec493..9148493 100644 --- a/benches/sample.rs +++ b/benches/sample.rs @@ -1,18 +1,71 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use std::hint::black_box; + +use criterion::{criterion_group, criterion_main, Criterion}; use nix::sched::{sched_setaffinity, CpuSet}; use nix::unistd::Pid; -use nuts_rs::math::{axpy, axpy_out, vector_dot}; -use nuts_rs::test_logps::NormalLogp; -use nuts_rs::{new_sampler, sample_parallel, Chain, JitterInitFunc, SamplerArgs}; +use nuts_rs::{Chain, CpuLogpFunc, CpuMath, LogpError, Math, Settings}; +use rand::SeedableRng; use rayon::ThreadPoolBuilder; +use thiserror::Error; -fn make_sampler(dim: usize, mu: f64) -> impl Chain { - let func = NormalLogp::new(dim, mu); - new_sampler(func, SamplerArgs::default(), 0, 0) +#[derive(Debug)] +struct PosteriorDensity { + dim: usize, } -pub fn sample_one(mu: f64, out: &mut [f64]) { - let mut sampler = make_sampler(out.len(), mu); +// The density might fail in a recoverable or non-recoverable manner... +#[derive(Debug, Error)] +enum PosteriorLogpError {} +impl LogpError for PosteriorLogpError { + fn is_recoverable(&self) -> bool { + false + } +} + +impl CpuLogpFunc for PosteriorDensity { + type LogpError = PosteriorLogpError; + + // Only used for transforming adaptation. + type TransformParams = (); + + // We define a 10 dimensional normal distribution + fn dim(&self) -> usize { + self.dim + } + + // The normal likelihood with mean 3 and its gradient. + fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { + let mu = 3f64; + let logp = position + .iter() + .copied() + .zip(grad.iter_mut()) + .map(|(x, grad)| { + let diff = x - mu; + *grad = -diff; + -0.5 * diff * diff + }) + .sum(); + return Ok(logp); + } +} + +fn make_sampler(dim: usize) -> impl Chain> { + let func = PosteriorDensity { dim: dim }; + + let settings = nuts_rs::DiagGradNutsSettings { + num_tune: 1000, + maxdepth: 3, // small value just for testing... + ..Default::default() + }; + + let math = nuts_rs::CpuMath::new(func); + let mut rng = rand::rngs::StdRng::seed_from_u64(42u64); + settings.new_chain(0, math, &mut rng) +} + +pub fn sample_one(out: &mut [f64]) { + let mut sampler = make_sampler(out.len()); let init = vec![3.5; out.len()]; sampler.set_position(&init).unwrap(); for _ in 0..1000 { @@ -36,29 +89,39 @@ fn criterion_benchmark(c: &mut Criterion) { cpu_set.set(0).unwrap(); sched_setaffinity(Pid::from_raw(0), &cpu_set).unwrap(); - for n in [10, 12, 14, 100, 800, 802] { - let x = vec![2.5; n]; - let mut y = vec![3.5; n]; - let mut out = vec![0.; n]; + for n in [4, 16, 17, 100, 4567] { + let mut math = CpuMath::new(PosteriorDensity { dim: n }); + + let x = math.new_array(); + let p = math.new_array(); + let p2 = math.new_array(); + let n1 = math.new_array(); + let mut y = math.new_array(); + let mut out = math.new_array(); + + let x_vec = vec![2.5; n]; + let mut y_vec = vec![2.5; n]; + + c.bench_function(&format!("multiply {}", n), |b| { + b.iter(|| math.array_mult(black_box(&x), black_box(&y), black_box(&mut out))); + }); - //axpy(&x, &mut y, 4.); c.bench_function(&format!("axpy {}", n), |b| { - b.iter(|| axpy(black_box(&x), black_box(&mut y), black_box(4.))); + b.iter(|| math.axpy(black_box(&x), black_box(&mut y), black_box(4.))); }); c.bench_function(&format!("axpy_ndarray {}", n), |b| { b.iter(|| { - let x = ndarray::aview1(black_box(&x)); - let mut y = ndarray::aview_mut1(black_box(&mut y)); + let x = ndarray::aview1(black_box(&x_vec)); + let mut y = ndarray::aview_mut1(black_box(&mut y_vec)); //y *= &x;// * black_box(4.); y.scaled_add(black_box(4f64), &x); }); }); - //axpy_out(&x, &y, 4., &mut out); c.bench_function(&format!("axpy_out {}", n), |b| { b.iter(|| { - axpy_out( + math.axpy_out( black_box(&x), black_box(&y), black_box(4.), @@ -66,57 +129,39 @@ fn criterion_benchmark(c: &mut Criterion) { ) }); }); - //vector_dot(&x, &y); + c.bench_function(&format!("vector_dot {}", n), |b| { - b.iter(|| vector_dot(black_box(&x), black_box(&y))); + b.iter(|| math.array_vector_dot(black_box(&x), black_box(&y))); }); - /* - scalar_prods_of_diff(&x, &y, &a, &d); - c.bench_function(&format!("scalar_prods_of_diff {}", n), |b| { + + c.bench_function(&format!("scalar_prods2 {}", n), |b| { b.iter(|| { - scalar_prods_of_diff(black_box(&x), black_box(&y), black_box(&a), black_box(&d)) + math.scalar_prods2(black_box(&p), black_box(&p2), black_box(&x), black_box(&y)) + }); + }); + + c.bench_function(&format!("scalar_prods3 {}", n), |b| { + b.iter(|| { + math.scalar_prods3( + black_box(&p), + black_box(&p2), + black_box(&n1), + black_box(&x), + black_box(&y), + ) }); }); - */ } let mut out = vec![0.; 10]; c.bench_function("sample_1000_10", |b| { - b.iter(|| sample_one(black_box(3.), black_box(&mut out))) + b.iter(|| sample_one(black_box(&mut out))) }); let mut out = vec![0.; 1000]; c.bench_function("sample_1000_1000", |b| { - b.iter(|| sample_one(black_box(3.), black_box(&mut out))) + b.iter(|| sample_one(black_box(&mut out))) }); - - for n in [10, 12, 1000] { - c.bench_function(&format!("sample_parallel_{}", n), |b| { - b.iter(|| { - let func = NormalLogp::new(n, 0.); - let settings = black_box(SamplerArgs::default()); - let mut init_point_func = JitterInitFunc::new(); - let n_chains = black_box(10); - let n_draws = black_box(1000); - let seed = black_box(42); - let n_try_init = 10; - let (handle, channel) = sample_parallel( - func, - &mut init_point_func, - settings, - n_chains, - n_draws, - seed, - n_try_init, - ) - .unwrap(); - let draws: Vec<_> = channel.iter().collect(); - //assert_eq!(draws.len() as u64, (n_draws + settings.num_tune) * n_chains); - handle.join().unwrap(); - draws - }); - }); - } } criterion_group!(benches, criterion_benchmark); diff --git a/proptest-regressions/math.txt b/proptest-regressions/math.txt index 023b873..4f88092 100644 --- a/proptest-regressions/math.txt +++ b/proptest-regressions/math.txt @@ -9,3 +9,4 @@ cc cf16a8d08e8ee8f7f3d3cfd60840e136ac51d130dffcd42db1a9a68d7e51f394 # shrinks to cc 28897b64919482133f3885c3de51da0895409d23c9dd503a7b51a3e949bda307 # shrinks to (x1, x2, x3, y1, y2) = ([0.0], [0.0], [-4.0946726283401733e139], [0.0], [1.3157422010991668e73]) cc acf6caef8a89a75ddab31ec3e391850723a625084df032aec2b650c2f95ba1fb # shrinks to (x, y) = ([0.0, 0.0, 0.0, 1.2271235629394547e205, 0.0, 0.0, -0.0, 0.0], [0.0, 0.0, 0.0, 7.121658452243713e81, 0.0, 0.0, 0.0, 0.0]), a = -6.261465657118442e-124 cc 7ef2902af043f2f37325a29f48a403a32a2593b8089f085492b1010c68627341 # shrinks to a = 1.033664102276113e155, (x, y, out) = ([-1.847508293460042e-54, 0.0, 0.0], [1.8293708670672727e101, 0.0, 0.0], [0.0, 0.0, 0.0]) +cc 934b98345a50e6ded57733192b3f9f126cd28c04398fdb896353a19d00e9455c # shrinks to (x, y) = ([0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0, inf]) diff --git a/src/chain.rs b/src/chain.rs index a182283..441a32e 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -416,7 +416,7 @@ impl> StatTraceBuilder Math for CpuMath { y: &Self::Vector, ) -> (f64, f64) { scalar_prods3( + self.arch, positive1.try_as_col_major().unwrap().as_slice(), negative1.try_as_col_major().unwrap().as_slice(), positive2.try_as_col_major().unwrap().as_slice(), @@ -119,6 +120,7 @@ impl Math for CpuMath { y: &Self::Vector, ) -> (f64, f64) { scalar_prods2( + self.arch, positive1.try_as_col_major().unwrap().as_slice(), positive2.try_as_col_major().unwrap().as_slice(), x.try_as_col_major().unwrap().as_slice(), @@ -153,6 +155,7 @@ impl Math for CpuMath { fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) { axpy_out( + self.arch, x.try_as_col_major().unwrap().as_slice(), y.try_as_col_major().unwrap().as_slice(), a, @@ -162,6 +165,7 @@ impl Math for CpuMath { fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) { axpy( + self.arch, x.try_as_col_major().unwrap().as_slice(), y.try_as_col_major_mut().unwrap().as_slice_mut(), a, @@ -174,7 +178,7 @@ impl Math for CpuMath { fn array_all_finite(&mut self, array: &Self::Vector) -> bool { let mut ok = true; - faer::zip!(array).for_each(|faer::unzip!(val)| ok = ok & val.is_finite()); + faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite()); ok } @@ -196,6 +200,7 @@ impl Math for CpuMath { dest: &mut Self::Vector, ) { multiply( + self.arch, array1.try_as_col_major().unwrap().as_slice(), array2.try_as_col_major().unwrap().as_slice(), dest.try_as_col_major_mut().unwrap().as_slice_mut(), @@ -220,6 +225,7 @@ impl Math for CpuMath { fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 { vector_dot( + self.arch, array1.try_as_col_major().unwrap().as_slice(), array2.try_as_col_major().unwrap().as_slice(), ) diff --git a/src/lib.rs b/src/lib.rs index 53e0273..b4798a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,3 @@ -#![cfg_attr(feature = "simd_support", feature(portable_simd))] -#![cfg_attr(feature = "simd_support", feature(slice_as_chunks))] //! Sample from posterior distributions using the No U-turn Sampler (NUTS). //! For details see the original [NUTS paper](https://arxiv.org/abs/1111.4246) //! and the more recent [introduction](https://arxiv.org/abs/1701.02434). diff --git a/src/low_rank_mass_matrix.rs b/src/low_rank_mass_matrix.rs index 384c641..316dbe4 100644 --- a/src/low_rank_mass_matrix.rs +++ b/src/low_rank_mass_matrix.rs @@ -18,13 +18,13 @@ use crate::{ fn mat_all_finite(mat: &MatRef) -> bool { let mut ok = true; - faer::zip!(mat).for_each(|faer::unzip!(val)| ok = ok & val.is_finite()); + faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite()); ok } fn col_all_finite(mat: &ColRef) -> bool { let mut ok = true; - faer::zip!(mat).for_each(|faer::unzip!(val)| ok = ok & val.is_finite()); + faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite()); ok } @@ -414,10 +414,22 @@ impl LowRankMassMatrixStrategy { grads.col_as_slice_mut(i).copy_from_slice(&grad[..]); } + let Some((stds, vals, vecs)) = self.compute_update(draws, grads) else { + return; + }; + + matrix.update(math, stds, vals, vecs); + } + + fn compute_update( + &self, + mut draws: Mat, + mut grads: Mat, + ) -> Option<(Col, Col, Mat)> { let stds = rescale_points(&mut draws, &mut grads); - let svd_draws = draws.thin_svd().expect("Failed to compute SVD"); - let svd_grads = grads.thin_svd().expect("Failed to compute SVD"); + let svd_draws = draws.thin_svd().ok()?; + let svd_grads = grads.thin_svd().ok()?; let subspace = faer::concat![[svd_draws.U(), svd_grads.U()]]; @@ -428,7 +440,7 @@ impl LowRankMassMatrixStrategy { let draws_proj = subspace_basis.transpose() * (&draws); let grads_proj = subspace_basis.transpose() * (&grads); - let (vals, vecs) = estimate_mass_matrix(draws_proj, grads_proj, self.settings.gamma); + let (vals, vecs) = estimate_mass_matrix(draws_proj, grads_proj, self.settings.gamma)?; let filtered = vals .iter() @@ -448,8 +460,7 @@ impl LowRankMassMatrixStrategy { .for_each(|(mut col, vals)| col.copy_from(vals)); let vecs = subspace_basis * vecs; - - matrix.update(math, stds, vals, vecs); + Some((stds, vals, vecs)) } } @@ -490,7 +501,11 @@ fn rescale_points(draws: &mut Mat, grads: &mut Mat) -> Col { stds } -fn estimate_mass_matrix(draws: Mat, grads: Mat, gamma: f64) -> (Col, Mat) { +fn estimate_mass_matrix( + draws: Mat, + grads: Mat, + gamma: f64, +) -> Option<(Col, Mat)> { let mut cov_draws = (&draws) * draws.transpose(); let mut cov_grads = (&grads) * grads.transpose(); @@ -509,22 +524,18 @@ fn estimate_mass_matrix(draws: Mat, grads: Mat, gamma: f64) -> (Col, cov_grads: Mat) -> Mat { - let eigs_grads = cov_grads - .self_adjoint_eigen(faer::Side::Lower) - .expect("Failed to compute eigenvalue decomposition"); +fn spd_mean(cov_draws: Mat, cov_grads: Mat) -> Option> { + let eigs_grads = cov_grads.self_adjoint_eigen(faer::Side::Lower).ok()?; let u = eigs_grads.U(); let eigs = eigs_grads.S().column_vector().to_owned(); @@ -534,9 +545,7 @@ fn spd_mean(cov_draws: Mat, cov_grads: Mat) -> Mat { let cov_grads_sqrt = u * eigs_sqrt.into_diagonal() * u.transpose(); let m = (&cov_grads_sqrt) * cov_draws * cov_grads_sqrt; - let m_eig = m - .self_adjoint_eigen(faer::Side::Lower) - .expect("Failed to compute eigenvalue decomposition"); + let m_eig = m.self_adjoint_eigen(faer::Side::Lower).ok()?; let m_u = m_eig.U(); let mut m_s = m_eig.S().column_vector().to_owned(); @@ -550,7 +559,7 @@ fn spd_mean(cov_draws: Mat, cov_grads: Mat) -> Mat { .for_each(|val| *val = val.sqrt().recip()); let grads_inv_sqrt = u * eigs_grads_inv.into_diagonal() * u.transpose(); - (&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt + Some((&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt) } impl SamplerStats for LowRankMassMatrixStrategy { @@ -656,7 +665,7 @@ mod test { x.diagonal_mut().column_vector_mut().add_assign(x_diag); y.diagonal_mut().column_vector_mut().add_assign(y_diag); - let out = spd_mean(x, y); + let out = spd_mean(x, y).expect("Failed to compute spd mean"); let expected_diag = faer::col![1., 2., 4.]; let mut expected = faer::Mat::zeros(3, 3); expected @@ -684,7 +693,8 @@ mod test { //let grads: Mat = Mat::from_fn(20, 3, |_, _| rng.sample(distr)); let grads = -(&draws); - let (vals, vecs) = estimate_mass_matrix(draws, grads, 0.0001); + let (vals, vecs) = + estimate_mass_matrix(draws, grads, 0.0001).expect("Failed to compute mass matrix"); assert!(vals.iter().cloned().all(|x| x > 0.)); assert!(mat_all_finite(&vecs.as_ref())); diff --git a/src/math.rs b/src/math.rs index da7d00a..462815a 100644 --- a/src/math.rs +++ b/src/math.rs @@ -1,8 +1,5 @@ use itertools::izip; -use multiversion::multiversion; - -#[cfg(feature = "simd_support")] -use std::simd::{f64x4, num::SimdFloat, StdFloat}; +use pulp::{Arch, WithSimd}; pub(crate) fn logaddexp(a: f64, b: f64) -> f64 { if a == b { @@ -19,99 +16,146 @@ pub(crate) fn logaddexp(a: f64, b: f64) -> f64 { } } -#[cfg(feature = "simd_support")] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) { - let n = x.len(); - assert!(y.len() == n); - assert!(out.len() == n); - - let (out, out_tail) = out.as_chunks_mut(); - let (x, x_tail) = x.as_chunks(); - let (y, y_tail) = y.as_chunks(); - - izip!(out, x, y).for_each(|(out, x, y)| { - let x = f64x4::from_array(*x); - let y = f64x4::from_array(*y); - *out = (x * y).to_array(); - }); - - izip!(out_tail.iter_mut(), x_tail.iter(), y_tail.iter()).for_each(|(out, &x, &y)| { - *out = x * y; - }); -} - -#[cfg(not(feature = "simd_support"))] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) { - let n = x.len(); - assert!(y.len() == n); - assert!(out.len() == n); - - izip!(out.iter_mut(), x.iter(), y.iter()).for_each(|(out, &x, &y)| { - *out = x * y; - }); +struct Multiply<'a> { + x: &'a [f64], + y: &'a [f64], + out: &'a mut [f64], } -#[cfg(feature = "simd_support")] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) -> (f64, f64) { - let n = positive1.len(); - - assert!(positive1.len() == n); - assert!(positive2.len() == n); - assert!(x.len() == n); - assert!(y.len() == n); - - let zero = f64x4::splat(0.); - - let (a, a_tail) = positive1.as_chunks(); - let (b, b_tail) = positive2.as_chunks(); - let (c, c_tail) = x.as_chunks(); - let (d, d_tail) = y.as_chunks(); - - let out = izip!(a, b, c, d) - .map(|(&a, &b, &c, &d)| { - ( - f64x4::from_array(a), - f64x4::from_array(b), - f64x4::from_array(c), - f64x4::from_array(d), - ) - }) - .fold((zero, zero), |(s1, s2), (a, b, c, d)| { - let sum = a + b; - (c.mul_add(sum, s1), d.mul_add(sum, s2)) +impl<'a> WithSimd for Multiply<'a> { + type Output = (); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let x = self.x; + let y = self.y; + let out = self.out; + + let (x_out, x_tail) = S::as_simd_f64s(x); + let (y_out, y_tail) = S::as_simd_f64s(y); + let (out_head, out_tail) = S::as_mut_simd_f64s(out); + + let (out_arrays, out_simd_tail) = pulp::as_arrays_mut::<4, _>(out_head); + let (x_arrays, x_simd_tail) = pulp::as_arrays::<4, _>(x_out); + let (y_arrays, y_simd_tail) = pulp::as_arrays::<4, _>(y_out); + + izip!(out_arrays, x_arrays, y_arrays).for_each( + |([out0, out1, out2, out3], [x0, x1, x2, x3], [y0, y1, y2, y3])| { + *out0 = simd.mul_f64s(*x0, *y0); + *out1 = simd.mul_f64s(*x1, *y1); + *out2 = simd.mul_f64s(*x2, *y2); + *out3 = simd.mul_f64s(*x3, *y3); + }, + ); + + izip!( + out_simd_tail.iter_mut(), + x_simd_tail.iter(), + y_simd_tail.iter() + ) + .for_each(|(out, &x, &y)| { + *out = simd.mul_f64s(x, y); }); - let out_head = (out.0.reduce_sum(), out.1.reduce_sum()); - let out = izip!(a_tail, b_tail, c_tail, d_tail,).fold((0., 0.), |(s1, s2), (a, b, c, d)| { - (s1 + c * (a + b), s2 + d * (a + b)) - }); + izip!(out_tail.iter_mut(), x_tail.iter(), y_tail.iter()).for_each(|(out, &x, &y)| { + *out = x * y; + }); + } +} - (out_head.0 + out.0, out_head.1 + out.1) +#[inline(never)] +pub fn multiply(arch: Arch, x: &[f64], y: &[f64], out: &mut [f64]) { + arch.dispatch(Multiply { x, y, out }) } -#[cfg(not(feature = "simd_support"))] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) -> (f64, f64) { - let n = positive1.len(); +struct ScalarProds2<'a> { + positive1: &'a [f64], + positive2: &'a [f64], + x: &'a [f64], + y: &'a [f64], +} - assert!(positive1.len() == n); - assert!(positive2.len() == n); - assert!(x.len() == n); - assert!(y.len() == n); +impl<'a> WithSimd for ScalarProds2<'a> { + type Output = (f64, f64); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let positive1 = self.positive1; + let positive2 = self.positive2; + let x = self.x; + let y = self.y; + + let (p1_out, p1_tail) = S::as_simd_f64s(positive1); + let (p2_out, p2_tail) = S::as_simd_f64s(positive2); + let (x_out, x_tail) = S::as_simd_f64s(x); + let (y_out, y_tail) = S::as_simd_f64s(y); + + let mut s1_0 = simd.splat_f64s(0.0); + let mut s1_1 = simd.splat_f64s(0.0); + let mut s1_2 = simd.splat_f64s(0.0); + let mut s1_3 = simd.splat_f64s(0.0); + let mut s2_0 = simd.splat_f64s(0.0); + let mut s2_1 = simd.splat_f64s(0.0); + let mut s2_2 = simd.splat_f64s(0.0); + let mut s2_3 = simd.splat_f64s(0.0); + + let (p1_out, p1_simd_tail) = pulp::as_arrays::<4, _>(p1_out); + let (p2_out, p2_simd_tail) = pulp::as_arrays::<4, _>(p2_out); + let (x_out, x_simd_tail) = pulp::as_arrays::<4, _>(x_out); + let (y_out, y_simd_tail) = pulp::as_arrays::<4, _>(y_out); + + izip!(p1_out, p2_out, x_out, y_out).for_each( + |( + [p1_0, p1_1, p1_2, p1_3], + [p2_0, p2_1, p2_2, p2_3], + [x_0, x_1, x_2, x_3], + [y_0, y_1, y_2, y_3], + )| { + let sum0 = simd.add_f64s(*p1_0, *p2_0); + let sum1 = simd.add_f64s(*p1_1, *p2_1); + let sum2 = simd.add_f64s(*p1_2, *p2_2); + let sum3 = simd.add_f64s(*p1_3, *p2_3); + s1_0 = simd.mul_add_e_f64s(sum0, *x_0, s1_0); + s1_1 = simd.mul_add_e_f64s(sum1, *x_1, s1_1); + s1_2 = simd.mul_add_e_f64s(sum2, *x_2, s1_2); + s1_3 = simd.mul_add_e_f64s(sum3, *x_3, s1_3); + s2_0 = simd.mul_add_e_f64s(sum0, *y_0, s2_0); + s2_1 = simd.mul_add_e_f64s(sum1, *y_1, s2_1); + s2_2 = simd.mul_add_e_f64s(sum2, *y_2, s2_2); + s2_3 = simd.mul_add_e_f64s(sum3, *y_3, s2_3); + }, + ); + + izip!(p1_simd_tail, p2_simd_tail, x_simd_tail, y_simd_tail).for_each(|(p1, p2, x, y)| { + let sum = simd.add_f64s(*p1, *p2); + s1_0 = simd.mul_add_e_f64s(sum, *x, s1_0); + s2_0 = simd.mul_add_e_f64s(sum, *y, s2_0); + }); - izip!(positive1, positive2, x, y).fold((0., 0.), |(s1, s2), (a, b, c, d)| { - (s1 + c * (a + b), s2 + d * (a + b)) - }) + let mut out = ( + simd.reduce_sum_f64s( + simd.add_f64s(simd.add_f64s(s1_0, s1_1), simd.add_f64s(s1_2, s1_3)), + ), + simd.reduce_sum_f64s( + simd.add_f64s(simd.add_f64s(s2_0, s2_1), simd.add_f64s(s2_2, s2_3)), + ), + ); + + izip!(p1_tail.iter(), p2_tail.iter(), x_tail.iter(), y_tail.iter()).for_each( + |(p1, p2, x, y)| { + let sum = *p1 + *p2; + out.0 += sum * *x; + out.1 += sum * *y; + }, + ); + out + } } -#[cfg(feature = "simd_support")] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn scalar_prods3( +#[inline(never)] +pub fn scalar_prods2( + arch: Arch, positive1: &[f64], - negative1: &[f64], positive2: &[f64], x: &[f64], y: &[f64], @@ -120,45 +164,122 @@ pub fn scalar_prods3( assert!(positive1.len() == n); assert!(positive2.len() == n); - assert!(negative1.len() == n); assert!(x.len() == n); assert!(y.len() == n); - let zero = f64x4::splat(0.); - - let (a, a_tail) = positive1.as_chunks(); - let (b, b_tail) = negative1.as_chunks(); - let (c, c_tail) = positive2.as_chunks(); - let (x, x_tail) = x.as_chunks(); - let (y, y_tail) = y.as_chunks(); - - let out = izip!(a, b, c, x, y) - .map(|(&a, &b, &c, &x, &y)| { - ( - f64x4::from_array(a), - f64x4::from_array(b), - f64x4::from_array(c), - f64x4::from_array(x), - f64x4::from_array(y), - ) - }) - .fold((zero, zero), |(s1, s2), (a, b, c, x, y)| { - let sum = a - b + c; - (x.mul_add(sum, s1), y.mul_add(sum, s2)) + arch.dispatch(ScalarProds2 { + positive1, + positive2, + x, + y, + }) +} + +struct ScalarProds3<'a> { + positive1: &'a [f64], + negative1: &'a [f64], + positive2: &'a [f64], + x: &'a [f64], + y: &'a [f64], +} + +impl<'a> WithSimd for ScalarProds3<'a> { + type Output = (f64, f64); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let positive1 = self.positive1; + let negative1 = self.negative1; + let positive2 = self.positive2; + let x = self.x; + let y = self.y; + + let (p1_out, p1_tail) = S::as_simd_f64s(positive1); + let (n1_out, n1_tail) = S::as_simd_f64s(negative1); + let (p2_out, p2_tail) = S::as_simd_f64s(positive2); + let (x_out, x_tail) = S::as_simd_f64s(x); + let (y_out, y_tail) = S::as_simd_f64s(y); + + let mut s1_0 = simd.splat_f64s(0.0); + let mut s1_1 = simd.splat_f64s(0.0); + let mut s1_2 = simd.splat_f64s(0.0); + let mut s1_3 = simd.splat_f64s(0.0); + let mut s2_0 = simd.splat_f64s(0.0); + let mut s2_1 = simd.splat_f64s(0.0); + let mut s2_2 = simd.splat_f64s(0.0); + let mut s2_3 = simd.splat_f64s(0.0); + + let (p1_out, p1_simd_tail) = pulp::as_arrays::<4, _>(p1_out); + let (n1_out, n1_simd_tail) = pulp::as_arrays::<4, _>(n1_out); + let (p2_out, p2_simd_tail) = pulp::as_arrays::<4, _>(p2_out); + let (x_out, x_simd_tail) = pulp::as_arrays::<4, _>(x_out); + let (y_out, y_simd_tail) = pulp::as_arrays::<4, _>(y_out); + + izip!(p1_out, n1_out, p2_out, x_out, y_out).for_each( + |( + [p1_0, p1_1, p1_2, p1_3], + [n1_0, n1_1, n1_2, n1_3], + [p2_0, p2_1, p2_2, p2_3], + [x_0, x_1, x_2, x_3], + [y_0, y_1, y_2, y_3], + )| { + let sum0 = simd.sub_f64s(simd.add_f64s(*p1_0, *p2_0), *n1_0); + let sum1 = simd.sub_f64s(simd.add_f64s(*p1_1, *p2_1), *n1_1); + let sum2 = simd.sub_f64s(simd.add_f64s(*p1_2, *p2_2), *n1_2); + let sum3 = simd.sub_f64s(simd.add_f64s(*p1_3, *p2_3), *n1_3); + s1_0 = simd.mul_add_e_f64s(sum0, *x_0, s1_0); + s1_1 = simd.mul_add_e_f64s(sum1, *x_1, s1_1); + s1_2 = simd.mul_add_e_f64s(sum2, *x_2, s1_2); + s1_3 = simd.mul_add_e_f64s(sum3, *x_3, s1_3); + s2_0 = simd.mul_add_e_f64s(sum0, *y_0, s2_0); + s2_1 = simd.mul_add_e_f64s(sum1, *y_1, s2_1); + s2_2 = simd.mul_add_e_f64s(sum2, *y_2, s2_2); + s2_3 = simd.mul_add_e_f64s(sum3, *y_3, s2_3); + }, + ); + + izip!( + p1_simd_tail, + n1_simd_tail, + p2_simd_tail, + x_simd_tail, + y_simd_tail + ) + .for_each(|(p1, n1, p2, x, y)| { + let sum = simd.sub_f64s(simd.add_f64s(*p1, *p2), *n1); + s1_0 = simd.mul_add_e_f64s(sum, *x, s1_0); + s2_0 = simd.mul_add_e_f64s(sum, *y, s2_0); }); - let out_head = (out.0.reduce_sum(), out.1.reduce_sum()); - let out = izip!(a_tail, b_tail, c_tail, x_tail, y_tail) - .fold((0., 0.), |(s1, s2), (a, b, c, x, y)| { - (s1 + x * (a - b + c), s2 + y * (a - b + c)) + let mut out = ( + simd.reduce_sum_f64s( + simd.add_f64s(simd.add_f64s(s1_0, s1_1), simd.add_f64s(s1_2, s1_3)), + ), + simd.reduce_sum_f64s( + simd.add_f64s(simd.add_f64s(s2_0, s2_1), simd.add_f64s(s2_2, s2_3)), + ), + ); + + izip!( + p1_tail.iter(), + n1_tail.iter(), + p2_tail.iter(), + x_tail.iter(), + y_tail.iter() + ) + .for_each(|(p1, n1, p2, x, y)| { + let sum = *p1 - *n1 + *p2; + out.0 += sum * *x; + out.1 += sum * *y; }); - (out_head.0 + out.0, out_head.1 + out.1) + out + } } -#[cfg(not(feature = "simd_support"))] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] +#[inline(never)] pub fn scalar_prods3( + arch: Arch, positive1: &[f64], negative1: &[f64], positive2: &[f64], @@ -173,117 +294,171 @@ pub fn scalar_prods3( assert!(x.len() == n); assert!(y.len() == n); - izip!(positive1, negative1, positive2, x, y).fold((0., 0.), |(s1, s2), (a, b, c, x, y)| { - (s1 + x * (a - b + c), s2 + y * (a - b + c)) + arch.dispatch(ScalarProds3 { + positive1, + negative1, + positive2, + x, + y, }) } -#[cfg(feature = "simd_support")] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 { - assert!(a.len() == b.len()); - - let (x, x_tail) = a.as_chunks(); - let (y, y_tail) = b.as_chunks(); - - let sum: f64x4 = izip!(x, y) - .map(|(&x, &y)| { - let x = f64x4::from_array(x); - let y = f64x4::from_array(y); - x * y - }) - .sum(); - - let mut result = sum.reduce_sum(); - for (val1, val2) in x_tail.iter().zip(y_tail).take(3) { - result += *val1 * *val2; - } - result +struct VectorDot<'a> { + x: &'a [f64], + y: &'a [f64], } -#[cfg(not(feature = "simd_support"))] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 { - assert!(a.len() == b.len()); +impl<'a> WithSimd for VectorDot<'a> { + type Output = f64; - let mut result = 0f64; - for (&val1, &val2) in a.iter().zip(b) { - result += val1 * val2; - } - result -} + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let a = self.x; + let b = self.y; -#[cfg(feature = "simd_support")] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn axpy(x: &[f64], y: &mut [f64], a: f64) { - let n = x.len(); - assert!(y.len() == n); + assert!(a.len() == b.len()); - let (x, x_tail) = x.as_chunks(); - let (y, y_tail) = y.as_chunks_mut(); + let (x, x_tail) = S::as_simd_f64s(a); + let (y, y_tail) = S::as_simd_f64s(b); - let a_splat = f64x4::splat(a); + let mut out0 = simd.splat_f64s(0f64); + let mut out1 = simd.splat_f64s(0f64); + let mut out2 = simd.splat_f64s(0f64); + let mut out3 = simd.splat_f64s(0f64); - izip!(x, y).for_each(|(x, y)| { - let x = f64x4::from_array(*x); - let y_val = f64x4::from_array(*y); - let out = x.mul_add(a_splat, y_val); - *y = out.to_array(); - }); + let (x, x_simd_tail) = pulp::as_arrays::<4, _>(x); + let (y, y_simd_tail) = pulp::as_arrays::<4, _>(y); - izip!(x_tail, y_tail).for_each(|(x, y)| { - *y = x.mul_add(a, *y); - }); + izip!(x, y).for_each(|([x0, x1, x2, x3], [y0, y1, y2, y3])| { + out0 = simd.mul_add_e_f64s(*x0, *y0, out0); + out1 = simd.mul_add_e_f64s(*x1, *y1, out1); + out2 = simd.mul_add_e_f64s(*x2, *y2, out2); + out3 = simd.mul_add_e_f64s(*x3, *y3, out3); + }); + + izip!(x_simd_tail, y_simd_tail).for_each(|(&x, &y)| { + out0 = simd.mul_add_e_f64s(x, y, out0); + }); + + out0 = simd.add_f64s(out0, out1); + out1 = simd.add_f64s(out2, out3); + out0 = simd.add_f64s(out0, out1); + let mut result = simd.reduce_sum_f64s(out0); + + x_tail.iter().zip(y_tail).for_each(|(&x, &y)| { + result += x * y; + }); + result + } } -#[cfg(not(feature = "simd_support"))] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn axpy(x: &[f64], y: &mut [f64], a: f64) { - let n = x.len(); - assert!(y.len() == n); +pub fn vector_dot(arch: Arch, a: &[f64], b: &[f64]) -> f64 { + arch.dispatch(VectorDot { x: a, y: b }) +} - izip!(x, y).for_each(|(x, y)| { - *y = x.mul_add(a, *y); - }); +struct Axpy<'a> { + x: &'a [f64], + y: &'a mut [f64], + a: f64, } -#[cfg(feature = "simd_support")] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) { +impl<'a> WithSimd for Axpy<'a> { + type Output = (); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let x = self.x; + let y = self.y; + let a = self.a; + + let (x_out, x_tail) = S::as_simd_f64s(x); + let (y_out, y_tail) = S::as_mut_simd_f64s(y); + + let a_splat = simd.splat_f64s(a); + + let (x_arrays, x_simd_tail) = pulp::as_arrays::<4, _>(x_out); + let (y_arrays, y_simd_tail) = pulp::as_arrays_mut::<4, _>(y_out); + + izip!(x_arrays, y_arrays).for_each(|([x0, x1, x2, x3], [y0, y1, y2, y3])| { + *y0 = simd.mul_add_e_f64s(a_splat, *x0, *y0); + *y1 = simd.mul_add_e_f64s(a_splat, *x1, *y1); + *y2 = simd.mul_add_e_f64s(a_splat, *x2, *y2); + *y3 = simd.mul_add_e_f64s(a_splat, *x3, *y3); + }); + + izip!(x_simd_tail.iter(), y_simd_tail.iter_mut()).for_each(|(&x, y)| { + *y = simd.mul_add_e_f64s(a_splat, x, *y); + }); + + izip!(x_tail.iter(), y_tail.iter_mut()).for_each(|(&x, y)| { + *y = a.mul_add(x, *y); + }); + } +} +pub fn axpy(arch: Arch, x: &[f64], y: &mut [f64], a: f64) { let n = x.len(); assert!(y.len() == n); - assert!(out.len() == n); - let (x, x_tail) = x.as_chunks(); - let (y, y_tail) = y.as_chunks(); - let (out, out_tail) = out.as_chunks_mut(); - - let a_splat = f64x4::splat(a); + arch.dispatch(Axpy { x, y, a }); +} - izip!(x, y, out).for_each(|(&x, &y, out)| { - let x = f64x4::from_array(x); - let y_val = f64x4::from_array(y); +struct AxpyOut<'a> { + x: &'a [f64], + y: &'a [f64], + out: &'a mut [f64], + a: f64, +} - *out = x.mul_add(a_splat, y_val).to_array(); - }); +impl<'a> WithSimd for AxpyOut<'a> { + type Output = (); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let x = self.x; + let y = self.y; + let out = self.out; + let a = self.a; + + let (x_out, x_tail) = S::as_simd_f64s(x); + let (y_out, y_tail) = S::as_simd_f64s(y); + let (out_head, out_tail) = S::as_mut_simd_f64s(out); + + let a_splat = simd.splat_f64s(a); + + let (out_arrays, out_simd_tail) = pulp::as_arrays_mut::<4, _>(out_head); + let (x_arrays, x_simd_tail) = pulp::as_arrays::<4, _>(x_out); + let (y_arrays, y_simd_tail) = pulp::as_arrays::<4, _>(y_out); + + izip!(out_arrays, x_arrays, y_arrays).for_each( + |([out0, out1, out2, out3], [x0, x1, x2, x3], [y0, y1, y2, y3])| { + *out0 = simd.mul_add_e_f64s(a_splat, *x0, *y0); + *out1 = simd.mul_add_e_f64s(a_splat, *x1, *y1); + *out2 = simd.mul_add_e_f64s(a_splat, *x2, *y2); + *out3 = simd.mul_add_e_f64s(a_splat, *x3, *y3); + }, + ); + + izip!( + out_simd_tail.iter_mut(), + x_simd_tail.iter(), + y_simd_tail.iter() + ) + .for_each(|(out, &x, &y)| { + *out = simd.mul_add_e_f64s(a_splat, x, y); + }); - izip!(x_tail, y_tail, out_tail) - .take(3) - .for_each(|(&x, &y, out)| { + izip!(x_tail.iter(), y_tail.iter(), out_tail.iter_mut()).for_each(|(&x, &y, out)| { *out = a.mul_add(x, y); }); + } } -#[cfg(not(feature = "simd_support"))] -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) { +pub fn axpy_out(arch: Arch, x: &[f64], y: &[f64], a: f64, out: &mut [f64]) { let n = x.len(); assert!(y.len() == n); assert!(out.len() == n); - izip!(x, y, out).for_each(|(&x, &y, out)| { - *out = a.mul_add(x, y); - }); + arch.dispatch(AxpyOut { x, y, out, a }); } #[cfg(test)] @@ -365,9 +540,10 @@ mod tests { #[test] fn test_axpy((x, y) in array2(10), a in prop::num::f64::ANY) { + let arch = pulp::Arch::default(); let orig = y.clone(); let mut y = y.clone(); - axpy(&x[..], &mut y[..], a); + axpy(arch, &x[..], &mut y[..], a); for ((&x, y), out) in x.iter().zip(orig).zip(y) { assert_approx_eq(out, a.mul_add(x, y)); } @@ -375,7 +551,8 @@ mod tests { #[test] fn test_scalar_prods2((x1, x2, y1, y2) in array4(10)) { - let (p1, p2) = scalar_prods2(&x1[..], &x2[..], &y1[..], &y2[..]); + let arch = pulp::Arch::default(); + let (p1, p2) = scalar_prods2(arch, &x1[..], &x2[..], &y1[..], &y2[..]); let x1 = ndarray::Array1::from_vec(x1); let x2 = ndarray::Array1::from_vec(x2); let y1 = ndarray::Array1::from_vec(y1); @@ -386,7 +563,8 @@ mod tests { #[test] fn test_scalar_prods3((x1, x2, x3, y1, y2) in array5(10)) { - let (p1, p2) = scalar_prods3(&x1[..], &x2[..], &x3[..], &y1[..], &y2[..]); + let arch = Arch::default(); + let (p1, p2) = scalar_prods3(arch, &x1[..], &x2[..], &x3[..], &y1[..], &y2[..]); let x1 = ndarray::Array1::from_vec(x1); let x2 = ndarray::Array1::from_vec(x2); let x3 = ndarray::Array1::from_vec(x3); @@ -398,8 +576,9 @@ mod tests { #[test] fn test_axpy_out(a in prop::num::f64::ANY, (x, y, out) in array3(10)) { + let arch = Arch::default(); let mut out = out.clone(); - axpy_out(&x[..], &y[..], a, &mut out[..]); + axpy_out(arch, &x[..], &y[..], a, &mut out[..]); let x = ndarray::Array1::from_vec(x); let mut y = ndarray::Array1::from_vec(y); y.scaled_add(a, &x); @@ -410,8 +589,9 @@ mod tests { #[test] fn test_multiplty((x, y, out) in array3(10)) { + let arch = pulp::Arch::default(); let mut out = out.clone(); - multiply(&x[..], &y[..], &mut out[..]); + multiply(arch, &x[..], &y[..], &mut out[..]); let x = ndarray::Array1::from_vec(x); let y = ndarray::Array1::from_vec(y); for (&out1, out2) in out.iter().zip(&x * &y) { @@ -421,7 +601,8 @@ mod tests { #[test] fn test_vector_dot((x, y) in array2(10)) { - let actual = vector_dot(&x[..], &y[..]); + let arch = pulp::Arch::default(); + let actual = vector_dot(arch, &x[..], &y[..]); let x = ndarray::Array1::from_vec(x); let y = ndarray::Array1::from_vec(y); let expected = x.iter().zip(y.iter()).map(|(&x, &y)| x * y).sum(); diff --git a/src/sampler.rs b/src/sampler.rs index e803228..0bb4f7c 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -702,7 +702,7 @@ impl Sampler { let main_thread = spawn(move || { let pool = ThreadPoolBuilder::new() .num_threads(num_cores + 1) // One more thread because the controller also uses one - .thread_name(|i| format!("nutpie-worker-{}", i)) + .thread_name(|i| format!("nutpie-worker-{i}")) .build() .context("Could not start thread pool")?; @@ -927,7 +927,6 @@ pub mod test_logps { array::{Array, ArrayBuilder, FixedSizeListBuilder, PrimitiveBuilder}, datatypes::Float64Type, }; - use multiversion::multiversion; use thiserror::Error; use super::{DrawStorage, Model}; @@ -954,63 +953,18 @@ pub mod test_logps { fn dim(&self) -> usize { self.dim } + fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { let n = position.len(); assert!(gradient.len() == n); - #[cfg(feature = "simd_support")] - #[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] - fn logp_inner(mu: f64, position: &[f64], gradient: &mut [f64]) -> f64 { - use std::simd::f64x4; - use std::simd::num::SimdFloat; - - let n = position.len(); - assert!(gradient.len() == n); - - let head_length = n - n % 4; - - let (pos, pos_tail) = position.split_at(head_length); - let (grad, grad_tail) = gradient.split_at_mut(head_length); - - let mu_splat = f64x4::splat(mu); - - let mut logp = f64x4::splat(0f64); - - for (p, g) in pos.chunks_exact(4).zip(grad.chunks_exact_mut(4)) { - let p = f64x4::from_slice(p); - let val = mu_splat - p; - logp -= val * val * f64x4::splat(0.5); - g.copy_from_slice(&val.to_array()); - } - - let mut logp_tail = 0f64; - for (p, g) in pos_tail.iter().zip(grad_tail.iter_mut()).take(3) { - let val = mu - p; - logp_tail -= val * val / 2.; - *g = val; - } - - logp.reduce_sum() + logp_tail - } - - #[cfg(not(feature = "simd_support"))] - #[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] - fn logp_inner(mu: f64, position: &[f64], gradient: &mut [f64]) -> f64 { - let n = position.len(); - assert!(gradient.len() == n); - - let mut logp = 0f64; - for (p, g) in position.iter().zip(gradient.iter_mut()) { - let val = mu - p; - logp -= val * val / 2.; - *g = val; - } - - logp + let mut logp = 0f64; + for (p, g) in position.iter().zip(gradient.iter_mut()) { + let val = self.mu - p; + logp -= val * val / 2.; + *g = val; } - let logp = logp_inner(self.mu, position, gradient); - Ok(logp) }