Skip to content

Commit 0a0cde0

Browse files
committed
Add Neon, Neonx2, Neonx4 and Neonx8 SIMD implementations
When ran with one thread, neonx2 is ~31% faster on my Raspberry Pi 3 B than the previous fastest implementation (array128). When ran with one thread, neonx8 is ~34% faster on my Raspberry Pi 5 B than the previous fastest implementation (array4096).
1 parent 57055e3 commit 0a0cde0

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

crates/utils/src/multiversion.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ macro_rules! multiversion {
8989
AVX2x4 => unsafe { $name::avx2x4::$name($($arg_name),*) },
9090
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
9191
AVX2x8 => unsafe { $name::avx2x8::$name($($arg_name),*) },
92+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
93+
Neon => $name::neon::$name($($arg_name),*),
94+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
95+
Neonx2 => $name::neonx2::$name($($arg_name),*),
96+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
97+
Neonx4 => $name::neonx4::$name($($arg_name),*),
98+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
99+
Neonx8 => $name::neonx8::$name($($arg_name),*),
92100
}
93101
}
94102
};
@@ -184,6 +192,42 @@ macro_rules! multiversion {
184192

185193
$crate::multiversion!{@helper target_feature(enable = "avx2") $($tail)*}
186194
}
195+
196+
/// [`multiversion!`] neon implementation.
197+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
198+
pub mod neon {
199+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
200+
use {super::*, $($($path::)+neon::*),*};
201+
202+
$($tail)*
203+
}
204+
205+
/// [`multiversion!`] neonx2 implementation.
206+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
207+
pub mod neonx2 {
208+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
209+
use {super::*, $($($path::)+neonx2::*),*};
210+
211+
$($tail)*
212+
}
213+
214+
/// [`multiversion!`] neonx4 implementation.
215+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
216+
pub mod neonx4 {
217+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
218+
use {super::*, $($($path::)+neonx4::*),*};
219+
220+
$($tail)*
221+
}
222+
223+
/// [`multiversion!`] neonx8 implementation.
224+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
225+
pub mod neonx8 {
226+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
227+
use {super::*, $($($path::)+neonx8::*),*};
228+
229+
$($tail)*
230+
}
187231
};
188232

189233
// Microbenchmark for dynamic dispatch
@@ -219,6 +263,14 @@ macro_rules! multiversion {
219263
AVX2x4 => unsafe { avx2x4::$name() },
220264
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
221265
AVX2x8 => unsafe { avx2x8::$name() },
266+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
267+
Neon => neon::$name(),
268+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
269+
Neonx2 => neonx2::$name(),
270+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
271+
Neonx4 => neonx4::$name(),
272+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
273+
Neonx8 => neonx8::$name(),
222274
});
223275
(start.elapsed(), x)
224276
})
@@ -395,6 +447,46 @@ macro_rules! multiversion_test {
395447

396448
unsafe { $body }
397449
}
450+
451+
#[test]
452+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
453+
$(#[$m])*
454+
fn neon() {
455+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
456+
use {$($($path::)+neon::*),*};
457+
458+
$body
459+
}
460+
461+
#[test]
462+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
463+
$(#[$m])*
464+
fn neonx2() {
465+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
466+
use {$($($path::)+neonx2::*),*};
467+
468+
$body
469+
}
470+
471+
#[test]
472+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
473+
$(#[$m])*
474+
fn neonx4() {
475+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
476+
use {$($($path::)+neonx4::*),*};
477+
478+
$body
479+
}
480+
481+
#[test]
482+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
483+
$(#[$m])*
484+
fn neonx8() {
485+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
486+
use {$($($path::)+neonx8::*),*};
487+
488+
$body
489+
}
398490
};
399491

400492
(
@@ -471,6 +563,40 @@ macro_rules! multiversion_test {
471563
$crate::multiversion_test!(@expr { $($tail)+ });
472564
}
473565
}
566+
567+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
568+
{
569+
{
570+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
571+
use {$($($path::)+neon::*),*};
572+
573+
$crate::multiversion_test!(@expr { $($tail)+ });
574+
}
575+
576+
#[cfg(feature = "all-simd")]
577+
{
578+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
579+
use {$($($path::)+neonx2::*),*};
580+
581+
$crate::multiversion_test!(@expr { $($tail)+ });
582+
}
583+
584+
#[cfg(feature = "all-simd")]
585+
{
586+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
587+
use {$($($path::)+neonx4::*),*};
588+
589+
$crate::multiversion_test!(@expr { $($tail)+ });
590+
}
591+
592+
#[cfg(feature = "all-simd")]
593+
{
594+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
595+
use {$($($path::)+neonx8::*),*};
596+
597+
$crate::multiversion_test!(@expr { $($tail)+ });
598+
}
599+
}
474600
};
475601
(@expr $e:expr) => { $e }
476602
}
@@ -536,6 +662,14 @@ versions_impl! {
536662
AVX2x4 if std::arch::is_x86_feature_detected!("avx2"),
537663
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
538664
AVX2x8 if std::arch::is_x86_feature_detected!("avx2"),
665+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
666+
Neon,
667+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
668+
Neonx2,
669+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
670+
Neonx4,
671+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
672+
Neonx8,
539673
}
540674

541675
static OVERRIDE: OnceLock<Option<Version>> = OnceLock::new();

crates/utils/src/simd/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,12 @@ pub use avx2_impl::avx2;
2020
))]
2121
pub use avx2_impl::{avx2x2, avx2x4, avx2x8};
2222

