@@ -6,16 +6,14 @@ pub use self::{original::*, rust::*};
66use crate :: { bits, memchr:: MemchrSearcher , memcmp} ;
77use std:: {
88 arch:: x86_64:: * ,
9+ mem,
910 ops:: { AddAssign , SubAssign } ,
1011} ;
1112
12- const AVX2_LANES : usize = 32 ;
13- const SSE2_LANES : usize = 16 ;
14-
1513#[ derive( Clone , Copy , Default , PartialEq ) ]
16- struct Hash ( usize ) ;
14+ struct ScalarHash ( usize ) ;
1715
18- impl From < & [ u8 ] > for Hash {
16+ impl From < & [ u8 ] > for ScalarHash {
1917 #[ inline( always) ]
2018 fn from ( bytes : & [ u8 ] ) -> Self {
2119 bytes. iter ( ) . fold ( Default :: default ( ) , |mut hash, & b| {
@@ -25,30 +23,108 @@ impl From<&[u8]> for Hash {
2523 }
2624}
2725
28- impl AddAssign < u8 > for Hash {
26+ impl AddAssign < u8 > for ScalarHash {
2927 #[ inline( always) ]
3028 fn add_assign ( & mut self , b : u8 ) {
3129 self . 0 += usize:: from ( b) ;
3230 }
3331}
3432
35- impl SubAssign < u8 > for Hash {
33+ impl SubAssign < u8 > for ScalarHash {
3634 #[ inline( always) ]
3735 fn sub_assign ( & mut self , b : u8 ) {
3836 self . 0 -= usize:: from ( b) ;
3937 }
4038}
4139
40+ trait Vector : Copy {
41+ unsafe fn set1_epi8 ( a : i8 ) -> Self ;
42+
43+ unsafe fn loadu_si ( a : * const Self ) -> Self ;
44+
45+ unsafe fn cmpeq_epi8 ( a : Self , b : Self ) -> Self ;
46+
47+ unsafe fn and_si ( a : Self , b : Self ) -> Self ;
48+
49+ unsafe fn movemask_epi8 ( a : Self ) -> i32 ;
50+ }
51+
52+ impl Vector for __m128i {
53+ #[ inline( always) ]
54+ unsafe fn set1_epi8 ( a : i8 ) -> Self {
55+ _mm_set1_epi8 ( a)
56+ }
57+
58+ #[ inline( always) ]
59+ unsafe fn loadu_si ( a : * const Self ) -> Self {
60+ _mm_loadu_si128 ( a)
61+ }
62+
63+ #[ inline( always) ]
64+ unsafe fn cmpeq_epi8 ( a : Self , b : Self ) -> Self {
65+ _mm_cmpeq_epi8 ( a, b)
66+ }
67+
68+ #[ inline( always) ]
69+ unsafe fn and_si ( a : Self , b : Self ) -> Self {
70+ _mm_and_si128 ( a, b)
71+ }
72+
73+ #[ inline( always) ]
74+ unsafe fn movemask_epi8 ( a : Self ) -> i32 {
75+ _mm_movemask_epi8 ( a)
76+ }
77+ }
78+
79+ impl Vector for __m256i {
80+ #[ inline( always) ]
81+ unsafe fn set1_epi8 ( a : i8 ) -> Self {
82+ _mm256_set1_epi8 ( a)
83+ }
84+
85+ #[ inline( always) ]
86+ unsafe fn loadu_si ( a : * const Self ) -> Self {
87+ _mm256_loadu_si256 ( a)
88+ }
89+
90+ #[ inline( always) ]
91+ unsafe fn cmpeq_epi8 ( a : Self , b : Self ) -> Self {
92+ _mm256_cmpeq_epi8 ( a, b)
93+ }
94+
95+ #[ inline( always) ]
96+ unsafe fn and_si ( a : Self , b : Self ) -> Self {
97+ _mm256_and_si256 ( a, b)
98+ }
99+
100+ #[ inline( always) ]
101+ unsafe fn movemask_epi8 ( a : Self ) -> i32 {
102+ _mm256_movemask_epi8 ( a)
103+ }
104+ }
105+
106+ struct VectorHash < V : Vector > {
107+ first : V ,
108+ last : V ,
109+ }
110+
111+ impl < V : Vector > VectorHash < V > {
112+ fn new ( first : u8 , last : u8 ) -> Self {
113+ Self {
114+ first : unsafe { Vector :: set1_epi8 ( first as i8 ) } ,
115+ last : unsafe { Vector :: set1_epi8 ( last as i8 ) } ,
116+ }
117+ }
118+ }
119+
42120macro_rules! avx2_searcher {
43121 ( $name: ident, $size: literal, $memcmp: path) => {
44122 pub struct $name {
45123 needle: Box <[ u8 ] >,
46124 position: usize ,
47- hash: Hash ,
48- sse2_first: __m128i,
49- sse2_last: __m128i,
50- avx2_first: __m256i,
51- avx2_last: __m256i,
125+ scalar_hash: ScalarHash ,
126+ sse2_hash: VectorHash <__m128i>,
127+ avx2_hash: VectorHash <__m256i>,
52128 }
53129
54130 impl $name {
@@ -61,20 +137,16 @@ macro_rules! avx2_searcher {
61137 assert!( !needle. is_empty( ) ) ;
62138 assert!( position < needle. len( ) ) ;
63139
64- let hash = Hash :: from( needle. as_ref( ) ) ;
65- let sse2_first = unsafe { _mm_set1_epi8( needle[ 0 ] as i8 ) } ;
66- let sse2_last = unsafe { _mm_set1_epi8( needle[ position] as i8 ) } ;
67- let avx2_first = unsafe { _mm256_set1_epi8( needle[ 0 ] as i8 ) } ;
68- let avx2_last = unsafe { _mm256_set1_epi8( needle[ position] as i8 ) } ;
140+ let scalar_hash = ScalarHash :: from( needle. as_ref( ) ) ;
141+ let sse2_hash = VectorHash :: new( needle[ 0 ] , needle[ position] ) ;
142+ let avx2_hash = VectorHash :: new( needle[ 0 ] , needle[ position] ) ;
69143
70144 Self {
71145 needle,
72146 position,
73- hash,
74- sse2_first,
75- sse2_last,
76- avx2_first,
77- avx2_last,
147+ scalar_hash,
148+ sse2_hash,
149+ avx2_hash,
78150 }
79151 }
80152
@@ -93,14 +165,14 @@ macro_rules! avx2_searcher {
93165 debug_assert!( haystack. len( ) >= self . size( ) ) ;
94166
95167 let mut end = self . size( ) - 1 ;
96- let mut hash = Hash :: from( & haystack[ ..end] ) ;
168+ let mut hash = ScalarHash :: from( & haystack[ ..end] ) ;
97169
98170 while end < haystack. len( ) {
99171 hash += * unsafe { haystack. get_unchecked( end) } ;
100172 end += 1 ;
101173
102174 let start = end - self . size( ) ;
103- if hash == self . hash && haystack[ start..end] == * self . needle {
175+ if hash == self . scalar_hash && haystack[ start..end] == * self . needle {
104176 return true ;
105177 }
106178
@@ -111,22 +183,28 @@ macro_rules! avx2_searcher {
111183 }
112184
113185 #[ inline( always) ]
114- fn sse2_search_in( & self , haystack: & [ u8 ] ) -> bool {
115- if haystack. len( ) < SSE2_LANES {
116- return self . scalar_search_in( haystack) ;
186+ fn vector_search_in<V : Vector >(
187+ & self ,
188+ haystack: & [ u8 ] ,
189+ hash: & VectorHash <V >,
190+ next: fn ( & Self , & [ u8 ] ) -> bool ,
191+ ) -> bool {
192+ let lanes = mem:: size_of:: <V >( ) ;
193+ if haystack. len( ) < lanes {
194+ return next( self , haystack) ;
117195 }
118196
119- let mut chunks = haystack[ ..=haystack. len( ) - self . size( ) ] . chunks_exact( SSE2_LANES ) ;
197+ let mut chunks = haystack[ ..=haystack. len( ) - self . size( ) ] . chunks_exact( lanes ) ;
120198 while let Some ( chunk) = chunks. next( ) {
121199 let start = chunk. as_ptr( ) ;
122- let first = unsafe { _mm_loadu_si128 ( start. cast( ) ) } ;
123- let last = unsafe { _mm_loadu_si128 ( start. add( self . position) . cast( ) ) } ;
200+ let first = unsafe { Vector :: loadu_si ( start. cast( ) ) } ;
201+ let last = unsafe { Vector :: loadu_si ( start. add( self . position) . cast( ) ) } ;
124202
125- let mask_first = unsafe { _mm_cmpeq_epi8 ( self . sse2_first , first) } ;
126- let mask_last = unsafe { _mm_cmpeq_epi8 ( self . sse2_last , last) } ;
203+ let mask_first = unsafe { Vector :: cmpeq_epi8 ( hash . first , first) } ;
204+ let mask_last = unsafe { Vector :: cmpeq_epi8 ( hash . last , last) } ;
127205
128- let mask = unsafe { _mm_and_si128 ( mask_first, mask_last) } ;
129- let mut mask = unsafe { _mm_movemask_epi8 ( mask) } as u32 ;
206+ let mask = unsafe { Vector :: and_si ( mask_first, mask_last) } ;
207+ let mut mask = unsafe { Vector :: movemask_epi8 ( mask) } as u32 ;
130208
131209 let start = start as usize - haystack. as_ptr( ) as usize ;
132210 while mask != 0 {
@@ -140,46 +218,20 @@ macro_rules! avx2_searcher {
140218 }
141219
142220 let remainder = chunks. remainder( ) ;
143- debug_assert!( remainder. len( ) < SSE2_LANES ) ;
221+ debug_assert!( remainder. len( ) < lanes ) ;
144222
145223 let chunk = & haystack[ remainder. as_ptr( ) as usize - haystack. as_ptr( ) as usize ..] ;
146- self . scalar_search_in ( chunk)
224+ next ( self , chunk)
147225 }
148226
149227 #[ inline( always) ]
150- fn avx2_search_in( & self , haystack: & [ u8 ] ) -> bool {
151- if haystack. len( ) < AVX2_LANES {
152- return self . sse2_search_in( haystack) ;
153- }
154-
155- let mut chunks = haystack[ ..=haystack. len( ) - self . size( ) ] . chunks_exact( AVX2_LANES ) ;
156- while let Some ( chunk) = chunks. next( ) {
157- let start = chunk. as_ptr( ) ;
158- let first = unsafe { _mm256_loadu_si256( start. cast( ) ) } ;
159- let last = unsafe { _mm256_loadu_si256( start. add( self . position) . cast( ) ) } ;
160-
161- let mask_first = unsafe { _mm256_cmpeq_epi8( self . avx2_first, first) } ;
162- let mask_last = unsafe { _mm256_cmpeq_epi8( self . avx2_last, last) } ;
163-
164- let mask = unsafe { _mm256_and_si256( mask_first, mask_last) } ;
165- let mut mask = unsafe { _mm256_movemask_epi8( mask) } as u32 ;
166-
167- let start = start as usize - haystack. as_ptr( ) as usize ;
168- while mask != 0 {
169- let chunk = & haystack[ start + mask. trailing_zeros( ) as usize ..] ;
170- if unsafe { $memcmp( & chunk[ 1 ..self . size( ) ] , & self . needle[ 1 ..] ) } {
171- return true ;
172- }
173-
174- mask = bits:: clear_leftmost_set( mask) ;
175- }
176- }
177-
178- let remainder = chunks. remainder( ) ;
179- debug_assert!( remainder. len( ) < AVX2_LANES ) ;
228+ fn sse2_search_in( & self , haystack: & [ u8 ] ) -> bool {
229+ self . vector_search_in( haystack, & self . sse2_hash, Self :: scalar_search_in)
230+ }
180231
181- let chunk = & haystack[ remainder. as_ptr( ) as usize - haystack. as_ptr( ) as usize ..] ;
182- self . sse2_search_in( chunk)
232+ #[ inline( always) ]
233+ fn avx2_search_in( & self , haystack: & [ u8 ] ) -> bool {
234+ self . vector_search_in( haystack, & self . avx2_hash, Self :: sse2_search_in)
183235 }
184236
185237 #[ inline( always) ]
0 commit comments