Skip to content

Commit 7039695

Browse files
committed
Finish the urlsafe coding engine
1 parent 1046070 commit 7039695

File tree

1 file changed

+135
-94
lines changed

1 file changed

+135
-94
lines changed

src/engine/avx2/mod.rs

Lines changed: 135 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub struct AVX2Encoder {
4141
// For STANDARD these are '+' and '/' and the engine matches against '/' i.e. 0x2F
4242
// For URL_SAFE these are '-' and '_' and the engine matches against '_' i.e. 0x5F
4343
singleton_mask: __m256i,
44+
hi_witnesses: __m256i,
45+
lo_witnesses: __m256i,
4446
}
4547

4648
impl AVX2Encoder {
@@ -55,15 +57,64 @@ impl AVX2Encoder {
5557
)
5658
};
5759

60+
// These decode offsets are accessed by the high nibble of the ASCII character being
61+
// decoded so for example 'A' (0x41) is offset -65 since it encodes 0b000000.
62+
// The one exception to that is the value '/' (0x2F) which has to be handled specifically.
5863
let decode_offsets = unsafe {
5964
_mm256_setr_epi8(
60-
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
61-
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
62-
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
65+
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
66+
0, 0, 19, 4, -65, -65, -71, -71, 16, 0, 0, 0, 0, 0, 0, 0,
67+
0, 0, 19, 4, -65, -65, -71, -71, 16, 0, 0, 0, 0, 0, 0, 0
6368
)
6469
};
6570

6671
let singleton_mask = unsafe { _mm256_set1_epi8(0x2F) };
72+
// Witnesses for the high nibbles:
73+
// 0x0 and 0x1 are never valid, no matter what the low nibble is.
74+
// 0x2 is valid for the characters '+' (0x2B), '/' (0x2F) and '-' (0x2D), depending on the
75+
// alphabet.
76+
// 0x3 contains numerals but the only valid inputs are 0x30 to 0x39, so we need to make
77+
// sure that everything from 0xA to 0xF is rejected.
78+
// 0x4 and 0x5 contain [A-Z] and also the special character '_' (0x5F) from the urlsafe
79+
// alphabet.
80+
// 0x6 and 0x7 contain [a-z].
81+
// 0x7 and 0x8 are never valid; 0x8 or higher especially means invalid ASCII.
82+
//
83+
// We use -0x1 as "always invalid" value so that the low witness has to only return
84+
// something != 0 for the invalid test to trip.
85+
let hi_witnesses = unsafe {
86+
_mm256_setr_epi8(
87+
// 0 1 2 3 4 5 6 7
88+
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
89+
// 8 9 10 11 12 13 14 15
90+
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
91+
// 0 1 2 3 4 5 6 7
92+
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
93+
// 8 9 10 11 12 13 14 15
94+
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
95+
)
96+
};
97+
// Witnesses for the low nibbles.
98+
// ASCII has the advantage that A-Z and a-z are 0x20 away from each other so you can use
99+
// the same lo witnesses for both of those ranges.
100+
// The easiest way to create these witness tables and what is done here is to use the hi
101+
// witness to select a bit to probe and set the bit in the low witness for invalid nibbles
102+
// in that range. E.g. the hi witness sets bit 1 for high nibble 0x2, bit 2 for 0x3, and
103+
// bit 3 for 0x4 and 0x6. The lo witness then sets bit 2 for 0xA..0xF (since those are
104+
// invalids in the numeric range), bit 1 for everything invalid in the special bytes range
105+
// (i.e. everything but 0x2F, 0x2B etc.), bit 3 for 0x1 and bit 4 for 0xB..0xF.
106+
let lo_witnesses = unsafe {
107+
_mm256_setr_epi8(
108+
// 0 1 2 3 4 5 6 7
109+
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
110+
// 8 9 80 11 12 13 14 15
111+
0x71, 0x71, 0x73, 0x7A, 0x7B, 0x7B, 0x7B, 0x7A,
112+
// 0 1 2 3 4 5 6 7
113+
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
114+
// 8 9 80 11 12 13 14 15
115+
0x71, 0x71, 0x73, 0x7A, 0x7B, 0x7B, 0x7B, 0x7A,
116+
)
117+
};
67118

