Skip to content

Commit 8333d03

Browse files
committed
Add initial AVX512 SIMD implementation
The AVX512 target feature and intrinsics have been stabilized in nightly!
1 parent ac9a83f commit 8333d03

File tree

5 files changed

+351
-1
lines changed

5 files changed

+351
-1
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ edition = "2024"
1010
license = "MIT"
1111
publish = false
1212
repository = "https://github.com/ictrobot/aoc-rs"
13-
rust-version = "1.87.0"
13+
rust-version = "1.89.0"
1414

1515
[workspace.lints.clippy]
1616
pedantic = { level = "warn", priority = -1 }

crates/utils/src/multiversion.rs

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ macro_rules! multiversion {
8888
AVX2x4 => unsafe { $name::avx2x4::$name($($arg_name),*) },
8989
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
9090
AVX2x8 => unsafe { $name::avx2x8::$name($($arg_name),*) },
91+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
92+
AVX512 => unsafe { $name::avx512::$name($($arg_name),*) },
93+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
94+
AVX512x2 => unsafe { $name::avx512x2::$name($($arg_name),*) },
95+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
96+
AVX512x4 => unsafe { $name::avx512x4::$name($($arg_name),*) },
97+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
98+
AVX512x8 => unsafe { $name::avx512x8::$name($($arg_name),*) },
9199
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
92100
Neon => unsafe { $name::neon::$name($($arg_name),*) },
93101
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
@@ -184,6 +192,42 @@ macro_rules! multiversion {
184192
$crate::multiversion!{@enable target_feature(enable = "avx2") $($tail)*}
185193
}
186194

195+
/// [`multiversion!`] avx512 implementation.
196+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
197+
pub mod avx512 {
198+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
199+
use {super::*, $($($path::)+avx512::*),*};
200+
201+
$crate::multiversion!{@enable target_feature(enable = "avx512f") $($tail)*}
202+
}
203+
204+
/// [`multiversion!`] avx512x2 implementation.
205+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
206+
pub mod avx512x2 {
207+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
208+
use {super::*, $($($path::)+avx512x2::*),*};
209+
210+
$crate::multiversion!{@enable target_feature(enable = "avx512f") $($tail)*}
211+
}
212+
213+
/// [`multiversion!`] avx512x4 implementation.
214+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
215+
pub mod avx512x4 {
216+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
217+
use {super::*, $($($path::)+avx512x4::*),*};
218+
219+
$crate::multiversion!{@enable target_feature(enable = "avx512f") $($tail)*}
220+
}
221+
222+
/// [`multiversion!`] avx512x8 implementation.
223+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
224+
pub mod avx512x8 {
225+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
226+
use {super::*, $($($path::)+avx512x8::*),*};
227+
228+
$crate::multiversion!{@enable target_feature(enable = "avx512f") $($tail)*}
229+
}
230+
187231
/// [`multiversion!`] neon implementation.
188232
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
189233
pub mod neon {
@@ -254,6 +298,14 @@ macro_rules! multiversion {
254298
AVX2x4 => unsafe { avx2x4::$name() },
255299
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
256300
AVX2x8 => unsafe { avx2x8::$name() },
301+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
302+
AVX512 => unsafe { avx512::$name() },
303+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
304+
AVX512x2 => unsafe { avx512x2::$name() },
305+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
306+
AVX512x4 => unsafe { avx512x4::$name() },
307+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
308+
AVX512x8 => unsafe { avx512x8::$name() },
257309
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
258310
Neon => unsafe { neon::$name() },
259311
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]
@@ -433,6 +485,70 @@ macro_rules! multiversion_test {
433485
unsafe { $body }
434486
}
435487

488+
#[test]
489+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
490+
$(#[$m])*
491+
fn avx512() {
492+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
493+
use {$($($path::)+avx512::*),*};
494+
495+
if !$crate::multiversion::Version::AVX512.supported() {
496+
use std::io::{stdout, Write};
497+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512 due to missing avx512 support", module_path!());
498+
return;
499+
}
500+
501+
unsafe { $body }
502+
}
503+
504+
#[test]
505+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
506+
$(#[$m])*
507+
fn avx512x2() {
508+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
509+
use {$($($path::)+avx512x2::*),*};
510+
511+
if !$crate::multiversion::Version::AVX512x2.supported() {
512+
use std::io::{stdout, Write};
513+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512x2 due to missing avx512 support", module_path!());
514+
return;
515+
}
516+
517+
unsafe { $body }
518+
}
519+
520+
#[test]
521+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
522+
$(#[$m])*
523+
fn avx512x4() {
524+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
525+
use {$($($path::)+avx512x4::*),*};
526+
527+
if !$crate::multiversion::Version::AVX512x4.supported() {
528+
use std::io::{stdout, Write};
529+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512x4 due to missing avx512 support", module_path!());
530+
return;
531+
}
532+
533+
unsafe { $body }
534+
}
535+
536+
#[test]
537+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
538+
$(#[$m])*
539+
fn avx512x8() {
540+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
541+
use {$($($path::)+avx512x8::*),*};
542+
543+
if !$crate::multiversion::Version::AVX512x8.supported() {
544+
use std::io::{stdout, Write};
545+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512x8 due to missing avx512 support", module_path!());
546+
return;
547+
}
548+
549+
unsafe { $body }
550+
}
551+
436552
#[test]
437553
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
438554
$(#[$m])*
@@ -549,6 +665,40 @@ macro_rules! multiversion_test {
549665
}
550666
}
551667

668+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
669+
if $crate::multiversion::Version::AVX512.supported() {
670+
unsafe {
671+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
672+
use {$($($path::)+avx512::*),*};
673+
674+
$crate::multiversion_test!(@expr { $($tail)+ });
675+
}
676+
677+
#[cfg(feature = "all-simd")]
678+
unsafe {
679+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
680+
use {$($($path::)+avx512x2::*),*};
681+
682+
$crate::multiversion_test!(@expr { $($tail)+ });
683+
}
684+
685+
#[cfg(feature = "all-simd")]
686+
unsafe {
687+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
688+
use {$($($path::)+avx512x4::*),*};
689+
690+
$crate::multiversion_test!(@expr { $($tail)+ });
691+
}
692+
693+
#[cfg(feature = "all-simd")]
694+
unsafe {
695+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
696+
use {$($($path::)+avx512x8::*),*};
697+
698+
$crate::multiversion_test!(@expr { $($tail)+ });
699+
}
700+
}
701+
552702
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
553703
{
554704
unsafe {
@@ -647,6 +797,14 @@ versions_impl! {
647797
AVX2x4 if std::arch::is_x86_feature_detected!("avx2"),
648798
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
649799
AVX2x8 if std::arch::is_x86_feature_detected!("avx2"),
800+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
801+
AVX512 if std::arch::is_x86_feature_detected!("avx512f"),
802+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
803+
AVX512x2 if std::arch::is_x86_feature_detected!("avx512f"),
804+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
805+
AVX512x4 if std::arch::is_x86_feature_detected!("avx512f"),
806+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
807+
AVX512x8 if std::arch::is_x86_feature_detected!("avx512f"),
650808
#[cfg(all(feature = "unsafe", target_arch = "aarch64"))]
651809
Neon,
652810
#[cfg(all(feature = "unsafe", feature = "all-simd", target_arch = "aarch64"))]

crates/utils/src/simd/avx512.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//! AVX512 vector implementations.
2+
//!
3+
//! Currently only requires AVX-512F.
4+
5+
use std::array::from_fn;
6+
use std::ops::{Add, BitAnd, BitOr, BitXor, Not};
7+
8+
#[cfg(target_arch = "x86_64")]
9+
#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
10+
use std::arch::x86_64::*;
11+
12+
#[cfg(target_arch = "x86")]
13+
#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
14+
use std::arch::x86::*;
15+
16+
/// AVX512 [u32] vector implementation.
17+
#[derive(Clone, Copy)]
18+
#[repr(transparent)]
19+
pub struct U32Vector<const V: usize, const L: usize>([__m512i; V]);
20+
21+
impl<const V: usize, const L: usize> From<[u32; L]> for U32Vector<V, L> {
22+
#[inline]
23+
fn from(value: [u32; L]) -> Self {
24+
Self(from_fn(|i| unsafe {
25+
#[expect(
26+
clippy::cast_ptr_alignment,
27+
reason = "_mm512_loadu_si512 is an unaligned load which requires no alignment"
28+
)]
29+
_mm512_loadu_si512(value[i * 16..].as_ptr().cast::<__m512i>())
30+
}))
31+
}
32+
}
33+
34+
impl<const V: usize, const L: usize> From<U32Vector<V, L>> for [u32; L] {
35+
#[inline]
36+
fn from(value: U32Vector<V, L>) -> Self {
37+
let mut result = [0; L];
38+
for (&v, r) in value.0.iter().zip(result.chunks_exact_mut(16)) {
39+
unsafe {
40+
#[expect(
41+
clippy::cast_ptr_alignment,
42+
reason = "_mm512_storeu_si512 is an unaligned store which requires no alignment"
43+
)]
44+
_mm512_storeu_si512(r.as_mut_ptr().cast::<__m512i>(), v);
45+
}
46+
}
47+
result
48+
}
49+
}
50+
51+
impl<const V: usize, const L: usize> Add for U32Vector<V, L> {
52+
type Output = Self;
53+
54+
#[inline]
55+
fn add(self, rhs: Self) -> Self::Output {
56+
Self(from_fn(|i| unsafe {
57+
_mm512_add_epi32(self.0[i], rhs.0[i])
58+
}))
59+
}
60+
}
61+
62+
impl<const V: usize, const L: usize> BitAnd for U32Vector<V, L> {
63+
type Output = Self;
64+
65+
#[inline]
66+
fn bitand(self, rhs: Self) -> Self::Output {
67+
Self(from_fn(|i| unsafe {
68+
_mm512_and_si512(self.0[i], rhs.0[i])
69+
}))
70+
}
71+
}
72+
73+
impl<const V: usize, const L: usize> BitOr for U32Vector<V, L> {
74+
type Output = Self;
75+
76+
#[inline]
77+
fn bitor(self, rhs: Self) -> Self::Output {
78+
Self(from_fn(|i| unsafe { _mm512_or_si512(self.0[i], rhs.0[i]) }))
79+
}
80+
}
81+
82+
impl<const V: usize, const L: usize> BitXor for U32Vector<V, L> {
83+
type Output = Self;
84+
85+
#[inline]
86+
fn bitxor(self, rhs: Self) -> Self::Output {
87+
Self(from_fn(|i| unsafe {
88+
_mm512_xor_si512(self.0[i], rhs.0[i])
89+
}))
90+
}
91+
}
92+
93+
impl<const V: usize, const L: usize> Not for U32Vector<V, L> {
94+
type Output = Self;
95+
96+
#[inline]
97+
fn not(self) -> Self::Output {
98+
Self(from_fn(|i| unsafe {
99+
_mm512_xor_si512(self.0[i], _mm512_set1_epi8(!0))
100+
}))
101+
}
102+
}
103+
104+
impl<const V: usize, const L: usize> U32Vector<V, L> {
105+
pub const LANES: usize = {
106+
assert!(V * 16 == L);
107+
L
108+
};
109+
110+
#[inline]
111+
#[must_use]
112+
#[target_feature(enable = "avx512f")]
113+
pub fn andnot(self, rhs: Self) -> Self {
114+
Self(from_fn(|i| _mm512_andnot_si512(rhs.0[i], self.0[i])))
115+
}
116+
117+
#[inline]
118+
#[must_use]
119+
#[target_feature(enable = "avx512f")]
120+
pub fn splat(v: u32) -> Self {
121+
Self(
122+
#[expect(clippy::cast_possible_wrap)]
123+
[_mm512_set1_epi32(v as i32); V],
124+
)
125+
}
126+
127+
#[inline]
128+
#[must_use]
129+
#[target_feature(enable = "avx512f")]
130+
pub fn rotate_left(self, n: u32) -> Self {
131+
Self(from_fn(|i| {
132+
#[expect(clippy::cast_possible_wrap)]
133+
_mm512_or_si512(
134+
_mm512_sll_epi32(self.0[i], _mm_cvtsi32_si128(n as i32)),
135+
_mm512_srl_epi32(self.0[i], _mm_cvtsi32_si128(32 - n as i32)),
136+
)
137+
}))
138+
}
139+
}
140+
141+
/// Vector implementations using a single AVX512 vector.
142+
pub mod avx512 {
143+
/// The name of this backend.
144+
pub const SIMD_BACKEND: &str = "avx512";
145+
146+
/// AVX512 vector with sixteen [u32] lanes.
147+
pub type U32Vector = super::U32Vector<1, 16>;
148+
}
149+
150+
/// Vector implementations using two AVX512 vectors.
151+
#[cfg(feature = "all-simd")]
152+
pub mod avx512x2 {
153+
/// The name of this backend.
154+
pub const SIMD_BACKEND: &str = "avx512x2";
155+
156+
/// Two AVX512 vectors with thirty-two total [u32] lanes.
157+
pub type U32Vector = super::U32Vector<2, 32>;
158+
}
159+
160+
/// Vector implementations using four AVX512 vectors.
161+
#[cfg(feature = "all-simd")]
162+
pub mod avx512x4 {
163+
/// The name of this backend.
164+
pub const SIMD_BACKEND: &str = "avx512x4";
165+
166+
/// Four AVX512 vectors with sixty-four total [u32] lanes.
167+
pub type U32Vector = super::U32Vector<4, 64>;
168+
}
169+
170+
/// Vector implementations using eight AVX512 vectors.
171+
#[cfg(feature = "all-simd")]
172+
pub mod avx512x8 {
173+
/// The name of this backend.
174+
pub const SIMD_BACKEND: &str = "avx512x8";
175+
176+
/// Eight AVX512 vectors with 128 total [u32] lanes.
177+
pub type U32Vector = super::U32Vector<8, 128>;
178+
}

0 commit comments

Comments
 (0)