Skip to content

Commit c5cc8fa

Browse files
zakcutnermarmeladema
authored andcommitted
Share implementation between SSE2 and AVX2 methods
1 parent 07537aa commit c5cc8fa

File tree

1 file changed

+120
-68
lines changed

1 file changed

+120
-68
lines changed

src/avx2/mod.rs

Lines changed: 120 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@ pub use self::{original::*, rust::*};
66
use crate::{bits, memchr::MemchrSearcher, memcmp};
77
use std::{
88
arch::x86_64::*,
9+
mem,
910
ops::{AddAssign, SubAssign},
1011
};
1112

12-
const AVX2_LANES: usize = 32;
13-
const SSE2_LANES: usize = 16;
14-
1513
#[derive(Clone, Copy, Default, PartialEq)]
16-
struct Hash(usize);
14+
struct ScalarHash(usize);
1715

18-
impl From<&[u8]> for Hash {
16+
impl From<&[u8]> for ScalarHash {
1917
#[inline(always)]
2018
fn from(bytes: &[u8]) -> Self {
2119
bytes.iter().fold(Default::default(), |mut hash, &b| {
@@ -25,30 +23,108 @@ impl From<&[u8]> for Hash {
2523
}
2624
}
2725

28-
impl AddAssign<u8> for Hash {
26+
impl AddAssign<u8> for ScalarHash {
2927
#[inline(always)]
3028
fn add_assign(&mut self, b: u8) {
3129
self.0 += usize::from(b);
3230
}
3331
}
3432

35-
impl SubAssign<u8> for Hash {
33+
impl SubAssign<u8> for ScalarHash {
3634
#[inline(always)]
3735
fn sub_assign(&mut self, b: u8) {
3836
self.0 -= usize::from(b);
3937
}
4038
}
4139

40+
trait Vector: Copy {
41+
unsafe fn set1_epi8(a: i8) -> Self;
42+
43+
unsafe fn loadu_si(a: *const Self) -> Self;
44+
45+
unsafe fn cmpeq_epi8(a: Self, b: Self) -> Self;
46+
47+
unsafe fn and_si(a: Self, b: Self) -> Self;
48+
49+
unsafe fn movemask_epi8(a: Self) -> i32;
50+
}
51+
52+
impl Vector for __m128i {
53+
#[inline(always)]
54+
unsafe fn set1_epi8(a: i8) -> Self {
55+
_mm_set1_epi8(a)
56+
}
57+
58+
#[inline(always)]
59+
unsafe fn loadu_si(a: *const Self) -> Self {
60+
_mm_loadu_si128(a)
61+
}
62+
63+
#[inline(always)]
64+
unsafe fn cmpeq_epi8(a: Self, b: Self) -> Self {
65+
_mm_cmpeq_epi8(a, b)
66+
}
67+
68+
#[inline(always)]
69+
unsafe fn and_si(a: Self, b: Self) -> Self {
70+
_mm_and_si128(a, b)
71+
}
72+
73+
#[inline(always)]
74+
unsafe fn movemask_epi8(a: Self) -> i32 {
75+
_mm_movemask_epi8(a)
76+
}
77+
}
78+
79+
impl Vector for __m256i {
80+
#[inline(always)]
81+
unsafe fn set1_epi8(a: i8) -> Self {
82+
_mm256_set1_epi8(a)
83+
}
84+
85+
#[inline(always)]
86+
unsafe fn loadu_si(a: *const Self) -> Self {
87+
_mm256_loadu_si256(a)
88+
}
89+
90+
#[inline(always)]
91+
unsafe fn cmpeq_epi8(a: Self, b: Self) -> Self {
92+
_mm256_cmpeq_epi8(a, b)
93+
}
94+
95+
#[inline(always)]
96+
unsafe fn and_si(a: Self, b: Self) -> Self {
97+
_mm256_and_si256(a, b)
98+
}
99+
100+
#[inline(always)]
101+
unsafe fn movemask_epi8(a: Self) -> i32 {
102+
_mm256_movemask_epi8(a)
103+
}
104+
}
105+
106+
struct VectorHash<V: Vector> {
107+
first: V,
108+
last: V,
109+
}
110+
111+
impl<V: Vector> VectorHash<V> {
112+
fn new(first: u8, last: u8) -> Self {
113+
Self {
114+
first: unsafe { Vector::set1_epi8(first as i8) },
115+
last: unsafe { Vector::set1_epi8(last as i8) },
116+
}
117+
}
118+
}
119+
42120
macro_rules! avx2_searcher {
43121
($name:ident, $size:literal, $memcmp:path) => {
44122
pub struct $name {
45123
needle: Box<[u8]>,
46124
position: usize,
47-
hash: Hash,
48-
sse2_first: __m128i,
49-
sse2_last: __m128i,
50-
avx2_first: __m256i,
51-
avx2_last: __m256i,
125+
scalar_hash: ScalarHash,
126+
sse2_hash: VectorHash<__m128i>,
127+
avx2_hash: VectorHash<__m256i>,
52128
}
53129

54130
impl $name {
@@ -61,20 +137,16 @@ macro_rules! avx2_searcher {
61137
assert!(!needle.is_empty());
62138
assert!(position < needle.len());
63139

64-
let hash = Hash::from(needle.as_ref());
65-
let sse2_first = unsafe { _mm_set1_epi8(needle[0] as i8) };
66-
let sse2_last = unsafe { _mm_set1_epi8(needle[position] as i8) };
67-
let avx2_first = unsafe { _mm256_set1_epi8(needle[0] as i8) };
68-
let avx2_last = unsafe { _mm256_set1_epi8(needle[position] as i8) };
140+
let scalar_hash = ScalarHash::from(needle.as_ref());
141+
let sse2_hash = VectorHash::new(needle[0], needle[position]);
142+
let avx2_hash = VectorHash::new(needle[0], needle[position]);
69143

70144
Self {
71145
needle,
72146
position,
73-
hash,
74-
sse2_first,
75-
sse2_last,
76-
avx2_first,
77-
avx2_last,
147+
scalar_hash,
148+
sse2_hash,
149+
avx2_hash,
78150
}
79151
}
80152

@@ -93,14 +165,14 @@ macro_rules! avx2_searcher {
93165
debug_assert!(haystack.len() >= self.size());
94166

95167
let mut end = self.size() - 1;
96-
let mut hash = Hash::from(&haystack[..end]);
168+
let mut hash = ScalarHash::from(&haystack[..end]);
97169

98170
while end < haystack.len() {
99171
hash += *unsafe { haystack.get_unchecked(end) };
100172
end += 1;
101173

102174
let start = end - self.size();
103-
if hash == self.hash && haystack[start..end] == *self.needle {
175+
if hash == self.scalar_hash && haystack[start..end] == *self.needle {
104176
return true;
105177
}
106178

@@ -111,22 +183,28 @@ macro_rules! avx2_searcher {
111183
}
112184

113185
#[inline(always)]
114-
fn sse2_search_in(&self, haystack: &[u8]) -> bool {
115-
if haystack.len() < SSE2_LANES {
116-
return self.scalar_search_in(haystack);
186+
fn vector_search_in<V: Vector>(
187+
&self,
188+
haystack: &[u8],
189+
hash: &VectorHash<V>,
190+
next: fn(&Self, &[u8]) -> bool,
191+
) -> bool {
192+
let lanes = mem::size_of::<V>();
193+
if haystack.len() < lanes {
194+
return next(self, haystack);
117195
}
118196

119-
let mut chunks = haystack[..=haystack.len() - self.size()].chunks_exact(SSE2_LANES);
197+
let mut chunks = haystack[..=haystack.len() - self.size()].chunks_exact(lanes);
120198
while let Some(chunk) = chunks.next() {
121199
let start = chunk.as_ptr();
122-
let first = unsafe { _mm_loadu_si128(start.cast()) };
123-
let last = unsafe { _mm_loadu_si128(start.add(self.position).cast()) };
200+
let first = unsafe { Vector::loadu_si(start.cast()) };
201+
let last = unsafe { Vector::loadu_si(start.add(self.position).cast()) };
124202

125-
let mask_first = unsafe { _mm_cmpeq_epi8(self.sse2_first, first) };
126-
let mask_last = unsafe { _mm_cmpeq_epi8(self.sse2_last, last) };
203+
let mask_first = unsafe { Vector::cmpeq_epi8(hash.first, first) };
204+
let mask_last = unsafe { Vector::cmpeq_epi8(hash.last, last) };
127205

128-
let mask = unsafe { _mm_and_si128(mask_first, mask_last) };
129-
let mut mask = unsafe { _mm_movemask_epi8(mask) } as u32;
206+
let mask = unsafe { Vector::and_si(mask_first, mask_last) };
207+
let mut mask = unsafe { Vector::movemask_epi8(mask) } as u32;
130208

131209
let start = start as usize - haystack.as_ptr() as usize;
132210
while mask != 0 {
@@ -140,46 +218,20 @@ macro_rules! avx2_searcher {
140218
}
141219

142220
let remainder = chunks.remainder();
143-
debug_assert!(remainder.len() < SSE2_LANES);
221+
debug_assert!(remainder.len() < lanes);
144222

145223
let chunk = &haystack[remainder.as_ptr() as usize - haystack.as_ptr() as usize..];
146-
self.scalar_search_in(chunk)
224+
next(self, chunk)
147225
}
148226

149227
#[inline(always)]
150-
fn avx2_search_in(&self, haystack: &[u8]) -> bool {
151-
if haystack.len() < AVX2_LANES {
152-
return self.sse2_search_in(haystack);
153-
}
154-
155-
let mut chunks = haystack[..=haystack.len() - self.size()].chunks_exact(AVX2_LANES);
156-
while let Some(chunk) = chunks.next() {
157-
let start = chunk.as_ptr();
158-
let first = unsafe { _mm256_loadu_si256(start.cast()) };
159-
let last = unsafe { _mm256_loadu_si256(start.add(self.position).cast()) };
160-
161-
let mask_first = unsafe { _mm256_cmpeq_epi8(self.avx2_first, first) };
162-
let mask_last = unsafe { _mm256_cmpeq_epi8(self.avx2_last, last) };
163-
164-
let mask = unsafe { _mm256_and_si256(mask_first, mask_last) };
165-
let mut mask = unsafe { _mm256_movemask_epi8(mask) } as u32;
166-
167-
let start = start as usize - haystack.as_ptr() as usize;
168-
while mask != 0 {
169-
let chunk = &haystack[start + mask.trailing_zeros() as usize..];
170-
if unsafe { $memcmp(&chunk[1..self.size()], &self.needle[1..]) } {
171-
return true;
172-
}
173-
174-
mask = bits::clear_leftmost_set(mask);
175-
}
176-
}
177-
178-
let remainder = chunks.remainder();
179-
debug_assert!(remainder.len() < AVX2_LANES);
228+
fn sse2_search_in(&self, haystack: &[u8]) -> bool {
229+
self.vector_search_in(haystack, &self.sse2_hash, Self::scalar_search_in)
230+
}
180231

181-
let chunk = &haystack[remainder.as_ptr() as usize - haystack.as_ptr() as usize..];
182-
self.sse2_search_in(chunk)
232+
#[inline(always)]
233+
fn avx2_search_in(&self, haystack: &[u8]) -> bool {
234+
self.vector_search_in(haystack, &self.avx2_hash, Self::sse2_search_in)
183235
}
184236

185237
#[inline(always)]

0 commit comments

Comments
 (0)