Skip to content

Commit d1053b6

Browse files
committed
Add 2-lanes searcher and get rid of ScalarSearcher
Benchmarks results summary: * `short_haystack`: -10% instructions * `long_haystack`: -0.2% instructions * `random_haystack`: +1.4% instructions
1 parent 2aa5491 commit d1053b6

File tree

1 file changed

+80
-68
lines changed

1 file changed

+80
-68
lines changed

src/x86.rs

Lines changed: 80 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,6 @@ trait NeedleWithSize: Needle {
2525

2626
impl<N: Needle + ?Sized> NeedleWithSize for N {}
2727

28-
/// Rolling hash for the simple Rabin-Karp implementation. As a hashing
29-
/// function, the sum of all the bytes is computed.
30-
#[derive(Clone, Copy, Default, PartialEq)]
31-
struct ScalarHash(usize);
32-
33-
impl From<&[u8]> for ScalarHash {
34-
#[inline]
35-
fn from(bytes: &[u8]) -> Self {
36-
bytes.iter().fold(Default::default(), |mut hash, &b| {
37-
hash.push(b);
38-
hash
39-
})
40-
}
41-
}
42-
43-
impl ScalarHash {
44-
#[inline]
45-
fn push(&mut self, b: u8) {
46-
self.0 = self.0.wrapping_add(b.into());
47-
}
48-
49-
#[inline]
50-
fn pop(&mut self, b: u8) {
51-
self.0 = self.0.wrapping_sub(b.into());
52-
}
53-
}
54-
5528
/// Represents an SIMD register type that is x86-specific (but could be used
5629
/// more generically) in order to share functionality between SSE2, AVX2 and
5730
/// possibly future implementations.
@@ -69,6 +42,45 @@ trait Vector: Copy {
6942
unsafe fn movemask_epi8(a: Self) -> i32;
7043
}
7144

45+
#[derive(Clone, Copy)]
46+
#[repr(transparent)]
47+
#[allow(non_camel_case_types)]
48+
struct __m16i(__m128i);
49+
50+
impl Vector for __m16i {
51+
const LANES: usize = 2;
52+
53+
#[inline]
54+
#[target_feature(enable = "avx2")]
55+
unsafe fn set1_epi8(a: i8) -> Self {
56+
__m16i(_mm_set1_epi8(a))
57+
}
58+
59+
#[inline]
60+
#[target_feature(enable = "avx2")]
61+
unsafe fn loadu_si(a: *const Self) -> Self {
62+
__m16i(_mm_set1_epi16(std::ptr::read_unaligned(a as *const i16)))
63+
}
64+
65+
#[inline]
66+
#[target_feature(enable = "avx2")]
67+
unsafe fn cmpeq_epi8(a: Self, b: Self) -> Self {
68+
__m16i(_mm_cmpeq_epi8(a.0, b.0))
69+
}
70+
71+
#[inline]
72+
#[target_feature(enable = "avx2")]
73+
unsafe fn and_si(a: Self, b: Self) -> Self {
74+
__m16i(_mm_and_si128(a.0, b.0))
75+
}
76+
77+
#[inline]
78+
#[target_feature(enable = "avx2")]
79+
unsafe fn movemask_epi8(a: Self) -> i32 {
80+
_mm_movemask_epi8(a.0) & 0x3
81+
}
82+
}
83+
7284
#[derive(Clone, Copy)]
7385
#[repr(transparent)]
7486
#[allow(non_camel_case_types)]
@@ -254,6 +266,16 @@ impl From<&VectorHash<__m128i>> for VectorHash<__m32i> {
254266
}
255267
}
256268

