Skip to content

Commit f51c5e3

Browse files
zakcutnermarmeladema
authored andcommitted
Use custom needle position in original AVX2 implementation for fair comparison
1 parent 0d0ba37 commit f51c5e3

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

src/avx2/rust.rs

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ pub fn strstr_avx2_rust_fast_2(haystack: &[u8], needle: &[u8]) -> bool {
302302
#[cfg(target_feature = "avx2")]
303303
pub struct StrStrAVX2Searcher {
304304
needle: Box<[u8]>,
305+
position: usize,
305306
sse_first: __m128i,
306307
sse_last: __m128i,
307308
avx2_first: __m256i,
@@ -312,16 +313,21 @@ pub struct StrStrAVX2Searcher {
312313
#[cfg(target_feature = "avx2")]
313314
impl StrStrAVX2Searcher {
314315
pub fn new(needle: &[u8]) -> Self {
316+
Self::with_position(needle, needle.len() - 1)
317+
}
318+
319+
pub fn with_position(needle: &[u8], position: usize) -> Self {
315320
let mut needle_sum = 0_usize;
316321
for &c in needle {
317322
needle_sum += c as usize;
318323
}
319324
StrStrAVX2Searcher {
320325
needle: needle.to_vec().into_boxed_slice(),
326+
position,
321327
sse_first: unsafe { _mm_set1_epi8(needle[0] as i8) },
322-
sse_last: unsafe { _mm_set1_epi8(needle[needle.len() - 1] as i8) },
328+
sse_last: unsafe { _mm_set1_epi8(needle[position] as i8) },
323329
avx2_first: unsafe { _mm256_set1_epi8(needle[0] as i8) },
324-
avx2_last: unsafe { _mm256_set1_epi8(needle[needle.len() - 1] as i8) },
330+
avx2_last: unsafe { _mm256_set1_epi8(needle[position] as i8) },
325331
needle_sum,
326332
}
327333
}
@@ -333,19 +339,18 @@ impl StrStrAVX2Searcher {
333339
match self.needle.len() {
334340
0 => true,
335341
1 => memchr::memchr(self.needle[0], haystack).is_some(),
336-
2 => unsafe { self.avx2_memcmp(haystack, memcmp0) },
337-
3 => unsafe { self.avx2_memcmp(haystack, memcmp1) },
338-
4 => unsafe { self.avx2_memcmp(haystack, memcmp2) },
342+
2 => unsafe { self.avx2_memcmp(haystack, memcmp1) },
343+
3 => unsafe { self.avx2_memcmp(haystack, memcmp2) },
344+
4 => unsafe { self.avx2_memcmp(haystack, memcmp3) },
339345
5 => unsafe { self.avx2_memcmp(haystack, memcmp4) },
340-
6 => unsafe { self.avx2_memcmp(haystack, memcmp4) },
341-
7 => unsafe { self.avx2_memcmp(haystack, memcmp5) },
342-
8 => unsafe { self.avx2_memcmp(haystack, memcmp6) },
346+
6 => unsafe { self.avx2_memcmp(haystack, memcmp5) },
347+
7 => unsafe { self.avx2_memcmp(haystack, memcmp6) },
348+
8 => unsafe { self.avx2_memcmp(haystack, memcmp7) },
343349
9 => unsafe { self.avx2_memcmp(haystack, memcmp8) },
344-
10 => unsafe { self.avx2_memcmp(haystack, memcmp8) },
345-
11 => unsafe { self.avx2_memcmp(haystack, memcmp9) },
346-
12 => unsafe { self.avx2_memcmp(haystack, memcmp10) },
347-
13 => unsafe { self.avx2_memcmp(haystack, memcmp11) },
348-
14 => unsafe { self.avx2_memcmp(haystack, memcmp12) },
350+
10 => unsafe { self.avx2_memcmp(haystack, memcmp9) },
351+
11 => unsafe { self.avx2_memcmp(haystack, memcmp10) },
352+
12 => unsafe { self.avx2_memcmp(haystack, memcmp11) },
353+
13 => unsafe { self.avx2_memcmp(haystack, memcmp12) },
349354
_ => unsafe { self.avx2_memcmp(haystack, memcmp) },
350355
}
351356
}
@@ -359,22 +364,19 @@ impl StrStrAVX2Searcher {
359364
while let Some(chunk) = chunks.next() {
360365
let i = chunk.as_ptr() as usize - haystack.as_ptr() as usize;
361366
let block_first = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
362-
let block_last =
363-
_mm_loadu_si128(chunk.as_ptr().add(self.needle.len() - 1) as *const __m128i);
367+
let block_last = _mm_loadu_si128(chunk.as_ptr().add(self.position) as *const __m128i);
364368

365369
let eq_first = _mm_cmpeq_epi8(self.sse_first, block_first);
366370
let eq_last = _mm_cmpeq_epi8(self.sse_last, block_last);
367371

368-
let mut mask = std::mem::transmute::<i32, u32>(_mm_movemask_epi8(_mm_and_si128(
369-
eq_first, eq_last,
370-
)));
372+
let mut mask = _mm_movemask_epi8(_mm_and_si128(eq_first, eq_last)) as u32;
371373
while mask != 0 {
372374
let bitpos = mask.trailing_zeros() as usize;
373375
let startpos = i + bitpos;
374376
if startpos + self.needle.len() <= haystack.len()
375377
&& memcmp(
376-
&haystack[(startpos + 1)..(startpos + self.needle.len() - 1)],
377-
&self.needle[1..self.needle.len() - 1],
378+
&haystack[(startpos + 1)..(startpos + self.needle.len())],
379+
&self.needle[1..self.needle.len()],
378380
)
379381
{
380382
return true;
@@ -403,7 +405,7 @@ impl StrStrAVX2Searcher {
403405
let i = chunk.as_ptr() as usize - haystack.as_ptr() as usize;
404406
let block_first = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
405407
let block_last =
406-
_mm256_loadu_si256(chunk.as_ptr().add(self.needle.len() - 1) as *const __m256i);
408+
_mm256_loadu_si256(chunk.as_ptr().add(self.position) as *const __m256i);
407409

408410
let eq_first = _mm256_cmpeq_epi8(self.avx2_first, block_first);
409411
let eq_last = _mm256_cmpeq_epi8(self.avx2_last, block_last);
@@ -416,8 +418,8 @@ impl StrStrAVX2Searcher {
416418
let startpos = i + bitpos;
417419
if startpos + self.needle.len() <= haystack.len()
418420
&& memcmp(
419-
&haystack[(startpos + 1)..(startpos + self.needle.len() - 1)],
420-
&self.needle[1..self.needle.len() - 1],
421+
&haystack[(startpos + 1)..(startpos + self.needle.len())],
422+
&self.needle[1..self.needle.len()],
421423
)
422424
{
423425
return true;

0 commit comments

Comments
 (0)