@@ -3,9 +3,6 @@ use crate::engine::Config;
33use crate :: engine:: DecodeEstimate ;
44use crate :: { DecodeError , PAD_BYTE } ;
55
6- use core:: marker:: PhantomData ;
7-
8-
96#[ cfg( target_arch = "x86" ) ]
107use core:: arch:: x86:: * ;
118#[ cfg( target_arch = "x86_64" ) ]
@@ -28,41 +25,75 @@ const DECODED_CHUNK_LEN: usize = 24;
2825/// - It has to use unsafe code because intrinsics are always unsafe in Rust.
2926/// - The algorithm in use makes specific assumptions about the alphabet, so it's only implemented
3027/// for the STANDARD and URL_SAFE Alphabet
31- pub struct AVX2Encoder < A > {
28+ pub struct AVX2Encoder {
3229 config : AVX2Config ,
33- alp : PhantomData < A > ,
34- }
3530
36- impl < A > AVX2Encoder < A > {
37- /// Create an AVX2Encoder from a given config.
38- ///
39- /// You can either select the Alphabet by defining the type specifically:
40- /// ```rust
41- /// let engine: AVX2Encoder<Standard> = AVX2Encoder::from(AVX2Config::default());
42- /// ```
43- /// or by calling one of the associated functions [`from_standard`] and [`from_urlsafe`].
44- pub const fn from ( config : AVX2Config ) -> Self {
45- Self {
46- config,
47- alp : PhantomData ,
48- }
49- }
31+ // Alphabet LUT for serial steps
32+ encode_table : [ u8 ; 64 ] ,
33+ decode_table : [ u8 ; 256 ] ,
34+
35+ // Alphabet LUT for vectorized steps
36+ encode_offsets : __m256i ,
37+ decode_offsets : __m256i ,
5038}
51- impl AVX2Encoder < Standard > {
52- /// Create an AVX2Encoder for the STANDARD alphabet with the given config.
53- pub const fn from_standard ( config : AVX2Config ) -> Self {
39+
40+ impl AVX2Encoder {
41+ /// Create an AVX2Encoder for the standard Alphabet from a given config.
42+ /// You can create one for urlsafe with the associated function [`from_urlsafe`].
43+ pub fn from_standard ( config : AVX2Config ) -> Self {
44+ let encode_offsets = unsafe {
45+ _mm256_setr_epi8 (
46+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
47+ -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
48+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
49+ -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
50+ )
51+ } ;
52+
53+ let decode_offsets = unsafe {
54+ _mm256_setr_epi8 (
55+ 0 , 16 , 19 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
56+ 0 , 16 , 19 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
57+ )
58+ } ;
59+
5460 Self {
5561 config,
56- alp : PhantomData ,
62+
63+ encode_table : ENCODE_TABLE ,
64+ decode_table : DECODE_TABLE ,
65+
66+ encode_offsets,
67+ decode_offsets,
5768 }
5869 }
59- }
60- impl AVX2Encoder < Urlsafe > {
61- /// Create an AVX2Encoder for the STANDARD alphabet with the given config.
62- pub const fn from_url_safe ( config : AVX2Config ) -> Self {
70+ /// Create an AVX2Encoder for the urlsafe alphabet with the given config.
71+ /// You can create one for standard with the associated function [`from_standard`].
72+ pub fn from_url_safe ( config : AVX2Config ) -> Self {
73+ let encode_offsets = unsafe {
74+ _mm256_setr_epi8 (
75+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
76+ -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
77+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
78+ -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
79+ )
80+ } ;
81+
82+ let decode_offsets = unsafe {
83+ _mm256_setr_epi8 (
84+ 0 , -32 , 17 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
85+ 0 , -32 , 17 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
86+ )
87+ } ;
88+
6389 Self {
6490 config,
65- alp : PhantomData ,
91+
92+ encode_table : URL_ENCODE_TABLE ,
93+ decode_table : URL_DECODE_TABLE ,
94+
95+ encode_offsets,
96+ decode_offsets,
6697 }
6798 }
6899}
@@ -95,8 +126,9 @@ impl DecodeEstimate for AVX2Estimate {
95126}
96127
97128
98- #[ inline]
129+ #[ inline( always ) ]
99130unsafe fn load_block ( input : __m256i ) -> __m256i {
131+ // TODO: Explain this load shuffle
100132 let i: __m256i = _mm256_shuffle_epi8 ( input, _mm256_set_epi8 (
101133 10 , 11 , 9 , 10 , 7 , 8 , 6 , 7 , 4 , 5 , 3 , 4 , 1 , 2 , 0 , 1 ,
102134 14 , 15 , 13 , 14 , 11 , 12 , 10 , 11 , 8 , 9 , 7 , 8 , 5 , 6 , 4 , 5
@@ -117,6 +149,7 @@ unsafe fn decode(
117149 mask_2f : __m256i ,
118150 block : __m256i
119151) -> __m256i {
152+ // TODO: Explain this decode step
120153 let hi_nibbles = _mm256_srli_epi32 ( block, 4 ) ;
121154 let lo_nibbles = _mm256_and_si256 ( block, mask_2f) ;
122155 let lo = _mm256_shuffle_epi8 ( lut_lo, lo_nibbles) ;
@@ -146,6 +179,10 @@ unsafe fn decode(
146179}
147180
148181#[ inline( always) ]
182+ /// decode_masked is a version of decode specialized for partial input.
183+ /// The only difference between it and the unmasked version is that the test that checks for
184+ /// invalid bytes (which is `a AND b` over a,b := 256-bit vector) gets the same input mask applied,
185+ /// since `0` bytes would in fact be an invalid input.
149186unsafe fn decode_masked (
150187 invalid : & mut bool ,
151188 lut_lo : __m256i ,
@@ -188,65 +225,14 @@ unsafe fn decode_masked(
188225 ) )
189226}
190227
191- #[ doc( hidden) ]
192- pub trait AvxAlp : Send + Sync {
193- unsafe fn encode ( input : __m256i ) -> __m256i ;
194- fn encode_table ( ) -> & ' static [ u8 ; 64 ] ;
195- fn decode_table ( ) -> & ' static [ u8 ; 256 ] ;
196- }
197-
198- #[ doc( hidden) ]
199- pub struct Standard ;
200- impl AvxAlp for Standard {
201- #[ inline]
202- unsafe fn encode ( input : __m256i ) -> __m256i {
203- let mut result: __m256i = _mm256_subs_epu8 ( input, _mm256_set1_epi8 ( 51 ) ) ;
204- let less: __m256i = _mm256_cmpgt_epi8 ( _mm256_set1_epi8 ( 26 ) , input) ;
205- result = _mm256_or_si256 ( result, _mm256_and_si256 ( less, _mm256_set1_epi8 ( 13 ) ) ) ;
206- let offsets: __m256i = _mm256_setr_epi8 (
207- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
208- -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
209- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
210- -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
211- ) ;
212- result = _mm256_shuffle_epi8 ( offsets, result) ;
213- return _mm256_add_epi8 ( result, input) ;
214- }
215-
216- fn encode_table ( ) -> & ' static [ u8 ; 64 ] {
217- & ENCODE_TABLE
218- }
219-
220- fn decode_table ( ) -> & ' static [ u8 ; 256 ] {
221- & DECODE_TABLE
222- }
223- }
224-
225- #[ doc( hidden) ]
226- pub struct Urlsafe ;
227- impl AvxAlp for Urlsafe {
228- #[ inline]
229- unsafe fn encode ( input : __m256i ) -> __m256i {
230- let mut result: __m256i = _mm256_subs_epu8 ( input, _mm256_set1_epi8 ( 51 ) ) ;
231- let less: __m256i = _mm256_cmpgt_epi8 ( _mm256_set1_epi8 ( 26 ) , input) ;
232- result = _mm256_or_si256 ( result, _mm256_and_si256 ( less, _mm256_set1_epi8 ( 13 ) ) ) ;
233- let offsets: __m256i = _mm256_setr_epi8 (
234- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
235- -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
236- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
237- -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
238- ) ;
239- result = _mm256_shuffle_epi8 ( offsets, result) ;
240- return _mm256_add_epi8 ( result, input) ;
241- }
242-
243- fn encode_table ( ) -> & ' static [ u8 ; 64 ] {
244- & URL_ENCODE_TABLE
245- }
246228
247- fn decode_table ( ) -> & ' static [ u8 ; 256 ] {
248- & URL_DECODE_TABLE
249- }
229+ #[ inline]
230+ unsafe fn encode ( offsets : __m256i , input : __m256i ) -> __m256i {
231+ let mut result: __m256i = _mm256_subs_epu8 ( input, _mm256_set1_epi8 ( 51 ) ) ;
232+ let less: __m256i = _mm256_cmpgt_epi8 ( _mm256_set1_epi8 ( 26 ) , input) ;
233+ result = _mm256_or_si256 ( result, _mm256_and_si256 ( less, _mm256_set1_epi8 ( 13 ) ) ) ;
234+ result = _mm256_shuffle_epi8 ( offsets, result) ;
235+ return _mm256_add_epi8 ( result, input) ;
250236}
251237
252238const ENCODE_TABLE : [ u8 ; 64 ] =
@@ -260,7 +246,7 @@ const URL_DECODE_TABLE: [u8; 256] =
260246
261247const MASKLOAD : [ i32 ; 16 ] = [ -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] ;
262248
263- impl < A : AvxAlp > super :: Engine for AVX2Encoder < A > {
249+ impl super :: Engine for AVX2Encoder {
264250 type Config = AVX2Config ;
265251 type DecodeEstimate = AVX2Estimate ;
266252
@@ -306,7 +292,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
306292 _mm256_set_epi32 ( SKIP , LOAD , LOAD , LOAD , LOAD , LOAD , LOAD , SKIP ) ) ;
307293
308294 let expanded: __m256i = load_block ( block) ;
309- let outblock: __m256i = A :: encode ( expanded) ;
295+ let outblock: __m256i = encode ( self . encode_offsets , expanded) ;
310296 _mm256_storeu_si256 ( output_chunk. as_mut_ptr ( ) . cast ( ) , outblock) ;
311297
312298 output_index += 32 ;
@@ -328,7 +314,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
328314 // First step: Expand the 24 input bytes into 32 bytes ready for encoding.
329315 let expanded: __m256i = load_block ( block) ;
330316 // Second step: Do the actual conversion
331- let outblock: __m256i = A :: encode ( expanded) ;
317+ let outblock: __m256i = encode ( self . encode_offsets , expanded) ;
332318 // Third step: Write the data into the output
333319 _mm256_storeu_si256 ( output_chunk. as_mut_ptr ( ) . cast ( ) , outblock) ;
334320
@@ -356,18 +342,16 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
356342
357343 const LOW_SIX_BITS_U8 : u8 = 0b111111 ;
358344
359- let encode_table = A :: encode_table ( ) ;
360-
361345 while input_index < start_of_rem {
362346 let input_chunk = & input[ input_index..( input_index + 3 ) ] ;
363347 let output_chunk = & mut output[ output_index..( output_index + 4 ) ] ;
364348
365- output_chunk[ 0 ] = encode_table[ ( input_chunk[ 0 ] >> 2 ) as usize ] ;
366- output_chunk[ 1 ] = encode_table
349+ output_chunk[ 0 ] = self . encode_table [ ( input_chunk[ 0 ] >> 2 ) as usize ] ;
350+ output_chunk[ 1 ] = self . encode_table
367351 [ ( ( input_chunk[ 0 ] << 4 | input_chunk[ 1 ] >> 4 ) & LOW_SIX_BITS_U8 ) as usize ] ;
368- output_chunk[ 2 ] = encode_table
352+ output_chunk[ 2 ] = self . encode_table
369353 [ ( ( input_chunk[ 1 ] << 2 | input_chunk[ 2 ] >> 6 ) & LOW_SIX_BITS_U8 ) as usize ] ;
370- output_chunk[ 3 ] = encode_table[ ( input_chunk[ 2 ] & LOW_SIX_BITS_U8 ) as usize ] ;
354+ output_chunk[ 3 ] = self . encode_table [ ( input_chunk[ 2 ] & LOW_SIX_BITS_U8 ) as usize ] ;
371355
372356 input_index += 3 ;
373357 output_index += 4 ;
@@ -377,18 +361,18 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
377361
378362 if rem == 2 {
379363 let final_input = input. len ( ) -2 ;
380- output[ output_index] = encode_table[ ( input[ final_input] >> 2 ) as usize ] ;
364+ output[ output_index] = self . encode_table [ ( input[ final_input] >> 2 ) as usize ] ;
381365 output[ output_index + 1 ] =
382- encode_table[ ( ( input[ final_input] << 4 | input[ final_input + 1 ] >> 4 )
366+ self . encode_table [ ( ( input[ final_input] << 4 | input[ final_input + 1 ] >> 4 )
383367 & LOW_SIX_BITS_U8 ) as usize ] ;
384368 output[ output_index + 2 ] =
385- encode_table[ ( ( input[ final_input + 1 ] << 2 ) & LOW_SIX_BITS_U8 ) as usize ] ;
369+ self . encode_table [ ( ( input[ final_input + 1 ] << 2 ) & LOW_SIX_BITS_U8 ) as usize ] ;
386370 output_index += 3 ;
387371 } else if rem == 1 {
388372 let final_input = input. len ( ) -1 ;
389- output[ output_index] = encode_table[ ( input[ final_input] >> 2 ) as usize ] ;
373+ output[ output_index] = self . encode_table [ ( input[ final_input] >> 2 ) as usize ] ;
390374 output[ output_index + 1 ] =
391- encode_table[ ( ( input[ final_input] << 4 ) & LOW_SIX_BITS_U8 ) as usize ] ;
375+ self . encode_table [ ( ( input[ final_input] << 4 ) & LOW_SIX_BITS_U8 ) as usize ] ;
392376 output_index += 2 ;
393377 }
394378
@@ -405,14 +389,13 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
405389 output : & mut [ u8 ] ,
406390 _estimate : Self :: DecodeEstimate ,
407391 ) -> Result < usize , DecodeError > {
408- let decode_table = A :: decode_table ( ) ;
409392 // TODO: Check if LLVM optimizes this modulo into an &
410393 let skip_stage_2 = match input. len ( ) % 4 {
411394 1 => {
412395 // trailing whitespace is so common that it's worth it to check the last byte to
413396 // possibly return a better error message
414397 if let Some ( b) = input. last ( ) {
415- if * b != PAD_BYTE && decode_table[ * b as usize ] == INVALID_VALUE {
398+ if * b != PAD_BYTE && self . decode_table [ * b as usize ] == INVALID_VALUE {
416399 return Err ( DecodeError :: InvalidByte ( input. len ( ) - 1 , * b) ) ;
417400 }
418401 }
@@ -487,10 +470,6 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
487470 0x10 , 0x10 , 0x01 , 0x02 , 0x04 , 0x08 , 0x04 , 0x08 ,
488471 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10
489472 ) } ;
490- let lut_roll = unsafe { _mm256_setr_epi8 (
491- 0 , 16 , 19 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
492- 0 , 16 , 19 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
493- ) } ;
494473 let mask_2f = unsafe { _mm256_set1_epi8 ( 0x2F ) } ;
495474
496475 // This will only evaluate to true if we have an input of 33 bytes or more;
@@ -514,10 +493,10 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
514493
515494 unsafe {
516495 block = _mm256_loadu_si256 ( input_chunk. as_ptr ( ) . cast ( ) ) ;
517- block = decode ( & mut invalid, lut_lo, lut_hi, lut_roll , mask_2f, block) ;
496+ block = decode ( & mut invalid, lut_lo, lut_hi, self . decode_offsets , mask_2f, block) ;
518497
519498 if invalid {
520- return Err ( find_invalid_input ( input_index, input_chunk, decode_table) ) ;
499+ return Err ( find_invalid_input ( input_index, input_chunk, & self . decode_table ) ) ;
521500 }
522501
523502 _mm256_storeu_si256 ( output_chunk. as_mut_ptr ( ) . cast ( ) , block) ;
@@ -551,12 +530,12 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
551530 let mask_output = _mm256_loadu_si256 ( MASKLOAD [ 2 ..10 ] . as_ptr ( ) . cast ( ) ) ;
552531
553532 block = _mm256_loadu_si256 ( input_chunk. as_ptr ( ) . cast ( ) ) ;
554- block = decode ( & mut invalid, lut_lo, lut_hi, lut_roll , mask_2f, block) ;
533+ block = decode ( & mut invalid, lut_lo, lut_hi, self . decode_offsets , mask_2f, block) ;
555534
556535 _mm256_maskstore_epi32 ( output_chunk. as_mut_ptr ( ) . cast ( ) , mask_output, block) ;
557536 }
558537 if invalid {
559- return Err ( find_invalid_input ( input_index, input_chunk, decode_table) ) ;
538+ return Err ( find_invalid_input ( input_index, input_chunk, & self . decode_table ) ) ;
560539 }
561540
562541 input_index += 32 ;
@@ -604,10 +583,10 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
604583 block = _mm256_maskload_epi32 ( input_chunk. as_ptr ( ) . cast ( ) , mask_input) ;
605584 let outblock
606585 = decode_masked ( & mut invalid,
607- lut_lo, lut_hi, lut_roll , mask_2f, mask_input, block) ;
586+ lut_lo, lut_hi, self . decode_offsets , mask_2f, mask_input, block) ;
608587
609588 if invalid {
610- return Err ( find_invalid_input ( input_index, input_chunk, decode_table) ) ;
589+ return Err ( find_invalid_input ( input_index, input_chunk, & self . decode_table ) ) ;
611590 }
612591
613592 _mm256_maskstore_epi32 ( output_chunk. as_mut_ptr ( ) . cast ( ) , mask_output, outblock) ;
@@ -677,7 +656,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
677656 // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
678657 // Pack the leftovers from left to right.
679658 let shift = 64 - ( morsels_in_leftover + 1 ) * 6 ;
680- let morsel = decode_table[ * b as usize ] ;
659+ let morsel = self . decode_table [ * b as usize ] ;
681660 if morsel == INVALID_VALUE {
682661 return Err ( DecodeError :: InvalidByte ( start_of_leftovers + i, * b) ) ;
683662 }
@@ -729,7 +708,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
729708 }
730709}
731710
732- fn find_invalid_input ( input_index : usize , input : & [ u8 ] , decode_table : & ' static [ u8 ; 256 ] ) -> DecodeError {
711+ fn find_invalid_input ( input_index : usize , input : & [ u8 ] , decode_table : & [ u8 ; 256 ] ) -> DecodeError {
733712 // Figure out which byte was invalid exactly.
734713 for i in 0 ..input. len ( ) {
735714 let byte = input[ i] ;
0 commit comments