Skip to content

Commit c98293b

Browse files
committed
Rewrite AVX2Encoder to not use generics
1 parent c8ff7da commit c98293b

File tree

1 file changed

+95
-116
lines changed

1 file changed

+95
-116
lines changed

src/engine/avx2/mod.rs

Lines changed: 95 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ use crate::engine::Config;
33
use crate::engine::DecodeEstimate;
44
use crate::{DecodeError, PAD_BYTE};
55

6-
use core::marker::PhantomData;
7-
8-
96
#[cfg(target_arch = "x86")]
107
use core::arch::x86::*;
118
#[cfg(target_arch = "x86_64")]
@@ -28,41 +25,75 @@ const DECODED_CHUNK_LEN: usize = 24;
2825
/// - It has to use unsafe code because intrinsics are always unsafe in Rust.
2926
/// - The algorithm in use makes specific assumptions about the alphabet, so it's only implemented
3027
/// for the STANDARD and URL_SAFE Alphabet
31-
pub struct AVX2Encoder<A> {
28+
pub struct AVX2Encoder {
3229
config: AVX2Config,
33-
alp: PhantomData<A>,
34-
}
3530

36-
impl<A> AVX2Encoder<A> {
37-
/// Create an AVX2Encoder from a given config.
38-
///
39-
/// You can either select the Alphabet by defining the type specifically:
40-
/// ```rust
41-
/// let engine: AVX2Encoder<Standard> = AVX2Encoder::from(AVX2Config::default());
42-
/// ```
43-
/// or by calling one of the associated functions [`from_standard`] and [`from_urlsafe`].
44-
pub const fn from(config: AVX2Config) -> Self {
45-
Self {
46-
config,
47-
alp: PhantomData,
48-
}
49-
}
31+
// Alphabet LUT for serial steps
32+
encode_table: [u8; 64],
33+
decode_table: [u8; 256],
34+
35+
// Alphabet LUT for vectorized steps
36+
encode_offsets: __m256i,
37+
decode_offsets: __m256i,
5038
}
51-
impl AVX2Encoder<Standard> {
52-
/// Create an AVX2Encoder for the STANDARD alphabet with the given config.
53-
pub const fn from_standard(config: AVX2Config) -> Self {
39+
40+
impl AVX2Encoder {
41+
/// Create an AVX2Encoder for the standard Alphabet from a given config.
42+
/// You can create one for urlsafe with the associated function [`from_urlsafe`].
43+
pub fn from_standard(config: AVX2Config) -> Self {
44+
let encode_offsets = unsafe {
45+
_mm256_setr_epi8(
46+
71, -4, -4, -4, -4, -4, -4, -4,
47+
-4, -4, -4,-19,-16, 65, 0, 0,
48+
71, -4, -4, -4, -4, -4, -4, -4,
49+
-4, -4, -4,-19,-16, 65, 0, 0,
50+
)
51+
};
52+
53+
let decode_offsets = unsafe {
54+
_mm256_setr_epi8(
55+
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
56+
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
57+
)
58+
};
59+
5460
Self {
5561
config,
56-
alp: PhantomData,
62+
63+
encode_table: ENCODE_TABLE,
64+
decode_table: DECODE_TABLE,
65+
66+
encode_offsets,
67+
decode_offsets,
5768
}
5869
}
59-
}
60-
impl AVX2Encoder<Urlsafe> {
61-
/// Create an AVX2Encoder for the STANDARD alphabet with the given config.
62-
pub const fn from_url_safe(config: AVX2Config) -> Self {
70+
/// Create an AVX2Encoder for the urlsafe alphabet with the given config.
71+
/// You can create one for standard with the associated function [`from_standard`].
72+
pub fn from_url_safe(config: AVX2Config) -> Self {
73+
let encode_offsets = unsafe {
74+
_mm256_setr_epi8(
75+
71, -4, -4, -4, -4, -4, -4, -4,
76+
-4, -4, -4,-17, 32, 65, 0, 0,
77+
71, -4, -4, -4, -4, -4, -4, -4,
78+
-4, -4, -4,-17, 32, 65, 0, 0,
79+
)
80+
};
81+
82+
let decode_offsets = unsafe {
83+
_mm256_setr_epi8(
84+
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
85+
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
86+
)
87+
};
88+
6389
Self {
6490
config,
65-
alp: PhantomData,
91+
92+
encode_table: URL_ENCODE_TABLE,
93+
decode_table: URL_DECODE_TABLE,
94+
95+
encode_offsets,
96+
decode_offsets,
6697
}
6798
}
6899
}
@@ -95,8 +126,9 @@ impl DecodeEstimate for AVX2Estimate {
95126
}
96127

97128

98-
#[inline]
129+
#[inline(always)]
99130
unsafe fn load_block(input: __m256i) -> __m256i {
131+
// TODO: Explain this load shuffle
100132
let i: __m256i = _mm256_shuffle_epi8(input, _mm256_set_epi8(
101133
10, 11, 9, 10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1,
102134
14, 15, 13, 14, 11, 12, 10, 11, 8, 9, 7, 8, 5, 6, 4, 5
@@ -117,6 +149,7 @@ unsafe fn decode(
117149
mask_2f: __m256i,
118150
block: __m256i
119151
) -> __m256i {
152+
// TODO: Explain this decode step
120153
let hi_nibbles = _mm256_srli_epi32(block, 4);
121154
let lo_nibbles = _mm256_and_si256(block, mask_2f);
122155
let lo = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
@@ -146,6 +179,10 @@ unsafe fn decode(
146179
}
147180

148181
#[inline(always)]
182+
/// decode_masked is a version of decode specialized for partial input.
183+
/// The only difference between it and the unmasked version is that the test that checks for
184+
/// invalid bytes (which is `a AND b` over a,b := 256-bit vector) gets the same input mask applied,
185+
/// since `0` bytes would in fact be an invalid input.
149186
unsafe fn decode_masked(
150187
invalid: &mut bool,
151188
lut_lo: __m256i,
@@ -188,65 +225,14 @@ unsafe fn decode_masked(
188225
))
189226
}
190227

191-
#[doc(hidden)]
192-
pub trait AvxAlp: Send + Sync {
193-
unsafe fn encode(input: __m256i) -> __m256i;
194-
fn encode_table() -> &'static [u8; 64];
195-
fn decode_table() -> &'static [u8; 256];
196-
}
197-
198-
#[doc(hidden)]
199-
pub struct Standard;
200-
impl AvxAlp for Standard {
201-
#[inline]
202-
unsafe fn encode(input: __m256i) -> __m256i {
203-
let mut result: __m256i = _mm256_subs_epu8(input, _mm256_set1_epi8(51));
204-
let less: __m256i = _mm256_cmpgt_epi8(_mm256_set1_epi8(26), input);
205-
result = _mm256_or_si256(result, _mm256_and_si256(less, _mm256_set1_epi8(13)));
206-
let offsets: __m256i = _mm256_setr_epi8(
207-
71, -4, -4, -4, -4, -4, -4, -4,
208-
-4, -4, -4,-19,-16, 65, 0, 0,
209-
71, -4, -4, -4, -4, -4, -4, -4,
210-
-4, -4, -4,-19,-16, 65, 0, 0,
211-
);
212-
result = _mm256_shuffle_epi8(offsets, result);
213-
return _mm256_add_epi8(result, input);
214-
}
215-
216-
fn encode_table() -> &'static [u8; 64] {
217-
&ENCODE_TABLE
218-
}
219-
220-
fn decode_table() -> &'static [u8; 256] {
221-
&DECODE_TABLE
222-
}
223-
}
224-
225-
#[doc(hidden)]
226-
pub struct Urlsafe;
227-
impl AvxAlp for Urlsafe {
228-
#[inline]
229-
unsafe fn encode(input: __m256i) -> __m256i {
230-
let mut result: __m256i = _mm256_subs_epu8(input, _mm256_set1_epi8(51));
231-
let less: __m256i = _mm256_cmpgt_epi8(_mm256_set1_epi8(26), input);
232-
result = _mm256_or_si256(result, _mm256_and_si256(less, _mm256_set1_epi8(13)));
233-
let offsets: __m256i = _mm256_setr_epi8(
234-
71, -4, -4, -4, -4, -4, -4, -4,
235-
-4, -4, -4,-17, 32, 65, 0, 0,
236-
71, -4, -4, -4, -4, -4, -4, -4,
237-
-4, -4, -4,-17, 32, 65, 0, 0,
238-
);
239-
result = _mm256_shuffle_epi8(offsets, result);
240-
return _mm256_add_epi8(result, input);
241-
}
242-
243-
fn encode_table() -> &'static [u8; 64] {
244-
&URL_ENCODE_TABLE
245-
}
246228

247-
fn decode_table() -> &'static [u8; 256] {
248-
&URL_DECODE_TABLE
249-
}
229+
#[inline]
230+
unsafe fn encode(offsets: __m256i, input: __m256i) -> __m256i {
231+
let mut result: __m256i = _mm256_subs_epu8(input, _mm256_set1_epi8(51));
232+
let less: __m256i = _mm256_cmpgt_epi8(_mm256_set1_epi8(26), input);
233+
result = _mm256_or_si256(result, _mm256_and_si256(less, _mm256_set1_epi8(13)));
234+
result = _mm256_shuffle_epi8(offsets, result);
235+
return _mm256_add_epi8(result, input);
250236
}
251237

252238
const ENCODE_TABLE: [u8; 64] =
@@ -260,7 +246,7 @@ const URL_DECODE_TABLE: [u8; 256] =
260246

261247
const MASKLOAD: [i32; 16] = [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0];
262248

263-
impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
249+
impl super::Engine for AVX2Encoder {
264250
type Config = AVX2Config;
265251
type DecodeEstimate = AVX2Estimate;
266252

@@ -306,7 +292,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
306292
_mm256_set_epi32(SKIP,LOAD,LOAD,LOAD,LOAD,LOAD,LOAD,SKIP));
307293

308294
let expanded: __m256i = load_block(block);
309-
let outblock: __m256i = A::encode(expanded);
295+
let outblock: __m256i = encode(self.encode_offsets, expanded);
310296
_mm256_storeu_si256(output_chunk.as_mut_ptr().cast(), outblock);
311297

312298
output_index += 32;
@@ -328,7 +314,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
328314
// First step: Expand the 24 input bytes into 32 bytes ready for encoding.
329315
let expanded: __m256i = load_block(block);
330316
// Second step: Do the actual conversion
331-
let outblock: __m256i = A::encode(expanded);
317+
let outblock: __m256i = encode(self.encode_offsets, expanded);
332318
// Third step: Write the data into the output
333319
_mm256_storeu_si256(output_chunk.as_mut_ptr().cast(), outblock);
334320

@@ -356,18 +342,16 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
356342

357343
const LOW_SIX_BITS_U8: u8 = 0b111111;
358344

359-
let encode_table = A::encode_table();
360-
361345
while input_index < start_of_rem {
362346
let input_chunk = &input[input_index..(input_index + 3)];
363347
let output_chunk = &mut output[output_index..(output_index + 4)];
364348

365-
output_chunk[0] = encode_table[(input_chunk[0] >> 2) as usize];
366-
output_chunk[1] = encode_table
349+
output_chunk[0] = self.encode_table[(input_chunk[0] >> 2) as usize];
350+
output_chunk[1] = self.encode_table
367351
[((input_chunk[0] << 4 | input_chunk[1] >> 4) & LOW_SIX_BITS_U8) as usize];
368-
output_chunk[2] = encode_table
352+
output_chunk[2] = self.encode_table
369353
[((input_chunk[1] << 2 | input_chunk[2] >> 6) & LOW_SIX_BITS_U8) as usize];
370-
output_chunk[3] = encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize];
354+
output_chunk[3] = self.encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize];
371355

372356
input_index += 3;
373357
output_index += 4;
@@ -377,18 +361,18 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
377361

378362
if rem == 2 {
379363
let final_input = input.len()-2;
380-
output[output_index] = encode_table[(input[final_input] >> 2) as usize];
364+
output[output_index] = self.encode_table[(input[final_input] >> 2) as usize];
381365
output[output_index + 1] =
382-
encode_table[((input[final_input] << 4 | input[final_input + 1] >> 4)
366+
self.encode_table[((input[final_input] << 4 | input[final_input + 1] >> 4)
383367
& LOW_SIX_BITS_U8) as usize];
384368
output[output_index + 2] =
385-
encode_table[((input[final_input + 1] << 2) & LOW_SIX_BITS_U8) as usize];
369+
self.encode_table[((input[final_input + 1] << 2) & LOW_SIX_BITS_U8) as usize];
386370
output_index += 3;
387371
} else if rem == 1 {
388372
let final_input = input.len()-1;
389-
output[output_index] = encode_table[(input[final_input] >> 2) as usize];
373+
output[output_index] = self.encode_table[(input[final_input] >> 2) as usize];
390374
output[output_index + 1] =
391-
encode_table[((input[final_input] << 4) & LOW_SIX_BITS_U8) as usize];
375+
self.encode_table[((input[final_input] << 4) & LOW_SIX_BITS_U8) as usize];
392376
output_index += 2;
393377
}
394378

@@ -405,14 +389,13 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
405389
output: &mut [u8],
406390
_estimate: Self::DecodeEstimate,
407391
) -> Result<usize, DecodeError> {
408-
let decode_table = A::decode_table();
409392
// TODO: Check if LLVM optimizes this modulo into an &
410393
let skip_stage_2 = match input.len() % 4 {
411394
1 => {
412395
// trailing whitespace is so common that it's worth it to check the last byte to
413396
// possibly return a better error message
414397
if let Some(b) = input.last() {
415-
if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE {
398+
if *b != PAD_BYTE && self.decode_table[*b as usize] == INVALID_VALUE {
416399
return Err(DecodeError::InvalidByte(input.len() - 1, *b));
417400
}
418401
}
@@ -487,10 +470,6 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
487470
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
488471
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10
489472
)};
490-
let lut_roll = unsafe {_mm256_setr_epi8(
491-
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
492-
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
493-
)};
494473
let mask_2f = unsafe { _mm256_set1_epi8(0x2F) };
495474

496475
// This will only evaluate to true if we have an input of 33 bytes or more;
@@ -514,10 +493,10 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
514493

515494
unsafe {
516495
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
517-
block = decode(&mut invalid, lut_lo, lut_hi, lut_roll, mask_2f, block);
496+
block = decode(&mut invalid, lut_lo, lut_hi, self.decode_offsets, mask_2f, block);
518497

519498
if invalid {
520-
return Err(find_invalid_input(input_index, input_chunk, decode_table));
499+
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));
521500
}
522501

523502
_mm256_storeu_si256(output_chunk.as_mut_ptr().cast(), block);
@@ -551,12 +530,12 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
551530
let mask_output = _mm256_loadu_si256(MASKLOAD[2..10].as_ptr().cast());
552531

553532
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
554-
block = decode(&mut invalid, lut_lo, lut_hi, lut_roll, mask_2f, block);
533+
block = decode(&mut invalid, lut_lo, lut_hi, self.decode_offsets, mask_2f, block);
555534

556535
_mm256_maskstore_epi32(output_chunk.as_mut_ptr().cast(), mask_output, block);
557536
}
558537
if invalid {
559-
return Err(find_invalid_input(input_index, input_chunk, decode_table));
538+
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));
560539
}
561540

562541
input_index += 32;
@@ -604,10 +583,10 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
604583
block = _mm256_maskload_epi32(input_chunk.as_ptr().cast(), mask_input);
605584
let outblock
606585
= decode_masked(&mut invalid,
607-
lut_lo, lut_hi, lut_roll, mask_2f, mask_input, block);
586+
lut_lo, lut_hi, self.decode_offsets, mask_2f, mask_input, block);
608587

609588
if invalid {
610-
return Err(find_invalid_input(input_index, input_chunk, decode_table));
589+
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));
611590
}
612591

613592
_mm256_maskstore_epi32(output_chunk.as_mut_ptr().cast(), mask_output, outblock);
@@ -677,7 +656,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
677656
// can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
678657
// Pack the leftovers from left to right.
679658
let shift = 64 - (morsels_in_leftover + 1) * 6;
680-
let morsel = decode_table[*b as usize];
659+
let morsel = self.decode_table[*b as usize];
681660
if morsel == INVALID_VALUE {
682661
return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
683662
}
@@ -729,7 +708,7 @@ impl<A: AvxAlp> super::Engine for AVX2Encoder<A> {
729708
}
730709
}
731710

732-
fn find_invalid_input(input_index: usize, input: &[u8], decode_table: &'static [u8; 256]) -> DecodeError {
711+
fn find_invalid_input(input_index: usize, input: &[u8], decode_table: &[u8; 256]) -> DecodeError {
733712
// Figure out which byte was invalid exactly.
734713
for i in 0..input.len() {
735714
let byte = input[i];

0 commit comments

Comments
 (0)