Skip to content

Commit 1046070

Browse files
committed
Start explainering the algorithm
1 parent ed23d14 commit 1046070

File tree

1 file changed

+155
-45
lines changed

1 file changed

+155
-45
lines changed

src/engine/avx2/mod.rs

Lines changed: 155 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ pub struct AVX2Encoder {
3535
// Alphabet LUT for vectorized steps
3636
encode_offsets: __m256i,
3737
decode_offsets: __m256i,
38+
39+
// The algorithm in use needs to be able to distinguish between the two singletons outside the
40+
// [A-Za-z] ranges.
41+
// For STANDARD these are '+' and '/' and the engine matches against '/' i.e. 0x2F
42+
// For URL_SAFE these are '-' and '_' and the engine matches against '_' i.e. 0x5F
43+
singleton_mask: __m256i,
3844
}
3945

4046
impl AVX2Encoder {
@@ -43,20 +49,22 @@ impl AVX2Encoder {
4349
pub fn from_standard(config: AVX2Config) -> Self {
4450
let encode_offsets = unsafe {
4551
_mm256_setr_epi8(
46-
71, -4, -4, -4, -4, -4, -4, -4,
47-
-4, -4, -4,-19,-16, 65, 0, 0,
48-
71, -4, -4, -4, -4, -4, -4, -4,
49-
-4, -4, -4,-19,-16, 65, 0, 0,
52+
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
53+
71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,-19,-16, 65, 0, 0,
54+
71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,-19,-16, 65, 0, 0,
5055
)
5156
};
5257

5358
let decode_offsets = unsafe {
5459
_mm256_setr_epi8(
60+
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
5561
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
5662
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
5763
)
5864
};
5965

66+
let singleton_mask = unsafe { _mm256_set1_epi8(0x2F) };
67+
6068
Self {
6169
config,
6270

@@ -65,27 +73,30 @@ impl AVX2Encoder {
6573

6674
encode_offsets,
6775
decode_offsets,
76+
singleton_mask,
6877
}
6978
}
7079
/// Create an AVX2Encoder for the urlsafe alphabet with the given config.
7180
/// You can create one for standard with the associated function [`from_standard`].
7281
pub fn from_url_safe(config: AVX2Config) -> Self {
7382
let encode_offsets = unsafe {
7483
_mm256_setr_epi8(
75-
71, -4, -4, -4, -4, -4, -4, -4,
76-
-4, -4, -4,-17, 32, 65, 0, 0,
77-
71, -4, -4, -4, -4, -4, -4, -4,
78-
-4, -4, -4,-17, 32, 65, 0, 0,
84+
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
85+
71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,-17, 32, 65, 0, 0,
86+
71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,-17, 32, 65, 0, 0,
7987
)
8088
};
8189

8290
let decode_offsets = unsafe {
8391
_mm256_setr_epi8(
84-
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
85-
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
92+
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
93+
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
94+
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
8695
)
8796
};
8897

98+
let singleton_mask = unsafe { _mm256_set1_epi8(0x2B) };
99+
89100
Self {
90101
config,
91102

@@ -94,6 +105,7 @@ impl AVX2Encoder {
94105

95106
encode_offsets,
96107
decode_offsets,
108+
singleton_mask,
97109
}
98110
}
99111
}
@@ -143,23 +155,67 @@ unsafe fn load_block(input: __m256i) -> __m256i {
143155
#[inline(always)]
144156
unsafe fn decode(
145157
invalid: &mut bool,
146-
lut_lo: __m256i,
147-
lut_hi: __m256i,
148-
lut_roll: __m256i,
149-
mask_2f: __m256i,
158+
lo_witness_lut: __m256i,
159+
hi_witness_lut: __m256i,
160+
offsets: __m256i,
161+
mask_singleton: __m256i,
150162
block: __m256i
151163
) -> __m256i {
152-
// TODO: Explain this decode step
153-
let hi_nibbles = _mm256_srli_epi32(block, 4);
154-
let lo_nibbles = _mm256_and_si256(block, mask_2f);
155-
let lo = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
156-
let eq_2f = _mm256_cmpeq_epi8(block, mask_2f);
157-
let hi_nibbles = _mm256_and_si256(hi_nibbles, mask_2f);
158-
let hi = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
159-
let roll = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_2f, hi_nibbles));
160-
if _mm256_testz_si256(lo, hi) == 0 {
164+
// The most relevant information to understand this algorithm is this tidbit:
165+
// AVX shuffle conveniently work like table lookups; c = _mm256_shuffle_epi8(a,b) behaves* like
166+
// for i in 0..16 {
167+
// c[i] = a[b[i]];
168+
// c[i+16] = a[b[i+16]];
169+
// }
170+
// This is the reason why lo_witness_lut, hi_witness_lut, encode_offsets and decode_offets all have the exact
171+
// same values set for each 16-byte half; they are used as Look-Up tables in shuffles.
172+
// (* it additionally sets c[i] and c[i] to 0 if b[i] >= 128 but that is not used here)
173+
//
174+
// As a first step, since the indexes available in shuffles are only 0.16 or in other words one
175+
// nibble's worth, split each input byte into high and low nibble.
176+
// The high nibbles are retrieved by shifting the input by 4 bits and then applying a mask of
177+
// 0b1111 to it. The low bits are retrieved by not shifting and applying the very same map.
178+
// The "standard" algorithm happens to look for 0x2F ('/') which *also* just happens to have the
179+
// lowest 4 bits set to 1, so it can use that. The urlsafe one can't.
180+
let mask_nib = _mm256_set1_epi8(0b00001111);
181+
let block_shifted = _mm256_srli_epi32(block, 4);
182+
let hi_nibbles = _mm256_and_si256(block_shifted, mask_nib);
183+
let lo_nibbles = _mm256_and_si256(block, mask_nib);
184+
185+
// This algorithm uses offsets for decoding. e.g. in the standard and url-safe alphabet the
186+
// ASCII letter 'A' encodes 0b000000, the letter 'B' 0b000001, and so on. The ASCII value of
187+
// 'A' is 65. So to get from a capital letter in the input to the value it encodes you have to
188+
// substract 65. Similarly, the letter 'a' encodes 0b011010, or 26 in decimal. 'b' encodes 27
189+
// and so on. But the ASCII value of 'a' is 97, so to get from a miniscule to it's value you
190+
// don't substract 65 but 71 instead.
191+
// The main optimization this algorithm makes and the source for it's assumptions is that it
192+
// relies on the fact that the alphabet used has continous ordered ranges of inputs that thus
193+
// share an offset, and that these ranges are distinguishable by their upper nibble.
194+
// In other words because for [A-Z] substracting 65 gets you to the correct value and for [a-z]
195+
// substracting 71 does as well. While decoding we just have to figure out which range an input
196+
// belongs to and directly know what offset to apply.
197+
// However, we need to check for invalid inputs. The algorithm again optimizes that by using
198+
// the fact that valid input is in one of the ranges or one of two special bytes ('+' and '/'
199+
// or '-' and '_' specifically)
200+
// [A-Z] for example is the range of 0b100_0001 to 0b101_1010, so the high nibbles 0b100 (4)
201+
// and 0b101 (5). But not every input with these high nibbles is valid, e.g. the character '@'
202+
// encoded as 0b100_0000 or the character '[', i.e. 0b101_1011. So we need to check if the low
203+
// nibble is valid for a given high nibble. AVX2 has an instructions for bitwise comparing two
204+
// vectors which is exposed as `test` instrinsics which return a different CPU flag for
205+
// conditionals.
206+
// _mm256_testz_si256 used here bitwise AND's both input vectors and returns 1 if the result is
207+
// zero and 0 if the result has any bit set.
208+
// So we need to now generate a `witness` for the high and low nibble each so that
209+
// `witness_hi & witness_lo == 0` iff the input is valid.
210+
let witness_lo = _mm256_shuffle_epi8(lo_witness_lut, lo_nibbles);
211+
let witness_hi = _mm256_shuffle_epi8(hi_witness_lut, hi_nibbles);
212+
if _mm256_testz_si256(witness_lo, witness_hi) == 0 {
161213
*invalid = true;
214+
return _mm256_and_si256(witness_hi, witness_lo);
162215
}
216+
217+
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
218+
let roll = _mm256_shuffle_epi8(offsets, _mm256_add_epi8(eq_singleton, hi_nibbles));
163219
let shuffeled = _mm256_add_epi8(block, roll);
164220

165221
let merge_ab_and_bc = _mm256_maddubs_epi16(shuffeled,
@@ -185,28 +241,28 @@ unsafe fn decode(
185241
/// since `0` bytes would in fact be an invalid input.
186242
unsafe fn decode_masked(
187243
invalid: &mut bool,
188-
lut_lo: __m256i,
189-
lut_hi: __m256i,
244+
lo_witness_lut: __m256i,
245+
hi_witness_lut: __m256i,
190246
lut_roll: __m256i,
191-
mask_2f: __m256i,
247+
mask_singleton: __m256i,
192248
mask_input: __m256i,
193249
block: __m256i
194250
) -> __m256i {
195251
let hi_nibbles = _mm256_srli_epi32(block, 4);
196-
let lo_nibbles = _mm256_and_si256(block, mask_2f);
197-
let eq_2f = _mm256_cmpeq_epi8(block, mask_2f);
198-
let hi_nibbles = _mm256_and_si256(hi_nibbles, mask_2f);
252+
let lo_nibbles = _mm256_and_si256(block, mask_singleton);
253+
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
254+
let hi_nibbles = _mm256_and_si256(hi_nibbles, mask_singleton);
199255

200-
let lo = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
201-
let hi = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
256+
let lo = _mm256_shuffle_epi8(lo_witness_lut, lo_nibbles);
257+
let hi = _mm256_shuffle_epi8(hi_witness_lut, hi_nibbles);
202258
// Special case: If we have a masked input we need to forward this mask here to not
203259
// trip the test below
204260
let hi = _mm256_and_si256(hi, mask_input);
205261
if _mm256_testz_si256(lo, hi) == 0 {
206262
*invalid = true;
207263
}
208264

209-
let roll = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_2f, hi_nibbles));
265+
let roll = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_singleton, hi_nibbles));
210266
let shuffeled = _mm256_add_epi8(block, roll);
211267

212268
let merge_ab_and_bc = _mm256_maddubs_epi16(shuffeled,
@@ -457,20 +513,69 @@ impl super::Engine for AVX2Encoder {
457513
let mut block: __m256i;
458514
let mut invalid: bool = false;
459515

460-
// Initialize the four required vectors for all avx decoding operations
461-
let lut_lo = unsafe { _mm256_setr_epi8(
516+
// Witnesses for the high nibbles:
517+
// 0x0 and 0x1 are never valid, no matter what the low nibble is.
518+
// 0x2 is valid for the characters '+' (0x2B), '/' (0x2F) and '-' (0x2D), depending on the
519+
// alphabet.
520+
// 0x3 contains numerals but the only valid inputs are 0x30 to 0x39, so we need to make
521+
// sure that everything from 0xA to 0xF is rejected.
522+
// 0x4 and 0x5 contain [A-Z] and also the special character '_' (0x5F) from the urlsafe
523+
// alphabet.
524+
// 0x6 and 0x7 contain [a-z].
525+
// 0x7 and 0x8 are never valid; 0x8 or higher especially means invalid ASCII.
526+
//
527+
// We use -0x1 as "always invalid" value so that the low witness has to only return
528+
// something != 0 for the invalid test to trip.
529+
let hi_witness_lut = unsafe { _mm256_setr_epi8(
530+
// 0 1 2 3 4 5 6 7
531+
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
532+
// 8 9 10 11 12 13 14 15
533+
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
534+
// 0 1 2 3 4 5 6 7
535+
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
536+
// 8 9 10 11 12 13 14 15
537+
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
538+
)};
539+
// Witnesses for the low nibbles. The requirements for the given hi witnesses are then:
540+
// // Be invalid if hi is.
541+
// - lo[..] & -0x1 == 1
542+
// // Numerals
543+
// - lo[0..9] & 0x2 == 0
544+
// - lo[10..15] & 0x2 == 1
545+
// // Capitals
546+
// - lo[0] & 0x4 == 1
547+
// - lo[1..] & 0x4 == 0
548+
// - lo[0..10] & 0x8 == 0
549+
// - lo[11..15] & 0x8 == 1
550+
// // Miniscules
551+
// - lo[0] & 0x4 == 1
552+
// - lo[1..] & 0x4 == 1
553+
// - lo[..10] & 0x8 == 1
554+
// - lo[11..15] & 0x8 == 1
555+
// // Special, depending on the alphabet
556+
// // standard
557+
// - lo[15] & 0x1 == 1
558+
// - lo[11] & 0x1 == 1
559+
// // urlsafe
560+
// - lo[13] & 0x1 == 1
561+
// - lo[15] & 0x8
562+
// ASCII has the advantage that A-Z and a-z are 0x20 away from each other so you can use
563+
// the same lo witnesses.
564+
// The easiest way to create these witness tables and what is done here is to use the hi
565+
// witness to select a bit to probe and set the bit in the low witness for valid nibbles in
566+
// that range. E.g. the hi witness sets bit 1 for high nibble 0x2 and bit 3 for 0x4 and
567+
// 0x6, and the lo witness only sets bit 1 for valid inputs with high nibble 0x2 (like
568+
// 0x2F, 0x2B etc.) and bit 3 for valid letters [A-Za-z].
569+
let lo_witness_lut = unsafe { _mm256_setr_epi8(
570+
// 0 1 2 3 4 5 6 7
462571
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
572+
// 8 9 10 11 12 13 14 15
463573
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A,
574+
// 0 1 2 3 4 5 6 7
464575
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
576+
// 8 9 10 11 12 13 14 15
465577
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A
466578
)};
467-
let lut_hi = unsafe { _mm256_setr_epi8(
468-
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
469-
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
470-
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
471-
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10
472-
)};
473-
let mask_2f = unsafe { _mm256_set1_epi8(0x2F) };
474579

475580
// This will only evaluate to true if we have an input of 33 bytes or more;
476581
// skip_final_bytes is at least input.len() otherwise.
@@ -493,7 +598,12 @@ impl super::Engine for AVX2Encoder {
493598

494599
unsafe {
495600
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
496-
block = decode(&mut invalid, lut_lo, lut_hi, self.decode_offsets, mask_2f, block);
601+
block = decode(&mut invalid,
602+
lo_witness_lut,
603+
hi_witness_lut,
604+
self.decode_offsets,
605+
self.singleton_mask,
606+
block);
497607

498608
if invalid {
499609
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));
@@ -530,7 +640,7 @@ impl super::Engine for AVX2Encoder {
530640
let mask_output = _mm256_loadu_si256(MASKLOAD[2..10].as_ptr().cast());
531641

532642
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
533-
block = decode(&mut invalid, lut_lo, lut_hi, self.decode_offsets, mask_2f, block);
643+
block = decode(&mut invalid, lo_witness_lut, hi_witness_lut, self.decode_offsets, self.singleton_mask, block);
534644

535645
_mm256_maskstore_epi32(output_chunk.as_mut_ptr().cast(), mask_output, block);
536646
}
@@ -583,7 +693,7 @@ impl super::Engine for AVX2Encoder {
583693
block = _mm256_maskload_epi32(input_chunk.as_ptr().cast(), mask_input);
584694
let outblock
585695
= decode_masked(&mut invalid,
586-
lut_lo, lut_hi, self.decode_offsets, mask_2f, mask_input, block);
696+
lo_witness_lut, hi_witness_lut, self.decode_offsets, self.singleton_mask, mask_input, block);
587697

588698
if invalid {
589699
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));
@@ -717,7 +827,7 @@ fn find_invalid_input(input_index: usize, input: &[u8], decode_table: &[u8; 256]
717827
}
718828
}
719829

720-
unreachable!("Called find_invalid_input on valid input! {}, {:?}", input_index, input);
830+
unreachable!("find_invalid_input was given valid input {:?}, global index {}", input, input_index);
721831
}
722832

723833

0 commit comments

Comments
 (0)