269+
impl From<&VectorHash<__m128i>> for VectorHash<__m16i> {
270+
#[inline]
271+
fn from(hash: &VectorHash<__m128i>) -> Self {
272+
Self {
273+
first: __m16i(hash.first),
274+
last: __m16i(hash.last),
275+
}
276+
}
277+
}
278+
257279
/// Single-substring searcher using an AVX2 algorithm based on the "Generic
258280
/// SIMD" algorithm [presented by Wojciech
259281
/// Muła](http://0x80.pl/articles/simd-strfind.html).
@@ -285,7 +307,6 @@ impl From<&VectorHash<__m128i>> for VectorHash<__m32i> {
285307
/// Rabin-Karp implementation.
286308
pub struct Avx2Searcher<N: Needle> {
287309
position: usize,
288-
scalar_hash: ScalarHash,
289310
sse2_hash: VectorHash<__m128i>,
290311
avx2_hash: VectorHash<__m256i>,
291312
needle: N,
@@ -325,41 +346,17 @@ impl<N: Needle> Avx2Searcher<N> {
325346
assert_eq!(size, bytes.len());
326347
}
327348

328-
let scalar_hash = ScalarHash::from(bytes);
329349
let sse2_hash = VectorHash::new(bytes[0], bytes[position]);
330350
let avx2_hash = VectorHash::new(bytes[0], bytes[position]);
331351

332352
Self {
333353
position,
334-
scalar_hash,
335354
sse2_hash,
336355
avx2_hash,
337356
needle,
338357
}
339358
}
340359

341-
#[inline]
342-
fn scalar_search_in(&self, haystack: &[u8]) -> bool {
343-
debug_assert!(haystack.len() >= self.needle.size());
344-
345-
let mut end = self.needle.size() - 1;
346-
let mut hash = ScalarHash::from(&haystack[..end]);
347-
348-
while end < haystack.len() {
349-
hash.push(*unsafe { haystack.get_unchecked(end) });
350-
end += 1;
351-
352-
let start = end - self.needle.size();
353-
if hash == self.scalar_hash && haystack[start..end] == *self.needle.as_bytes() {
354-
return true;
355-
}
356-
357-
hash.pop(*unsafe { haystack.get_unchecked(start) });
358-
}
359-
360-
false
361-
}
362-
363360
#[inline]
364361
#[target_feature(enable = "avx2")]
365362
unsafe fn vector_search_in_chunk<V: Vector>(
@@ -419,17 +416,11 @@ impl<N: Needle> Avx2Searcher<N> {
419416
unsafe fn vector_search_in<V: Vector>(
420417
&self,
421418
haystack: &[u8],
419+
end: usize,
422420
hash: &VectorHash<V>,
423-
next: unsafe fn(&Self, &[u8]) -> bool,
424421
) -> bool {
425422
debug_assert!(haystack.len() >= self.needle.size());
426423

427-
let end = haystack.len() - self.needle.size() + 1;
428-
429-
if end < V::LANES {
430-
return next(self, haystack);
431-
}
432-
433424
let mut chunks = haystack[..end].chunks_exact(V::LANES);
434425
while let Some(chunk) = chunks.next() {
435426
if self.vector_search_in_chunk(haystack, hash, chunk.as_ptr(), -1) {
@@ -452,28 +443,35 @@ impl<N: Needle> Avx2Searcher<N> {
452443

453444
#[inline]
454445
#[target_feature(enable = "avx2")]
455-
unsafe fn sse2_4_search_in(&self, haystack: &[u8]) -> bool {
446+
unsafe fn sse2_2_search_in(&self, haystack: &[u8], end: usize) -> bool {
447+
let hash = VectorHash::<__m16i>::from(&self.sse2_hash);
448+
self.vector_search_in(haystack, end, &hash)
449+
}
450+
451+
#[inline]
452+
#[target_feature(enable = "avx2")]
453+
unsafe fn sse2_4_search_in(&self, haystack: &[u8], end: usize) -> bool {
456454
let hash = VectorHash::<__m32i>::from(&self.sse2_hash);
457-
self.vector_search_in(haystack, &hash, Self::scalar_search_in)
455+
self.vector_search_in(haystack, end, &hash)
458456
}
459457

460458
#[inline]
461459
#[target_feature(enable = "avx2")]
462-
unsafe fn sse2_8_search_in(&self, haystack: &[u8]) -> bool {
460+
unsafe fn sse2_8_search_in(&self, haystack: &[u8], end: usize) -> bool {
463461
let hash = VectorHash::<__m64i>::from(&self.sse2_hash);
464-
self.vector_search_in(haystack, &hash, Self::sse2_4_search_in)
462+
self.vector_search_in(haystack, end, &hash)
465463
}
466464

467465
#[inline]
468466
#[target_feature(enable = "avx2")]
469-
unsafe fn sse2_16_search_in(&self, haystack: &[u8]) -> bool {
470-
self.vector_search_in(haystack, &self.sse2_hash, Self::sse2_8_search_in)
467+
unsafe fn sse2_16_search_in(&self, haystack: &[u8], end: usize) -> bool {
468+
self.vector_search_in(haystack, end, &self.sse2_hash)
471469
}
472470

473471
#[inline]
474472
#[target_feature(enable = "avx2")]
475-
unsafe fn avx2_search_in(&self, haystack: &[u8]) -> bool {
476-
self.vector_search_in(haystack, &self.avx2_hash, Self::sse2_16_search_in)
473+
unsafe fn avx2_search_in(&self, haystack: &[u8], end: usize) -> bool {
474+
self.vector_search_in(haystack, end, &self.avx2_hash)
477475
}
478476

479477
/// Inlined version of `search_in` for hot call sites.
@@ -484,7 +482,21 @@ impl<N: Needle> Avx2Searcher<N> {
484482
return haystack == self.needle.as_bytes();
485483
}
486484

487-
self.avx2_search_in(haystack)
485+
let end = haystack.len() - self.needle.size() + 1;
486+
487+
if end < __m16i::LANES {
488+
unreachable!();
489+
} else if end < __m32i::LANES {
490+
self.sse2_2_search_in(haystack, end)
491+
} else if end < __m64i::LANES {
492+
self.sse2_4_search_in(haystack, end)
493+
} else if end < __m128i::LANES {
494+
self.sse2_8_search_in(haystack, end)
495+
} else if end < __m256i::LANES {
496+
self.sse2_16_search_in(haystack, end)
497+
} else {
498+
self.avx2_search_in(haystack, end)
499+
}
488500
}
489501

490502
/// Performs a substring search for the `needle` within `haystack`.

0 commit comments

Comments
 (0)