Skip to content

Commit 62bcd44

Browse files
committed
Implement simd searcher for haystack of length between 8 and 16 bytes
Benchmarks results summary: * `short_haystack`: -6.5% instructions * `long_haystack`: +0% instructions (no change) * `random_haystack`: +0% instructions (no change)
1 parent 84543bf commit 62bcd44

File tree

1 file changed

+67
-15
lines changed

1 file changed

+67
-15
lines changed

src/x86.rs

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use seq_macro::seq;
66
use std::arch::x86::*;
77
#[cfg(target_arch = "x86_64")]
88
use std::arch::x86_64::*;
9-
use std::mem;
109

1110
trait NeedleWithSize: Needle {
1211
#[inline]
@@ -57,6 +56,8 @@ impl ScalarHash {
5756
/// more generically) in order to share functionality between SSE2, AVX2 and
5857
/// possibly future implementations.
5958
trait Vector: Copy {
59+
const LANES: usize;
60+
6061
unsafe fn set1_epi8(a: i8) -> Self;
6162

6263
unsafe fn loadu_si(a: *const Self) -> Self;
@@ -68,7 +69,48 @@ trait Vector: Copy {
6869
unsafe fn movemask_epi8(a: Self) -> i32;
6970
}
7071

72+
#[derive(Clone, Copy)]
73+
#[repr(transparent)]
74+
#[allow(non_camel_case_types)]
75+
struct __m64i(__m128i);
76+
77+
impl Vector for __m64i {
78+
const LANES: usize = 8;
79+
80+
#[inline]
81+
#[target_feature(enable = "avx2")]
82+
unsafe fn set1_epi8(a: i8) -> Self {
83+
__m64i(_mm_set1_epi8(a))
84+
}
85+
86+
#[inline]
87+
#[target_feature(enable = "avx2")]
88+
unsafe fn loadu_si(a: *const Self) -> Self {
89+
__m64i(_mm_loadu_si128(a as *const std::arch::x86_64::__m128i))
90+
}
91+
92+
#[inline]
93+
#[target_feature(enable = "avx2")]
94+
unsafe fn cmpeq_epi8(a: Self, b: Self) -> Self {
95+
__m64i(_mm_cmpeq_epi8(a.0, b.0))
96+
}
97+
98+
#[inline]
99+
#[target_feature(enable = "avx2")]
100+
unsafe fn and_si(a: Self, b: Self) -> Self {
101+
__m64i(_mm_and_si128(a.0, b.0))
102+
}
103+
104+
#[inline]
105+
#[target_feature(enable = "avx2")]
106+
unsafe fn movemask_epi8(a: Self) -> i32 {
107+
_mm_movemask_epi8(a.0) & 0xFF
108+
}
109+
}
110+
71111
impl Vector for __m128i {
112+
const LANES: usize = 16;
113+
72114
#[inline]
73115
#[target_feature(enable = "avx2")]
74116
unsafe fn set1_epi8(a: i8) -> Self {
@@ -101,6 +143,8 @@ impl Vector for __m128i {
101143
}
102144

103145
impl Vector for __m256i {
146+
const LANES: usize = 32;
147+
104148
#[inline]
105149
#[target_feature(enable = "avx2")]
106150
unsafe fn set1_epi8(a: i8) -> Self {
@@ -183,6 +227,7 @@ impl<V: Vector> VectorHash<V> {
183227
pub struct Avx2Searcher<N: Needle> {
184228
position: usize,
185229
scalar_hash: ScalarHash,
230+
u64_hash: VectorHash<__m64i>,
186231
sse2_hash: VectorHash<__m128i>,
187232
avx2_hash: VectorHash<__m256i>,
188233
needle: N,
@@ -223,12 +268,14 @@ impl<N: Needle> Avx2Searcher<N> {
223268
}
224269

225270
let scalar_hash = ScalarHash::from(bytes);
271+
let u64_hash = VectorHash::new(bytes[0], bytes[position]);
226272
let sse2_hash = VectorHash::new(bytes[0], bytes[position]);
227273
let avx2_hash = VectorHash::new(bytes[0], bytes[position]);
228274

229275
Self {
230276
position,
231277
scalar_hash,
278+
u64_hash,
232279
sse2_hash,
233280
avx2_hash,
234281
needle,
@@ -321,14 +368,13 @@ impl<N: Needle> Avx2Searcher<N> {
321368
) -> bool {
322369
debug_assert!(haystack.len() >= self.needle.size());
323370

324-
let lanes = mem::size_of::<V>();
325371
let end = haystack.len() - self.needle.size() + 1;
326372

327-
if end < lanes {
373+
if end < V::LANES {
328374
return next(self, haystack);
329375
}
330376

331-
let mut chunks = haystack[..end].chunks_exact(lanes);
377+
let mut chunks = haystack[..end].chunks_exact(V::LANES);
332378
while let Some(chunk) = chunks.next() {
333379
if self.vector_search_in_chunk(haystack, hash, chunk.as_ptr(), -1) {
334380
return true;
@@ -337,8 +383,8 @@ impl<N: Needle> Avx2Searcher<N> {
337383

338384
let remainder = chunks.remainder().len();
339385
if remainder > 0 {
340-
let start = haystack.as_ptr().add(end - lanes);
341-
let mask = -1 << (lanes - remainder);
386+
let start = haystack.as_ptr().add(end - V::LANES);
387+
let mask = -1 << (V::LANES - remainder);
342388

343389
if self.vector_search_in_chunk(haystack, hash, start, mask) {
344390
return true;
@@ -348,10 +394,16 @@ impl<N: Needle> Avx2Searcher<N> {
348394
false
349395
}
350396

397+
#[inline]
398+
#[target_feature(enable = "avx2")]
399+
unsafe fn u64_search_in(&self, haystack: &[u8]) -> bool {
400+
self.vector_search_in(haystack, &self.u64_hash, Self::scalar_search_in)
401+
}
402+
351403
#[inline]
352404
#[target_feature(enable = "avx2")]
353405
unsafe fn sse2_search_in(&self, haystack: &[u8]) -> bool {
354-
self.vector_search_in(haystack, &self.sse2_hash, Self::scalar_search_in)
406+
self.vector_search_in(haystack, &self.sse2_hash, Self::u64_search_in)
355407
}
356408

357409
#[inline]
@@ -701,20 +753,20 @@ mod tests {
701753
fn size_of_avx2_searcher() {
702754
use std::mem::size_of;
703755

704-
assert_eq!(size_of::<Avx2Searcher::<&[u8]>>(), 128);
705-
assert_eq!(size_of::<Avx2Searcher::<[u8; 0]>>(), 128);
706-
assert_eq!(size_of::<Avx2Searcher::<[u8; 16]>>(), 128);
707-
assert_eq!(size_of::<Avx2Searcher::<Box<[u8]>>>(), 128);
756+
assert_eq!(size_of::<Avx2Searcher::<&[u8]>>(), 160);
757+
assert_eq!(size_of::<Avx2Searcher::<[u8; 0]>>(), 160);
758+
assert_eq!(size_of::<Avx2Searcher::<[u8; 16]>>(), 160);
759+
assert_eq!(size_of::<Avx2Searcher::<Box<[u8]>>>(), 160);
708760
}
709761

710762
#[test]
711763
#[cfg(target_pointer_width = "64")]
712764
fn size_of_dynamic_avx2_searcher() {
713765
use std::mem::size_of;
714766

715-
assert_eq!(size_of::<DynamicAvx2Searcher::<&[u8]>>(), 160);
716-
assert_eq!(size_of::<DynamicAvx2Searcher::<[u8; 0]>>(), 160);
717-
assert_eq!(size_of::<DynamicAvx2Searcher::<[u8; 16]>>(), 160);
718-
assert_eq!(size_of::<DynamicAvx2Searcher::<Box<[u8]>>>(), 160);
767+
assert_eq!(size_of::<DynamicAvx2Searcher::<&[u8]>>(), 192);
768+
assert_eq!(size_of::<DynamicAvx2Searcher::<[u8; 0]>>(), 192);
769+
assert_eq!(size_of::<DynamicAvx2Searcher::<[u8; 16]>>(), 192);
770+
assert_eq!(size_of::<DynamicAvx2Searcher::<Box<[u8]>>>(), 192);
719771
}
720772
}

0 commit comments

Comments
 (0)