23+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
24+
#[path = "neon.rs"]
25+
mod neon_impl;
26+
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
27+
pub use neon_impl::neon;
28+
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
29+
pub use neon_impl::{neonx2, neonx4, neonx8};
30+
2331
pub mod scalar;

crates/utils/src/simd/neon.rs

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
//! Neon vector implementations.
2+
3+
use std::array::from_fn;
4+
use std::ops::{Add, BitAnd, BitOr, BitXor, Not};
5+
6+
#[expect(clippy::wildcard_imports)]
7+
use std::arch::aarch64::*;
8+
9+
/// Neon [u32] vector implementation.
10+
#[derive(Clone, Copy)]
11+
#[repr(transparent)]
12+
pub struct U32Vector<const V: usize, const L: usize>([uint32x4_t; V]);
13+
14+
impl<const V: usize, const L: usize> From<[u32; L]> for U32Vector<V, L> {
15+
#[inline]
16+
fn from(value: [u32; L]) -> Self {
17+
Self(from_fn(|i| unsafe { vld1q_u32(value[i * 4..].as_ptr()) }))
18+
}
19+
}
20+
21+
impl<const V: usize, const L: usize> From<U32Vector<V, L>> for [u32; L] {
22+
#[inline]
23+
fn from(value: U32Vector<V, L>) -> Self {
24+
let mut result = [0; L];
25+
for (&v, r) in value.0.iter().zip(result.chunks_exact_mut(4)) {
26+
unsafe {
27+
vst1q_u32(r.as_mut_ptr(), v);
28+
}
29+
}
30+
result
31+
}
32+
}
33+
34+
impl<const V: usize, const L: usize> Add for U32Vector<V, L> {
35+
type Output = Self;
36+
37+
#[inline]
38+
fn add(self, rhs: Self) -> Self::Output {
39+
Self(from_fn(|i| unsafe { vaddq_u32(self.0[i], rhs.0[i]) }))
40+
}
41+
}
42+
43+
impl<const V: usize, const L: usize> BitAnd for U32Vector<V, L> {
44+
type Output = Self;
45+
46+
#[inline]
47+
fn bitand(self, rhs: Self) -> Self::Output {
48+
Self(from_fn(|i| unsafe { vandq_u32(self.0[i], rhs.0[i]) }))
49+
}
50+
}
51+
52+
impl<const V: usize, const L: usize> BitOr for U32Vector<V, L> {
53+
type Output = Self;
54+
55+
#[inline]
56+
fn bitor(self, rhs: Self) -> Self::Output {
57+
Self(from_fn(|i| unsafe { vorrq_u32(self.0[i], rhs.0[i]) }))
58+
}
59+
}
60+
61+
impl<const V: usize, const L: usize> BitXor for U32Vector<V, L> {
62+
type Output = Self;
63+
64+
#[inline]
65+
fn bitxor(self, rhs: Self) -> Self::Output {
66+
Self(from_fn(|i| unsafe { veorq_u32(self.0[i], rhs.0[i]) }))
67+
}
68+
}
69+
70+
impl<const V: usize, const L: usize> Not for U32Vector<V, L> {
71+
type Output = Self;
72+
73+
#[inline]
74+
fn not(self) -> Self::Output {
75+
Self(from_fn(|i| unsafe {
76+
veorq_u32(self.0[i], vdupq_n_u32(!0))
77+
}))
78+
}
79+
}
80+
81+
impl<const V: usize, const L: usize> U32Vector<V, L> {
82+
pub const LANES: usize = {
83+
assert!(V * 4 == L);
84+
L
85+
};
86+
87+
#[inline]
88+
#[must_use]
89+
pub fn andnot(self, rhs: Self) -> Self {
90+
Self(from_fn(|i| unsafe { vbicq_u32(self.0[i], rhs.0[i]) }))
91+
}
92+
93+
#[inline]
94+
#[must_use]
95+
pub fn splat(v: u32) -> Self {
96+
Self([unsafe { vdupq_n_u32(v) }; V])
97+
}
98+
99+
#[inline]
100+
#[must_use]
101+
pub fn rotate_left(self, n: u32) -> Self {
102+
Self(from_fn(|i| unsafe {
103+
#[expect(clippy::cast_possible_wrap)]
104+
vorrq_u32(
105+
vshlq_u32(self.0[i], vdupq_n_s32(n as i32)),
106+
vshlq_u32(self.0[i], vdupq_n_s32(-(32 - n as i32))),
107+
)
108+
}))
109+
}
110+
}
111+
112+
/// Vector implementations using a single Neon vector.
113+
pub mod neon {
114+
/// The name of this backend.
115+
pub const SIMD_BACKEND: &str = "neon";
116+
117+
/// Neon vector with four [u32] lanes.
118+
pub type U32Vector = super::U32Vector<1, 4>;
119+
}
120+
121+
/// Vector implementations using two Neon vectors.
122+
#[cfg(feature = "all-simd")]
123+
pub mod neonx2 {
124+
/// The name of this backend.
125+
pub const SIMD_BACKEND: &str = "neonx2";
126+
127+
/// Two Neon vectors with eight total [u32] lanes.
128+
pub type U32Vector = super::U32Vector<2, 8>;
129+
}
130+
131+
/// Vector implementations using four Neon vectors.
132+
#[cfg(feature = "all-simd")]
133+
pub mod neonx4 {
134+
/// The name of this backend.
135+
pub const SIMD_BACKEND: &str = "neonx4";
136+
137+
/// Four Neon vectors with sixteen total [u32] lanes.
138+
pub type U32Vector = super::U32Vector<4, 16>;
139+
}
140+
141+
/// Vector implementations using eight Neon vectors.
142+
#[cfg(feature = "all-simd")]
143+
pub mod neonx8 {
144+
/// The name of this backend.
145+
pub const SIMD_BACKEND: &str = "neonx8";
146+
147+
/// Eight Neon vectors with thirty-two total [u32] lanes.
148+
pub type U32Vector = super::U32Vector<8, 32>;
149+
}

0 commit comments

Comments
 (0)