68119
Self {
69120
config,
@@ -74,6 +125,8 @@ impl AVX2Encoder {
74125
encode_offsets,
75126
decode_offsets,
76127
singleton_mask,
128+
hi_witnesses,
129+
lo_witnesses,
77130
}
78131
}
79132
/// Create an AVX2Encoder for the urlsafe alphabet with the given config.
@@ -89,13 +142,39 @@ impl AVX2Encoder {
89142

90143
let decode_offsets = unsafe {
91144
_mm256_setr_epi8(
92-
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
93-
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
94-
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
145+
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
146+
0, 0, 17, 4, -65, -65, -71, -71, 0, 0, 0,-32, 0, 0, 0, 0,
147+
0, 0, 17, 4, -65, -65, -71, -71, 0, 0, 0,-32, 0, 0, 0, 0
95148
)
96149
};
97150

98-
let singleton_mask = unsafe { _mm256_set1_epi8(0x2B) };
151+
let singleton_mask = unsafe { _mm256_set1_epi8(0x5F) };
152+
let hi_witnesses = unsafe {
153+
_mm256_setr_epi8(
154+
// 0 1 2 3 4 5 6 7
155+
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
156+
// 8 9 10 11 12 13 14 15
157+
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
158+
// 0 1 2 3 4 5 6 7
159+
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
160+
// 8 9 10 11 12 13 14 15
161+
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
162+
)
163+
};
164+
// Lo witnesses for url-safe are slightly different than for standard:
165+
// Inputs 0x5F ('_') and 0x2D are valid, inputs 0x2F ('/') and 0x2B ('+') are not.
166+
let lo_witnesses = unsafe {
167+
_mm256_setr_epi8(
168+
// 0 1 2 3 4 5 6 7
169+
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
170+
// 8 9 A B C D E F
171+
0x71, 0x71, 0x73, 0x7B, 0x7B, 0x7A, 0x7B, 0x73,
172+
// 0 1 2 3 4 5 6 7
173+
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
174+
// 8 9 A B C D E F
175+
0x71, 0x71, 0x73, 0x7B, 0x7B, 0x7A, 0x7B, 0x73,
176+
)
177+
};
99178

100179
Self {
101180
config,
@@ -106,6 +185,8 @@ impl AVX2Encoder {
106185
encode_offsets,
107186
decode_offsets,
108187
singleton_mask,
188+
hi_witnesses,
189+
lo_witnesses,
109190
}
110191
}
111192
}
@@ -191,7 +272,7 @@ unsafe fn decode(
191272
// The main optimization this algorithm makes and the source for it's assumptions is that it
192273
// relies on the fact that the alphabet used has continous ordered ranges of inputs that thus
193274
// share an offset, and that these ranges are distinguishable by their upper nibble.
194-
// In other words because for [A-Z] substracting 65 gets you to the correct value and for [a-z]
275+
// In other words for [A-Z] substracting 65 gets you to the correct value and for [a-z]
195276
// substracting 71 does as well. While decoding we just have to figure out which range an input
196277
// belongs to and directly know what offset to apply.
197278
// However, we need to check for invalid inputs. The algorithm again optimizes that by using
@@ -214,10 +295,21 @@ unsafe fn decode(
214295
return _mm256_and_si256(witness_hi, witness_lo);
215296
}
216297

298+
// Next we check for one of the singleton bytes. Since in neither standard nor url-safe
299+
// alphabet they both have the same offset to their encoded value and also can't be
300+
// distinguished from other offset values by their high nibble alone ('_' has high nibble 5
301+
// like a-z, '/' and '+' both have 2) we need to explicitly match against one of them.
217302
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
218-
let roll = _mm256_shuffle_epi8(offsets, _mm256_add_epi8(eq_singleton, hi_nibbles));
219-
let shuffeled = _mm256_add_epi8(block, roll);
220303

304+
// In the last decoding step we do two things: Add 0x6 to all hi nibbles where we found our
305+
// singleton. This makes input 0x2F check for offset in offsets[8] and 0x5F in offsets[11].
306+
// Then, get the actual offset amount from `offsets` and add it to our input.
307+
let offsetidxs = _mm256_add_epi8(hi_nibbles, _mm256_and_si256(eq_singleton, _mm256_set1_epi8(0x6)));
308+
let offsetvals = _mm256_shuffle_epi8(offsets, offsetidxs);
309+
let shuffeled = _mm256_add_epi8(block, offsetvals);
310+
311+
// This merges the 16, 6 bit wide but byte aligned, values in each half into a packed 12 byte
312+
// block of data each.
221313
let merge_ab_and_bc = _mm256_maddubs_epi16(shuffeled,
222314
_mm256_set1_epi32(0x01400140));
223315
let madd = _mm256_madd_epi16(merge_ab_and_bc, _mm256_set1_epi32(0x00011000));
@@ -243,27 +335,29 @@ unsafe fn decode_masked(
243335
invalid: &mut bool,
244336
lo_witness_lut: __m256i,
245337
hi_witness_lut: __m256i,
246-
lut_roll: __m256i,
338+
offsets: __m256i,
247339
mask_singleton: __m256i,
248340
mask_input: __m256i,
249341
block: __m256i
250342
) -> __m256i {
251-
let hi_nibbles = _mm256_srli_epi32(block, 4);
252-
let lo_nibbles = _mm256_and_si256(block, mask_singleton);
253-
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
254-
let hi_nibbles = _mm256_and_si256(hi_nibbles, mask_singleton);
255-
256-
let lo = _mm256_shuffle_epi8(lo_witness_lut, lo_nibbles);
257-
let hi = _mm256_shuffle_epi8(hi_witness_lut, hi_nibbles);
258-
// Special case: If we have a masked input we need to forward this mask here to not
259-
// trip the test below
260-
let hi = _mm256_and_si256(hi, mask_input);
261-
if _mm256_testz_si256(lo, hi) == 0 {
343+
let mask_nib = _mm256_set1_epi8(0b00001111);
344+
let block_shifted = _mm256_srli_epi32(block, 4);
345+
let hi_nibbles = _mm256_and_si256(block_shifted, mask_nib);
346+
let lo_nibbles = _mm256_and_si256(block, mask_nib);
347+
348+
let witness_lo = _mm256_shuffle_epi8(lo_witness_lut, lo_nibbles);
349+
let witness_hi = _mm256_shuffle_epi8(hi_witness_lut, hi_nibbles);
350+
351+
let witness_hi = _mm256_and_si256(witness_hi, mask_input);
352+
if _mm256_testz_si256(witness_lo, witness_hi) == 0 {
262353
*invalid = true;
354+
return _mm256_and_si256(witness_hi, witness_lo);
263355
}
264356

265-
let roll = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_singleton, hi_nibbles));
266-
let shuffeled = _mm256_add_epi8(block, roll);
357+
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
358+
let offsetidxs = _mm256_add_epi8(hi_nibbles, _mm256_and_si256(eq_singleton, _mm256_set1_epi8(0x6)));
359+
let offsetvals = _mm256_shuffle_epi8(offsets, offsetidxs);
360+
let shuffeled = _mm256_add_epi8(block, offsetvals);
267361

268362
let merge_ab_and_bc = _mm256_maddubs_epi16(shuffeled,
269363
_mm256_set1_epi32(0x01400140));
@@ -513,70 +607,6 @@ impl super::Engine for AVX2Encoder {
513607
let mut block: __m256i;
514608
let mut invalid: bool = false;
515609

516-
// Witnesses for the high nibbles:
517-
// 0x0 and 0x1 are never valid, no matter what the low nibble is.
518-
// 0x2 is valid for the characters '+' (0x2B), '/' (0x2F) and '-' (0x2D), depending on the
519-
// alphabet.
520-
// 0x3 contains numerals but the only valid inputs are 0x30 to 0x39, so we need to make
521-
// sure that everything from 0xA to 0xF is rejected.
522-
// 0x4 and 0x5 contain [A-Z] and also the special character '_' (0x5F) from the urlsafe
523-
// alphabet.
524-
// 0x6 and 0x7 contain [a-z].
525-
// 0x7 and 0x8 are never valid; 0x8 or higher especially means invalid ASCII.
526-
//
527-
// We use -0x1 as "always invalid" value so that the low witness has to only return
528-
// something != 0 for the invalid test to trip.
529-
let hi_witness_lut = unsafe { _mm256_setr_epi8(
530-
// 0 1 2 3 4 5 6 7
531-
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
532-
// 8 9 10 11 12 13 14 15
533-
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
534-
// 0 1 2 3 4 5 6 7
535-
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
536-
// 8 9 10 11 12 13 14 15
537-
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
538-
)};
539-
// Witnesses for the low nibbles. The requirements for the given hi witnesses are then:
540-
// // Be invalid if hi is.
541-
// - lo[..] & -0x1 == 1
542-
// // Numerals
543-
// - lo[0..9] & 0x2 == 0
544-
// - lo[10..15] & 0x2 == 1
545-
// // Capitals
546-
// - lo[0] & 0x4 == 1
547-
// - lo[1..] & 0x4 == 0
548-
// - lo[0..10] & 0x8 == 0
549-
// - lo[11..15] & 0x8 == 1
550-
// // Miniscules
551-
// - lo[0] & 0x4 == 1
552-
// - lo[1..] & 0x4 == 1
553-
// - lo[..10] & 0x8 == 1
554-
// - lo[11..15] & 0x8 == 1
555-
// // Special, depending on the alphabet
556-
// // standard
557-
// - lo[15] & 0x1 == 1
558-
// - lo[11] & 0x1 == 1
559-
// // urlsafe
560-
// - lo[13] & 0x1 == 1
561-
// - lo[15] & 0x8
562-
// ASCII has the advantage that A-Z and a-z are 0x20 away from each other so you can use
563-
// the same lo witnesses.
564-
// The easiest way to create these witness tables and what is done here is to use the hi
565-
// witness to select a bit to probe and set the bit in the low witness for valid nibbles in
566-
// that range. E.g. the hi witness sets bit 1 for high nibble 0x2 and bit 3 for 0x4 and
567-
// 0x6, and the lo witness only sets bit 1 for valid inputs with high nibble 0x2 (like
568-
// 0x2F, 0x2B etc.) and bit 3 for valid letters [A-Za-z].
569-
let lo_witness_lut = unsafe { _mm256_setr_epi8(
570-
// 0 1 2 3 4 5 6 7
571-
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
572-
// 8 9 10 11 12 13 14 15
573-
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A,
574-
// 0 1 2 3 4 5 6 7
575-
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
576-
// 8 9 10 11 12 13 14 15
577-
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A
578-
)};
579-
580610
// This will only evaluate to true if we have an input of 33 bytes or more;
581611
// skip_final_bytes is at least input.len() otherwise.
582612
if last_fast_index > 0 {
@@ -599,8 +629,8 @@ impl super::Engine for AVX2Encoder {
599629
unsafe {
600630
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
601631
block = decode(&mut invalid,
602-
lo_witness_lut,
603-
hi_witness_lut,
632+
self.lo_witnesses,
633+
self.hi_witnesses,
604634
self.decode_offsets,
605635
self.singleton_mask,
606636
block);
@@ -640,7 +670,12 @@ impl super::Engine for AVX2Encoder {
640670
let mask_output = _mm256_loadu_si256(MASKLOAD[2..10].as_ptr().cast());
641671

642672
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
643-
block = decode(&mut invalid, lo_witness_lut, hi_witness_lut, self.decode_offsets, self.singleton_mask, block);
673+
block = decode(&mut invalid,
674+
self.lo_witnesses,
675+
self.hi_witnesses,
676+
self.decode_offsets,
677+
self.singleton_mask,
678+
block);
644679

645680
_mm256_maskstore_epi32(output_chunk.as_mut_ptr().cast(), mask_output, block);
646681
}
@@ -691,9 +726,15 @@ impl super::Engine for AVX2Encoder {
691726
let mask_output = _mm256_loadu_si256(output_mask.as_ptr().cast());
692727

693728
block = _mm256_maskload_epi32(input_chunk.as_ptr().cast(), mask_input);
694-
let outblock
695-
= decode_masked(&mut invalid,
696-
lo_witness_lut, hi_witness_lut, self.decode_offsets, self.singleton_mask, mask_input, block);
729+
let outblock = decode_masked(
730+
&mut invalid,
731+
self.lo_witnesses,
732+
self.hi_witnesses,
733+
self.decode_offsets,
734+
self.singleton_mask,
735+
mask_input,
736+
block
737+
);
697738

698739
if invalid {
699740
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));

0 commit comments

Comments
 (0)