@@ -35,6 +35,12 @@ pub struct AVX2Encoder {
3535 // Alphabet LUT for vectorized steps
3636 encode_offsets : __m256i ,
3737 decode_offsets : __m256i ,
38+
39+ // The algorithm in use needs to be able to distinguish between the two singletons outside the
40+ // [A-Za-z] ranges.
41+ // For STANDARD these are '+' and '/' and the engine matches against '/' i.e. 0x2F
42+ // For URL_SAFE these are '-' and '_' and the engine matches against '_' i.e. 0x5F
43+ singleton_mask : __m256i ,
3844}
3945
4046impl AVX2Encoder {
@@ -43,20 +49,22 @@ impl AVX2Encoder {
4349 pub fn from_standard ( config : AVX2Config ) -> Self {
4450 let encode_offsets = unsafe {
4551 _mm256_setr_epi8 (
46- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
47- -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
48- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
49- -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
52+ // 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
53+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
54+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -19 , -16 , 65 , 0 , 0 ,
5055 )
5156 } ;
5257
5358 let decode_offsets = unsafe {
5459 _mm256_setr_epi8 (
60+ // 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
5561 0 , 16 , 19 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
5662 0 , 16 , 19 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
5763 )
5864 } ;
5965
66+ let singleton_mask = unsafe { _mm256_set1_epi8 ( 0x2F ) } ;
67+
6068 Self {
6169 config,
6270
@@ -65,27 +73,30 @@ impl AVX2Encoder {
6573
6674 encode_offsets,
6775 decode_offsets,
76+ singleton_mask,
6877 }
6978 }
7079 /// Create an AVX2Encoder for the urlsafe alphabet with the given config.
7180 /// You can create one for standard with the associated function [`from_standard`].
7281 pub fn from_url_safe ( config : AVX2Config ) -> Self {
7382 let encode_offsets = unsafe {
7483 _mm256_setr_epi8 (
75- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
76- -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
77- 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
78- -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
84+ // 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
85+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
86+ 71 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -17 , 32 , 65 , 0 , 0 ,
7987 )
8088 } ;
8189
8290 let decode_offsets = unsafe {
8391 _mm256_setr_epi8 (
84- 0 , -32 , 17 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
85- 0 , -32 , 17 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
92+ // 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
93+ 0 , -32 , 17 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
94+ 0 , -32 , 17 , 4 , -65 , -65 , -71 , -71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
8695 )
8796 } ;
8897
98+ let singleton_mask = unsafe { _mm256_set1_epi8 ( 0x2B ) } ;
99+
89100 Self {
90101 config,
91102
@@ -94,6 +105,7 @@ impl AVX2Encoder {
94105
95106 encode_offsets,
96107 decode_offsets,
108+ singleton_mask,
97109 }
98110 }
99111}
@@ -143,23 +155,67 @@ unsafe fn load_block(input: __m256i) -> __m256i {
143155#[ inline( always) ]
144156unsafe fn decode (
145157 invalid : & mut bool ,
146- lut_lo : __m256i ,
147- lut_hi : __m256i ,
148- lut_roll : __m256i ,
149- mask_2f : __m256i ,
158+ lo_witness_lut : __m256i ,
159+ hi_witness_lut : __m256i ,
160+ offsets : __m256i ,
161+ mask_singleton : __m256i ,
150162 block : __m256i
151163) -> __m256i {
152- // TODO: Explain this decode step
153- let hi_nibbles = _mm256_srli_epi32 ( block, 4 ) ;
154- let lo_nibbles = _mm256_and_si256 ( block, mask_2f) ;
155- let lo = _mm256_shuffle_epi8 ( lut_lo, lo_nibbles) ;
156- let eq_2f = _mm256_cmpeq_epi8 ( block, mask_2f) ;
157- let hi_nibbles = _mm256_and_si256 ( hi_nibbles, mask_2f) ;
158- let hi = _mm256_shuffle_epi8 ( lut_hi, hi_nibbles) ;
159- let roll = _mm256_shuffle_epi8 ( lut_roll, _mm256_add_epi8 ( eq_2f, hi_nibbles) ) ;
160- if _mm256_testz_si256 ( lo, hi) == 0 {
164+ // The most relevant information to understand this algorithm is this tidbit:
165+ // AVX shuffle conveniently work like table lookups; c = _mm256_shuffle_epi8(a,b) behaves* like
166+ // for i in 0..16 {
167+ // c[i] = a[b[i]];
168+ // c[i+16] = a[b[i+16]];
169+ // }
170+ // This is the reason why lo_witness_lut, hi_witness_lut, encode_offsets and decode_offets all have the exact
171+ // same values set for each 16-byte half; they are used as Look-Up tables in shuffles.
172+ // (* it additionally sets c[i] and c[i] to 0 if b[i] >= 128 but that is not used here)
173+ //
174+ // As a first step, since the indexes available in shuffles are only 0.16 or in other words one
175+ // nibble's worth, split each input byte into high and low nibble.
176+ // The high nibbles are retrieved by shifting the input by 4 bits and then applying a mask of
177+ // 0b1111 to it. The low bits are retrieved by not shifting and applying the very same map.
178+ // The "standard" algorithm happens to look for 0x2F ('/') which *also* just happens to have the
179+ // lowest 4 bits set to 1, so it can use that. The urlsafe one can't.
180+ let mask_nib = _mm256_set1_epi8 ( 0b00001111 ) ;
181+ let block_shifted = _mm256_srli_epi32 ( block, 4 ) ;
182+ let hi_nibbles = _mm256_and_si256 ( block_shifted, mask_nib) ;
183+ let lo_nibbles = _mm256_and_si256 ( block, mask_nib) ;
184+
185+ // This algorithm uses offsets for decoding. e.g. in the standard and url-safe alphabet the
186+ // ASCII letter 'A' encodes 0b000000, the letter 'B' 0b000001, and so on. The ASCII value of
187+ // 'A' is 65. So to get from a capital letter in the input to the value it encodes you have to
188+ // substract 65. Similarly, the letter 'a' encodes 0b011010, or 26 in decimal. 'b' encodes 27
189+ // and so on. But the ASCII value of 'a' is 97, so to get from a miniscule to it's value you
190+ // don't substract 65 but 71 instead.
191+ // The main optimization this algorithm makes and the source for it's assumptions is that it
192+ // relies on the fact that the alphabet used has continous ordered ranges of inputs that thus
193+ // share an offset, and that these ranges are distinguishable by their upper nibble.
194+ // In other words because for [A-Z] substracting 65 gets you to the correct value and for [a-z]
195+ // substracting 71 does as well. While decoding we just have to figure out which range an input
196+ // belongs to and directly know what offset to apply.
197+ // However, we need to check for invalid inputs. The algorithm again optimizes that by using
198+ // the fact that valid input is in one of the ranges or one of two special bytes ('+' and '/'
199+ // or '-' and '_' specifically)
200+ // [A-Z] for example is the range of 0b100_0001 to 0b101_1010, so the high nibbles 0b100 (4)
201+ // and 0b101 (5). But not every input with these high nibbles is valid, e.g. the character '@'
202+ // encoded as 0b100_0000 or the character '[', i.e. 0b101_1011. So we need to check if the low
203+ // nibble is valid for a given high nibble. AVX2 has an instructions for bitwise comparing two
204+ // vectors which is exposed as `test` instrinsics which return a different CPU flag for
205+ // conditionals.
206+ // _mm256_testz_si256 used here bitwise AND's both input vectors and returns 1 if the result is
207+ // zero and 0 if the result has any bit set.
208+ // So we need to now generate a `witness` for the high and low nibble each so that
209+ // `witness_hi & witness_lo == 0` iff the input is valid.
210+ let witness_lo = _mm256_shuffle_epi8 ( lo_witness_lut, lo_nibbles) ;
211+ let witness_hi = _mm256_shuffle_epi8 ( hi_witness_lut, hi_nibbles) ;
212+ if _mm256_testz_si256 ( witness_lo, witness_hi) == 0 {
161213 * invalid = true ;
214+ return _mm256_and_si256 ( witness_hi, witness_lo) ;
162215 }
216+
217+ let eq_singleton = _mm256_cmpeq_epi8 ( block, mask_singleton) ;
218+ let roll = _mm256_shuffle_epi8 ( offsets, _mm256_add_epi8 ( eq_singleton, hi_nibbles) ) ;
163219 let shuffeled = _mm256_add_epi8 ( block, roll) ;
164220
165221 let merge_ab_and_bc = _mm256_maddubs_epi16 ( shuffeled,
@@ -185,28 +241,28 @@ unsafe fn decode(
185241/// since `0` bytes would in fact be an invalid input.
186242unsafe fn decode_masked (
187243 invalid : & mut bool ,
188- lut_lo : __m256i ,
189- lut_hi : __m256i ,
244+ lo_witness_lut : __m256i ,
245+ hi_witness_lut : __m256i ,
190246 lut_roll : __m256i ,
191- mask_2f : __m256i ,
247+ mask_singleton : __m256i ,
192248 mask_input : __m256i ,
193249 block : __m256i
194250) -> __m256i {
195251 let hi_nibbles = _mm256_srli_epi32 ( block, 4 ) ;
196- let lo_nibbles = _mm256_and_si256 ( block, mask_2f ) ;
197- let eq_2f = _mm256_cmpeq_epi8 ( block, mask_2f ) ;
198- let hi_nibbles = _mm256_and_si256 ( hi_nibbles, mask_2f ) ;
252+ let lo_nibbles = _mm256_and_si256 ( block, mask_singleton ) ;
253+ let eq_singleton = _mm256_cmpeq_epi8 ( block, mask_singleton ) ;
254+ let hi_nibbles = _mm256_and_si256 ( hi_nibbles, mask_singleton ) ;
199255
200- let lo = _mm256_shuffle_epi8 ( lut_lo , lo_nibbles) ;
201- let hi = _mm256_shuffle_epi8 ( lut_hi , hi_nibbles) ;
256+ let lo = _mm256_shuffle_epi8 ( lo_witness_lut , lo_nibbles) ;
257+ let hi = _mm256_shuffle_epi8 ( hi_witness_lut , hi_nibbles) ;
202258 // Special case: If we have a masked input we need to forward this mask here to not
203259 // trip the test below
204260 let hi = _mm256_and_si256 ( hi, mask_input) ;
205261 if _mm256_testz_si256 ( lo, hi) == 0 {
206262 * invalid = true ;
207263 }
208264
209- let roll = _mm256_shuffle_epi8 ( lut_roll, _mm256_add_epi8 ( eq_2f , hi_nibbles) ) ;
265+ let roll = _mm256_shuffle_epi8 ( lut_roll, _mm256_add_epi8 ( eq_singleton , hi_nibbles) ) ;
210266 let shuffeled = _mm256_add_epi8 ( block, roll) ;
211267
212268 let merge_ab_and_bc = _mm256_maddubs_epi16 ( shuffeled,
@@ -457,20 +513,69 @@ impl super::Engine for AVX2Encoder {
457513 let mut block: __m256i ;
458514 let mut invalid: bool = false ;
459515
460- // Initialize the four required vectors for all avx decoding operations
461- let lut_lo = unsafe { _mm256_setr_epi8 (
516+ // Witnesses for the high nibbles:
517+ // 0x0 and 0x1 are never valid, no matter what the low nibble is.
518+ // 0x2 is valid for the characters '+' (0x2B), '/' (0x2F) and '-' (0x2D), depending on the
519+ // alphabet.
520+ // 0x3 contains numerals but the only valid inputs are 0x30 to 0x39, so we need to make
521+ // sure that everything from 0xA to 0xF is rejected.
522+ // 0x4 and 0x5 contain [A-Z] and also the special character '_' (0x5F) from the urlsafe
523+ // alphabet.
524+ // 0x6 and 0x7 contain [a-z].
525+ // 0x7 and 0x8 are never valid; 0x8 or higher especially means invalid ASCII.
526+ //
527+ // We use -0x1 as "always invalid" value so that the low witness has to only return
528+ // something != 0 for the invalid test to trip.
529+ let hi_witness_lut = unsafe { _mm256_setr_epi8 (
530+ // 0 1 2 3 4 5 6 7
531+ -0x1 , -0x1 , 0x01 , 0x02 , 0x04 , 0x08 , 0x04 , 0x08 ,
532+ // 8 9 10 11 12 13 14 15
533+ -0x1 , -0x1 , -0x1 , -0x1 , -0x1 , -0x1 , -0x1 , -0x1 ,
534+ // 0 1 2 3 4 5 6 7
535+ -0x1 , -0x1 , 0x01 , 0x02 , 0x04 , 0x08 , 0x04 , 0x08 ,
536+ // 8 9 10 11 12 13 14 15
537+ -0x1 , -0x1 , -0x1 , -0x1 , -0x1 , -0x1 , -0x1 , -0x1
538+ ) } ;
539+ // Witnesses for the low nibbles. The requirements for the given hi witnesses are then:
540+ // // Be invalid if hi is.
541+ // - lo[..] & -0x1 == 1
542+ // // Numerals
543+ // - lo[0..9] & 0x2 == 0
544+ // - lo[10..15] & 0x2 == 1
545+ // // Capitals
546+ // - lo[0] & 0x4 == 1
547+ // - lo[1..] & 0x4 == 0
548+ // - lo[0..10] & 0x8 == 0
549+ // - lo[11..15] & 0x8 == 1
550+ // // Miniscules
551+ // - lo[0] & 0x4 == 1
552+ // - lo[1..] & 0x4 == 1
553+ // - lo[..10] & 0x8 == 1
554+ // - lo[11..15] & 0x8 == 1
555+ // // Special, depending on the alphabet
556+ // // standard
557+ // - lo[15] & 0x1 == 1
558+ // - lo[11] & 0x1 == 1
559+ // // urlsafe
560+ // - lo[13] & 0x1 == 1
561+ // - lo[15] & 0x8
562+ // ASCII has the advantage that A-Z and a-z are 0x20 away from each other so you can use
563+ // the same lo witnesses.
564+ // The easiest way to create these witness tables and what is done here is to use the hi
565+ // witness to select a bit to probe and set the bit in the low witness for valid nibbles in
566+ // that range. E.g. the hi witness sets bit 1 for high nibble 0x2 and bit 3 for 0x4 and
567+ // 0x6, and the lo witness only sets bit 1 for valid inputs with high nibble 0x2 (like
568+ // 0x2F, 0x2B etc.) and bit 3 for valid letters [A-Za-z].
569+ let lo_witness_lut = unsafe { _mm256_setr_epi8 (
570+ // 0 1 2 3 4 5 6 7
462571 0x15 , 0x11 , 0x11 , 0x11 , 0x11 , 0x11 , 0x11 , 0x11 ,
572+ // 8 9 10 11 12 13 14 15
463573 0x11 , 0x11 , 0x13 , 0x1A , 0x1B , 0x1B , 0x1B , 0x1A ,
574+ // 0 1 2 3 4 5 6 7
464575 0x15 , 0x11 , 0x11 , 0x11 , 0x11 , 0x11 , 0x11 , 0x11 ,
576+ // 8 9 10 11 12 13 14 15
465577 0x11 , 0x11 , 0x13 , 0x1A , 0x1B , 0x1B , 0x1B , 0x1A
466578 ) } ;
467- let lut_hi = unsafe { _mm256_setr_epi8 (
468- 0x10 , 0x10 , 0x01 , 0x02 , 0x04 , 0x08 , 0x04 , 0x08 ,
469- 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 ,
470- 0x10 , 0x10 , 0x01 , 0x02 , 0x04 , 0x08 , 0x04 , 0x08 ,
471- 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10 , 0x10
472- ) } ;
473- let mask_2f = unsafe { _mm256_set1_epi8 ( 0x2F ) } ;
474579
475580 // This will only evaluate to true if we have an input of 33 bytes or more;
476581 // skip_final_bytes is at least input.len() otherwise.
@@ -493,7 +598,12 @@ impl super::Engine for AVX2Encoder {
493598
494599 unsafe {
495600 block = _mm256_loadu_si256 ( input_chunk. as_ptr ( ) . cast ( ) ) ;
496- block = decode ( & mut invalid, lut_lo, lut_hi, self . decode_offsets , mask_2f, block) ;
601+ block = decode ( & mut invalid,
602+ lo_witness_lut,
603+ hi_witness_lut,
604+ self . decode_offsets ,
605+ self . singleton_mask ,
606+ block) ;
497607
498608 if invalid {
499609 return Err ( find_invalid_input ( input_index, input_chunk, & self . decode_table ) ) ;
@@ -530,7 +640,7 @@ impl super::Engine for AVX2Encoder {
530640 let mask_output = _mm256_loadu_si256 ( MASKLOAD [ 2 ..10 ] . as_ptr ( ) . cast ( ) ) ;
531641
532642 block = _mm256_loadu_si256 ( input_chunk. as_ptr ( ) . cast ( ) ) ;
533- block = decode ( & mut invalid, lut_lo , lut_hi , self . decode_offsets , mask_2f , block) ;
643+ block = decode ( & mut invalid, lo_witness_lut , hi_witness_lut , self . decode_offsets , self . singleton_mask , block) ;
534644
535645 _mm256_maskstore_epi32 ( output_chunk. as_mut_ptr ( ) . cast ( ) , mask_output, block) ;
536646 }
@@ -583,7 +693,7 @@ impl super::Engine for AVX2Encoder {
583693 block = _mm256_maskload_epi32 ( input_chunk. as_ptr ( ) . cast ( ) , mask_input) ;
584694 let outblock
585695 = decode_masked ( & mut invalid,
586- lut_lo , lut_hi , self . decode_offsets , mask_2f , mask_input, block) ;
696+ lo_witness_lut , hi_witness_lut , self . decode_offsets , self . singleton_mask , mask_input, block) ;
587697
588698 if invalid {
589699 return Err ( find_invalid_input ( input_index, input_chunk, & self . decode_table ) ) ;
@@ -717,7 +827,7 @@ fn find_invalid_input(input_index: usize, input: &[u8], decode_table: &[u8; 256]
717827 }
718828 }
719829
720- unreachable ! ( "Called find_invalid_input on valid input! { }, {:? }" , input_index , input ) ;
830+ unreachable ! ( "find_invalid_input was given valid input {:? }, global index { }" , input , input_index ) ;
721831}
722832
723833
0 commit comments