Skip to content

Commit da52daa

Browse files
committed
Add initial AVX512 SIMD implementation
1 parent 57055e3 commit da52daa

File tree

5 files changed

+357
-0
lines changed

5 files changed

+357
-0
lines changed

crates/utils/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
//! Common utilities used by the [`aoc`](../aoc/) and year crates.
22
#![cfg_attr(not(feature = "unsafe"), forbid(unsafe_code))]
3+
#![feature(stdarch_x86_avx512)]
4+
#![feature(avx512_target_feature)]
35

46
pub mod array;
57
pub mod bit;

crates/utils/src/multiversion.rs

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ macro_rules! multiversion {
8989
AVX2x4 => unsafe { $name::avx2x4::$name($($arg_name),*) },
9090
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
9191
AVX2x8 => unsafe { $name::avx2x8::$name($($arg_name),*) },
92+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
93+
AVX512 => unsafe { $name::avx512::$name($($arg_name),*) },
94+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
95+
AVX512x2 => unsafe { $name::avx512x2::$name($($arg_name),*) },
96+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
97+
AVX512x4 => unsafe { $name::avx512x4::$name($($arg_name),*) },
98+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
99+
AVX512x8 => unsafe { $name::avx512x8::$name($($arg_name),*) },
92100
}
93101
}
94102
};
@@ -184,6 +192,50 @@ macro_rules! multiversion {
184192

185193
$crate::multiversion!{@helper target_feature(enable = "avx2") $($tail)*}
186194
}
195+
196+
/// [`multiversion!`] avx512 implementation.
197+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
198+
pub mod avx512 {
199+
#![allow(clippy::missing_safety_doc)]
200+
201+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
202+
use {super::*, $($($path::)+avx512::*),*};
203+
204+
$crate::multiversion!{@helper target_feature(enable = "avx512f") $($tail)*}
205+
}
206+
207+
/// [`multiversion!`] avx512x2 implementation.
208+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
209+
pub mod avx512x2 {
210+
#![allow(clippy::missing_safety_doc)]
211+
212+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
213+
use {super::*, $($($path::)+avx512x2::*),*};
214+
215+
$crate::multiversion!{@helper target_feature(enable = "avx512f") $($tail)*}
216+
}
217+
218+
/// [`multiversion!`] avx512x4 implementation.
219+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
220+
pub mod avx512x4 {
221+
#![allow(clippy::missing_safety_doc)]
222+
223+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
224+
use {super::*, $($($path::)+avx512x4::*),*};
225+
226+
$crate::multiversion!{@helper target_feature(enable = "avx512f") $($tail)*}
227+
}
228+
229+
/// [`multiversion!`] avx512x8 implementation.
230+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
231+
pub mod avx512x8 {
232+
#![allow(clippy::missing_safety_doc)]
233+
234+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
235+
use {super::*, $($($path::)+avx512x8::*),*};
236+
237+
$crate::multiversion!{@helper target_feature(enable = "avx512f") $($tail)*}
238+
}
187239
};
188240

