@@ -25,33 +25,6 @@ trait NeedleWithSize: Needle {
2525
2626impl < N : Needle + ?Sized > NeedleWithSize for N { }
2727
28- /// Rolling hash for the simple Rabin-Karp implementation. As a hashing
29- /// function, the sum of all the bytes is computed.
30- #[ derive( Clone , Copy , Default , PartialEq ) ]
31- struct ScalarHash ( usize ) ;
32-
33- impl From < & [ u8 ] > for ScalarHash {
34- #[ inline]
35- fn from ( bytes : & [ u8 ] ) -> Self {
36- bytes. iter ( ) . fold ( Default :: default ( ) , |mut hash, & b| {
37- hash. push ( b) ;
38- hash
39- } )
40- }
41- }
42-
43- impl ScalarHash {
44- #[ inline]
45- fn push ( & mut self , b : u8 ) {
46- self . 0 = self . 0 . wrapping_add ( b. into ( ) ) ;
47- }
48-
49- #[ inline]
50- fn pop ( & mut self , b : u8 ) {
51- self . 0 = self . 0 . wrapping_sub ( b. into ( ) ) ;
52- }
53- }
54-
5528/// Represents an SIMD register type that is x86-specific (but could be used
5629/// more generically) in order to share functionality between SSE2, AVX2 and
5730/// possibly future implementations.
@@ -69,6 +42,45 @@ trait Vector: Copy {
6942 unsafe fn movemask_epi8 ( a : Self ) -> i32 ;
7043}
7144
45+ #[ derive( Clone , Copy ) ]
46+ #[ repr( transparent) ]
47+ #[ allow( non_camel_case_types) ]
48+ struct __m16i ( __m128i ) ;
49+
50+ impl Vector for __m16i {
51+ const LANES : usize = 2 ;
52+
53+ #[ inline]
54+ #[ target_feature( enable = "avx2" ) ]
55+ unsafe fn set1_epi8 ( a : i8 ) -> Self {
56+ __m16i ( _mm_set1_epi8 ( a) )
57+ }
58+
59+ #[ inline]
60+ #[ target_feature( enable = "avx2" ) ]
61+ unsafe fn loadu_si ( a : * const Self ) -> Self {
62+ __m16i ( _mm_set1_epi16 ( std:: ptr:: read_unaligned ( a as * const i16 ) ) )
63+ }
64+
65+ #[ inline]
66+ #[ target_feature( enable = "avx2" ) ]
67+ unsafe fn cmpeq_epi8 ( a : Self , b : Self ) -> Self {
68+ __m16i ( _mm_cmpeq_epi8 ( a. 0 , b. 0 ) )
69+ }
70+
71+ #[ inline]
72+ #[ target_feature( enable = "avx2" ) ]
73+ unsafe fn and_si ( a : Self , b : Self ) -> Self {
74+ __m16i ( _mm_and_si128 ( a. 0 , b. 0 ) )
75+ }
76+
77+ #[ inline]
78+ #[ target_feature( enable = "avx2" ) ]
79+ unsafe fn movemask_epi8 ( a : Self ) -> i32 {
80+ _mm_movemask_epi8 ( a. 0 ) & 0x3
81+ }
82+ }
83+
7284#[ derive( Clone , Copy ) ]
7385#[ repr( transparent) ]
7486#[ allow( non_camel_case_types) ]
@@ -254,6 +266,16 @@ impl From<&VectorHash<__m128i>> for VectorHash<__m32i> {
254266 }
255267}
256268
269+ impl From < & VectorHash < __m128i > > for VectorHash < __m16i > {
270+ #[ inline]
271+ fn from ( hash : & VectorHash < __m128i > ) -> Self {
272+ Self {
273+ first : __m16i ( hash. first ) ,
274+ last : __m16i ( hash. last ) ,
275+ }
276+ }
277+ }
278+
257279/// Single-substring searcher using an AVX2 algorithm based on the "Generic
258280/// SIMD" algorithm [presented by Wojciech
259281/// Muła](http://0x80.pl/articles/simd-strfind.html).
@@ -285,7 +307,6 @@ impl From<&VectorHash<__m128i>> for VectorHash<__m32i> {
285307/// Rabin-Karp implementation.
286308pub struct Avx2Searcher < N : Needle > {
287309 position : usize ,
288- scalar_hash : ScalarHash ,
289310 sse2_hash : VectorHash < __m128i > ,
290311 avx2_hash : VectorHash < __m256i > ,
291312 needle : N ,
@@ -325,41 +346,17 @@ impl<N: Needle> Avx2Searcher<N> {
325346 assert_eq ! ( size, bytes. len( ) ) ;
326347 }
327348
328- let scalar_hash = ScalarHash :: from ( bytes) ;
329349 let sse2_hash = VectorHash :: new ( bytes[ 0 ] , bytes[ position] ) ;
330350 let avx2_hash = VectorHash :: new ( bytes[ 0 ] , bytes[ position] ) ;
331351
332352 Self {
333353 position,
334- scalar_hash,
335354 sse2_hash,
336355 avx2_hash,
337356 needle,
338357 }
339358 }
340359
341- #[ inline]
342- fn scalar_search_in ( & self , haystack : & [ u8 ] ) -> bool {
343- debug_assert ! ( haystack. len( ) >= self . needle. size( ) ) ;
344-
345- let mut end = self . needle . size ( ) - 1 ;
346- let mut hash = ScalarHash :: from ( & haystack[ ..end] ) ;
347-
348- while end < haystack. len ( ) {
349- hash. push ( * unsafe { haystack. get_unchecked ( end) } ) ;
350- end += 1 ;
351-
352- let start = end - self . needle . size ( ) ;
353- if hash == self . scalar_hash && haystack[ start..end] == * self . needle . as_bytes ( ) {
354- return true ;
355- }
356-
357- hash. pop ( * unsafe { haystack. get_unchecked ( start) } ) ;
358- }
359-
360- false
361- }
362-
363360 #[ inline]
364361 #[ target_feature( enable = "avx2" ) ]
365362 unsafe fn vector_search_in_chunk < V : Vector > (
@@ -419,17 +416,11 @@ impl<N: Needle> Avx2Searcher<N> {
419416 unsafe fn vector_search_in < V : Vector > (
420417 & self ,
421418 haystack : & [ u8 ] ,
419+ end : usize ,
422420 hash : & VectorHash < V > ,
423- next : unsafe fn ( & Self , & [ u8 ] ) -> bool ,
424421 ) -> bool {
425422 debug_assert ! ( haystack. len( ) >= self . needle. size( ) ) ;
426423
427- let end = haystack. len ( ) - self . needle . size ( ) + 1 ;
428-
429- if end < V :: LANES {
430- return next ( self , haystack) ;
431- }
432-
433424 let mut chunks = haystack[ ..end] . chunks_exact ( V :: LANES ) ;
434425 while let Some ( chunk) = chunks. next ( ) {
435426 if self . vector_search_in_chunk ( haystack, hash, chunk. as_ptr ( ) , -1 ) {
@@ -452,28 +443,35 @@ impl<N: Needle> Avx2Searcher<N> {
452443
453444 #[ inline]
454445 #[ target_feature( enable = "avx2" ) ]
455- unsafe fn sse2_4_search_in ( & self , haystack : & [ u8 ] ) -> bool {
446+ unsafe fn sse2_2_search_in ( & self , haystack : & [ u8 ] , end : usize ) -> bool {
447+ let hash = VectorHash :: < __m16i > :: from ( & self . sse2_hash ) ;
448+ self . vector_search_in ( haystack, end, & hash)
449+ }
450+
451+ #[ inline]
452+ #[ target_feature( enable = "avx2" ) ]
453+ unsafe fn sse2_4_search_in ( & self , haystack : & [ u8 ] , end : usize ) -> bool {
456454 let hash = VectorHash :: < __m32i > :: from ( & self . sse2_hash ) ;
457- self . vector_search_in ( haystack, & hash , Self :: scalar_search_in )
455+ self . vector_search_in ( haystack, end , & hash )
458456 }
459457
460458 #[ inline]
461459 #[ target_feature( enable = "avx2" ) ]
462- unsafe fn sse2_8_search_in ( & self , haystack : & [ u8 ] ) -> bool {
460+ unsafe fn sse2_8_search_in ( & self , haystack : & [ u8 ] , end : usize ) -> bool {
463461 let hash = VectorHash :: < __m64i > :: from ( & self . sse2_hash ) ;
464- self . vector_search_in ( haystack, & hash , Self :: sse2_4_search_in )
462+ self . vector_search_in ( haystack, end , & hash )
465463 }
466464
467465 #[ inline]
468466 #[ target_feature( enable = "avx2" ) ]
469- unsafe fn sse2_16_search_in ( & self , haystack : & [ u8 ] ) -> bool {
470- self . vector_search_in ( haystack, & self . sse2_hash , Self :: sse2_8_search_in )
467+ unsafe fn sse2_16_search_in ( & self , haystack : & [ u8 ] , end : usize ) -> bool {
468+ self . vector_search_in ( haystack, end , & self . sse2_hash )
471469 }
472470
473471 #[ inline]
474472 #[ target_feature( enable = "avx2" ) ]
475- unsafe fn avx2_search_in ( & self , haystack : & [ u8 ] ) -> bool {
476- self . vector_search_in ( haystack, & self . avx2_hash , Self :: sse2_16_search_in )
473+ unsafe fn avx2_search_in ( & self , haystack : & [ u8 ] , end : usize ) -> bool {
474+ self . vector_search_in ( haystack, end , & self . avx2_hash )
477475 }
478476
479477 /// Inlined version of `search_in` for hot call sites.
@@ -484,7 +482,21 @@ impl<N: Needle> Avx2Searcher<N> {
484482 return haystack == self . needle . as_bytes ( ) ;
485483 }
486484
487- self . avx2_search_in ( haystack)
485+ let end = haystack. len ( ) - self . needle . size ( ) + 1 ;
486+
487+ if end < __m16i:: LANES {
488+ unreachable ! ( ) ;
489+ } else if end < __m32i:: LANES {
490+ self . sse2_2_search_in ( haystack, end)
491+ } else if end < __m64i:: LANES {
492+ self . sse2_4_search_in ( haystack, end)
493+ } else if end < __m128i:: LANES {
494+ self . sse2_8_search_in ( haystack, end)
495+ } else if end < __m256i:: LANES {
496+ self . sse2_16_search_in ( haystack, end)
497+ } else {
498+ self . avx2_search_in ( haystack, end)
499+ }
488500 }
489501
490502 /// Performs a substring search for the `needle` within `haystack`.
0 commit comments