189241
// Microbenchmark for dynamic dispatch
@@ -219,6 +271,14 @@ macro_rules! multiversion {
219271
AVX2x4 => unsafe { avx2x4::$name() },
220272
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
221273
AVX2x8 => unsafe { avx2x8::$name() },
274+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
275+
AVX512 => unsafe { avx512::$name() },
276+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
277+
AVX512x2 => unsafe { avx512x2::$name() },
278+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
279+
AVX512x4 => unsafe { avx512x4::$name() },
280+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
281+
AVX512x8 => unsafe { avx512x8::$name() },
222282
});
223283
(start.elapsed(), x)
224284
})
@@ -395,6 +455,70 @@ macro_rules! multiversion_test {
395455

396456
unsafe { $body }
397457
}
458+
459+
#[test]
460+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
461+
$(#[$m])*
462+
fn avx512() {
463+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
464+
use {$($($path::)+avx512::*),*};
465+
466+
if !$crate::multiversion::Version::AVX512.supported() {
467+
use std::io::{stdout, Write};
468+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512 due to missing avx512 support", module_path!());
469+
return;
470+
}
471+
472+
unsafe { $body }
473+
}
474+
475+
#[test]
476+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
477+
$(#[$m])*
478+
fn avx512x2() {
479+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
480+
use {$($($path::)+avx512x2::*),*};
481+
482+
if !$crate::multiversion::Version::AVX512x2.supported() {
483+
use std::io::{stdout, Write};
484+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512x2 due to missing avx512 support", module_path!());
485+
return;
486+
}
487+
488+
unsafe { $body }
489+
}
490+
491+
#[test]
492+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
493+
$(#[$m])*
494+
fn avx512x4() {
495+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
496+
use {$($($path::)+avx512x4::*),*};
497+
498+
if !$crate::multiversion::Version::AVX512x4.supported() {
499+
use std::io::{stdout, Write};
500+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512x4 due to missing avx512 support", module_path!());
501+
return;
502+
}
503+
504+
unsafe { $body }
505+
}
506+
507+
#[test]
508+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
509+
$(#[$m])*
510+
fn avx512x8() {
511+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
512+
use {$($($path::)+avx512x8::*),*};
513+
514+
if !$crate::multiversion::Version::AVX512x8.supported() {
515+
use std::io::{stdout, Write};
516+
let _ = writeln!(&mut stdout(), "warning: skipping test in {}::avx512x8 due to missing avx512 support", module_path!());
517+
return;
518+
}
519+
520+
unsafe { $body }
521+
}
398522
};
399523

400524
(
@@ -471,6 +595,40 @@ macro_rules! multiversion_test {
471595
$crate::multiversion_test!(@expr { $($tail)+ });
472596
}
473597
}
598+
599+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
600+
if $crate::multiversion::Version::AVX512.supported() {
601+
unsafe {
602+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
603+
use {$($($path::)+avx512::*),*};
604+
605+
$crate::multiversion_test!(@expr { $($tail)+ });
606+
}
607+
608+
#[cfg(feature = "all-simd")]
609+
unsafe {
610+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
611+
use {$($($path::)+avx512x2::*),*};
612+
613+
$crate::multiversion_test!(@expr { $($tail)+ });
614+
}
615+
616+
#[cfg(feature = "all-simd")]
617+
unsafe {
618+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
619+
use {$($($path::)+avx512x4::*),*};
620+
621+
$crate::multiversion_test!(@expr { $($tail)+ });
622+
}
623+
624+
#[cfg(feature = "all-simd")]
625+
unsafe {
626+
#[allow(clippy::allow_attributes, unused_imports, clippy::wildcard_imports)]
627+
use {$($($path::)+avx512x8::*),*};
628+
629+
$crate::multiversion_test!(@expr { $($tail)+ });
630+
}
631+
}
474632
};
475633
(@expr $e:expr) => { $e }
476634
}
@@ -536,6 +694,14 @@ versions_impl! {
536694
AVX2x4 if std::arch::is_x86_feature_detected!("avx2"),
537695
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
538696
AVX2x8 if std::arch::is_x86_feature_detected!("avx2"),
697+
#[cfg(all(feature = "unsafe", any(target_arch = "x86", target_arch = "x86_64")))]
698+
AVX512 if std::arch::is_x86_feature_detected!("avx512f"),
699+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
700+
AVX512x2 if std::arch::is_x86_feature_detected!("avx512f"),
701+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
702+
AVX512x4 if std::arch::is_x86_feature_detected!("avx512f"),
703+
#[cfg(all(feature = "unsafe", feature = "all-simd", any(target_arch = "x86", target_arch = "x86_64")))]
704+
AVX512x8 if std::arch::is_x86_feature_detected!("avx512f"),
539705
}
540706

541707
static OVERRIDE: OnceLock<Option<Version>> = OnceLock::new();

crates/utils/src/simd/avx512.rs

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
//! AVX512 vector implementations.
2+
//!
3+
//! Currently only requires AVX-512F.
4+
5+
use std::array::from_fn;
6+
use std::ops::{Add, BitAnd, BitOr, BitXor, Not};
7+
8+
#[cfg(target_arch = "x86_64")]
9+
#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
10+
use std::arch::x86_64::*;
11+
12+
#[cfg(target_arch = "x86")]
13+
#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
14+
use std::arch::x86::*;
15+
16+
/// AVX512 [u32] vector implementation.
17+
#[derive(Clone, Copy)]
18+
#[repr(transparent)]
19+
pub struct U32Vector<const V: usize, const L: usize>([__m512i; V]);
20+
21+
impl<const V: usize, const L: usize> From<[u32; L]> for U32Vector<V, L> {
22+
#[inline]
23+
fn from(value: [u32; L]) -> Self {
24+
Self(from_fn(|i| unsafe {
25+
_mm512_loadu_si512(value[i * 16..].as_ptr().cast::<i32>())
26+
}))
27+
}
28+
}
29+
30+
impl<const V: usize, const L: usize> From<U32Vector<V, L>> for [u32; L] {
31+
#[inline]
32+
fn from(value: U32Vector<V, L>) -> Self {
33+
let mut result = [0; L];
34+
for (&v, r) in value.0.iter().zip(result.chunks_exact_mut(16)) {
35+
unsafe {
36+
#[expect(
37+
clippy::cast_ptr_alignment,
38+
reason = "_mm512_storeu_si512 is an unaligned store which requires no alignment"
39+
)]
40+
_mm512_storeu_si512(r.as_mut_ptr().cast::<__m512i>(), v);
41+
}
42+
}
43+
result
44+
}
45+
}
46+
47+
impl<const V: usize, const L: usize> Add for U32Vector<V, L> {
48+
type Output = Self;
49+
50+
#[inline]
51+
fn add(self, rhs: Self) -> Self::Output {
52+
Self(from_fn(|i| unsafe {
53+
_mm512_add_epi32(self.0[i], rhs.0[i])
54+
}))
55+
}
56+
}
57+
58+
impl<const V: usize, const L: usize> BitAnd for U32Vector<V, L> {
59+
type Output = Self;
60+
61+
#[inline]
62+
fn bitand(self, rhs: Self) -> Self::Output {
63+
Self(from_fn(|i| unsafe {
64+
_mm512_and_si512(self.0[i], rhs.0[i])
65+
}))
66+
}
67+
}
68+
69+
impl<const V: usize, const L: usize> BitOr for U32Vector<V, L> {
70+
type Output = Self;
71+
72+
#[inline]
73+
fn bitor(self, rhs: Self) -> Self::Output {
74+
Self(from_fn(|i| unsafe { _mm512_or_si512(self.0[i], rhs.0[i]) }))
75+
}
76+
}
77+
78+
impl<const V: usize, const L: usize> BitXor for U32Vector<V, L> {
79+
type Output = Self;
80+
81+
#[inline]
82+
fn bitxor(self, rhs: Self) -> Self::Output {
83+
Self(from_fn(|i| unsafe {
84+
_mm512_xor_si512(self.0[i], rhs.0[i])
85+
}))
86+
}
87+
}
88+
89+
impl<const V: usize, const L: usize> Not for U32Vector<V, L> {
90+
type Output = Self;
91+
92+
#[inline]
93+
fn not(self) -> Self::Output {
94+
Self(from_fn(|i| unsafe {
95+
_mm512_xor_si512(self.0[i], _mm512_set1_epi8(!0))
96+
}))
97+
}
98+
}
99+
100+
impl<const V: usize, const L: usize> U32Vector<V, L> {
101+
pub const LANES: usize = {
102+
assert!(V * 16 == L);
103+
L
104+
};
105+
106+
#[inline]
107+
#[must_use]
108+
pub fn andnot(self, rhs: Self) -> Self {
109+
Self(from_fn(|i| unsafe {
110+
_mm512_andnot_si512(rhs.0[i], self.0[i])
111+
}))
112+
}
113+
114+
#[inline]
115+
#[must_use]
116+
pub fn splat(v: u32) -> Self {
117+
Self(
118+
[unsafe {
119+
#[expect(clippy::cast_possible_wrap)]
120+
_mm512_set1_epi32(v as i32)
121+
}; V],
122+
)
123+
}
124+
125+
#[inline]
126+
#[must_use]
127+
pub fn rotate_left(self, n: u32) -> Self {
128+
Self(from_fn(|i| unsafe {
129+
#[expect(clippy::cast_possible_wrap)]
130+
_mm512_or_si512(
131+
_mm512_sll_epi32(self.0[i], _mm_cvtsi32_si128(n as i32)),
132+
_mm512_srl_epi32(self.0[i], _mm_cvtsi32_si128(32 - n as i32)),
133+
)
134+
}))
135+
}
136+
}
137+
138+
/// Vector implementations using a single AVX512 vector.
139+
pub mod avx512 {
140+
/// The name of this backend.
141+
pub const SIMD_BACKEND: &str = "avx512";
142+
143+
/// AVX512 vector with sixteen [u32] lanes.
144+
pub type U32Vector = super::U32Vector<1, 16>;
145+
}
146+
147+
/// Vector implementations using two AVX512 vectors.
148+
#[cfg(feature = "all-simd")]
149+
pub mod avx512x2 {
150+
/// The name of this backend.
151+
pub const SIMD_BACKEND: &str = "avx512x2";
152+
153+
/// Two AVX512 vectors with thirty-two total [u32] lanes.
154+
pub type U32Vector = super::U32Vector<2, 32>;
155+
}
156+
157+
/// Vector implementations using four AVX512 vectors.
158+
#[cfg(feature = "all-simd")]
159+
pub mod avx512x4 {
160+
/// The name of this backend.
161+
pub const SIMD_BACKEND: &str = "avx512x4";
162+
163+
/// Four AVX512 vectors with sixty-four total [u32] lanes.
164+
pub type U32Vector = super::U32Vector<4, 64>;
165+
}
166+
167+
/// Vector implementations using eight AVX512 vectors.
168+
#[cfg(feature = "all-simd")]
169+
pub mod avx512x8 {
170+
/// The name of this backend.
171+
pub const SIMD_BACKEND: &str = "avx512x8";
172+
173+
/// Eight AVX512 vectors with 128 total [u32] lanes.
174+
pub type U32Vector = super::U32Vector<8, 128>;
175+
}

0 commit comments

Comments
 (0)