Skip to content

Commit e9440e9

Browse files
committed
avx512
1 parent cbe264d commit e9440e9

File tree

2 files changed

+262
-10
lines changed

2 files changed

+262
-10
lines changed

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ pub fn encode_str<S: AsRef<str>>(input: S) -> String {
137137
#[cfg(target_arch = "x86_64")]
138138
{
139139
// Runtime CPU feature detection for x86_64
140-
if is_x86_feature_detected!("avx2") {
140+
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
141+
unsafe { return x86::encode_str_avx512(input) }
142+
} else if is_x86_feature_detected!("avx2") {
141143
unsafe { return x86::encode_str_avx2(input) }
142144
} else if is_x86_feature_detected!("sse2") {
143145
unsafe { return x86::encode_str_sse2(input) }

src/x86.rs

Lines changed: 259 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::arch::x86_64::{
2-
__m128i, __m256i, _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
2+
__m128i, __m256i, __m512i, _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
33
_mm256_loadu_si256, _mm256_movemask_epi8, _mm256_or_si256, _mm256_set1_epi8,
4+
_mm512_add_epi8, _mm512_cmpeq_epi8_mask, _mm512_cmpgt_epi8_mask, _mm512_load_si512,
5+
_mm512_loadu_si512, _mm512_set1_epi8,
46
_mm_add_epi8, _mm_cmpeq_epi8, _mm_cmpgt_epi8, _mm_load_si128, _mm_loadu_si128,
57
_mm_movemask_epi8, _mm_or_si128, _mm_prefetch, _mm_set1_epi8, _MM_HINT_T0,
68
};
@@ -13,9 +15,11 @@ const BELOW_A: i8 = i8::MAX - (31i8 - 0i8) - 1;
1315
const B: i8 = 34i8; // '"'
1416
const C: i8 = 92i8; // '\\'
1517

18+
const M512_VECTOR_SIZE: usize = std::mem::size_of::<__m512i>();
1619
const M256_VECTOR_SIZE: usize = std::mem::size_of::<__m256i>();
1720
const M128_VECTOR_SIZE: usize = std::mem::size_of::<__m128i>();
18-
const LOOP_SIZE: usize = 4 * M256_VECTOR_SIZE; // Process 128 bytes at a time
21+
const LOOP_SIZE_AVX2: usize = 4 * M256_VECTOR_SIZE; // Process 128 bytes at a time
22+
const LOOP_SIZE_AVX512: usize = 4 * M512_VECTOR_SIZE; // Process 256 bytes at a time
1923
const PREFETCH_DISTANCE: usize = 256; // Prefetch 256 bytes ahead
2024

2125
#[inline(always)]
@@ -24,6 +28,212 @@ fn sub(a: *const u8, b: *const u8) -> usize {
2428
(a as usize) - (b as usize)
2529
}
2630

31+
#[target_feature(enable = "avx512f", enable = "avx512bw")]
32+
pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
33+
let s = input.as_ref();
34+
let bytes = s.as_bytes();
35+
let len = bytes.len();
36+
37+
// Pre-allocate with estimated capacity
38+
let estimated_capacity = len + len / 2 + 2;
39+
let mut result = Vec::with_capacity(estimated_capacity);
40+
41+
result.push(b'"');
42+
43+
let start_ptr = bytes.as_ptr();
44+
let end_ptr = bytes[len..].as_ptr();
45+
let mut ptr = start_ptr;
46+
let mut start = 0;
47+
48+
if len >= M512_VECTOR_SIZE {
49+
let v_translation_a = _mm512_set1_epi8(TRANSLATION_A);
50+
let v_below_a = _mm512_set1_epi8(BELOW_A);
51+
let v_b = _mm512_set1_epi8(B);
52+
let v_c = _mm512_set1_epi8(C);
53+
54+
// Handle alignment - skip if already aligned
55+
const M512_VECTOR_ALIGN: usize = M512_VECTOR_SIZE - 1;
56+
let misalignment = start_ptr as usize & M512_VECTOR_ALIGN;
57+
if misalignment != 0 {
58+
let align = M512_VECTOR_SIZE - misalignment;
59+
let a = _mm512_loadu_si512(ptr as *const __m512i);
60+
61+
// Check for quotes, backslash, and control characters
62+
let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b);
63+
let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c);
64+
let ctrl_mask = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a, v_translation_a), v_below_a);
65+
66+
let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64;
67+
68+
if mask != 0 {
69+
let at = sub(ptr, start_ptr);
70+
let mut cur = mask.trailing_zeros() as usize;
71+
while cur < align {
72+
let c = *ptr.add(cur);
73+
let escape_byte = ESCAPE[c as usize];
74+
if escape_byte != 0 {
75+
let i = at + cur;
76+
if start < i {
77+
result.extend_from_slice(&bytes[start..i]);
78+
}
79+
write_escape(&mut result, escape_byte, c);
80+
start = i + 1;
81+
}
82+
mask ^= 1 << cur;
83+
if mask == 0 {
84+
break;
85+
}
86+
cur = mask.trailing_zeros() as usize;
87+
}
88+
}
89+
ptr = ptr.add(align);
90+
}
91+
92+
// Main loop processing 256 bytes at a time
93+
if LOOP_SIZE_AVX512 <= len {
94+
while ptr <= end_ptr.sub(LOOP_SIZE_AVX512) {
95+
debug_assert_eq!(0, (ptr as usize) % M512_VECTOR_SIZE);
96+
97+
// Prefetch next iteration's data
98+
if ptr.add(LOOP_SIZE_AVX512 + PREFETCH_DISTANCE) < end_ptr {
99+
_mm_prefetch(ptr.add(LOOP_SIZE_AVX512 + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
100+
}
101+
102+
// Load all 4 vectors at once for better pipelining
103+
let a0 = _mm512_load_si512(ptr as *const __m512i);
104+
let a1 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE) as *const __m512i);
105+
let a2 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE * 2) as *const __m512i);
106+
let a3 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE * 3) as *const __m512i);
107+
108+
// Check for quotes (") in all vectors
109+
let quote_0 = _mm512_cmpeq_epi8_mask(a0, v_b);
110+
let quote_1 = _mm512_cmpeq_epi8_mask(a1, v_b);
111+
let quote_2 = _mm512_cmpeq_epi8_mask(a2, v_b);
112+
let quote_3 = _mm512_cmpeq_epi8_mask(a3, v_b);
113+
114+
// Check for backslash (\) in all vectors
115+
let slash_0 = _mm512_cmpeq_epi8_mask(a0, v_c);
116+
let slash_1 = _mm512_cmpeq_epi8_mask(a1, v_c);
117+
let slash_2 = _mm512_cmpeq_epi8_mask(a2, v_c);
118+
let slash_3 = _mm512_cmpeq_epi8_mask(a3, v_c);
119+
120+
// Check for control characters (< 0x20) in all vectors
121+
let ctrl_0 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a0, v_translation_a), v_below_a);
122+
let ctrl_1 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a1, v_translation_a), v_below_a);
123+
let ctrl_2 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a2, v_translation_a), v_below_a);
124+
let ctrl_3 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a3, v_translation_a), v_below_a);
125+
126+
// Combine all masks
127+
let mask_a = quote_0 | slash_0 | ctrl_0;
128+
let mask_b = quote_1 | slash_1 | ctrl_1;
129+
let mask_c = quote_2 | slash_2 | ctrl_2;
130+
let mask_d = quote_3 | slash_3 | ctrl_3;
131+
132+
// Fast path: check if any escaping needed
133+
let any_escape = mask_a | mask_b | mask_c | mask_d;
134+
135+
if any_escape == 0 {
136+
// No escapes needed, copy whole chunk
137+
if start < sub(ptr, start_ptr) {
138+
result.extend_from_slice(&bytes[start..sub(ptr, start_ptr)]);
139+
}
140+
result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE_AVX512));
141+
start = sub(ptr, start_ptr) + LOOP_SIZE_AVX512;
142+
} else {
143+
// Process each 64-byte chunk that has escapes
144+
process_mask_avx512(ptr, start_ptr, &mut result, &mut start, bytes, mask_a, 0);
145+
process_mask_avx512(ptr, start_ptr, &mut result, &mut start, bytes, mask_b, M512_VECTOR_SIZE);
146+
process_mask_avx512(ptr, start_ptr, &mut result, &mut start, bytes, mask_c, M512_VECTOR_SIZE * 2);
147+
process_mask_avx512(ptr, start_ptr, &mut result, &mut start, bytes, mask_d, M512_VECTOR_SIZE * 3);
148+
}
149+
150+
ptr = ptr.add(LOOP_SIZE_AVX512);
151+
}
152+
}
153+
154+
// Process remaining aligned chunks
155+
while ptr <= end_ptr.sub(M512_VECTOR_SIZE) {
156+
debug_assert_eq!(0, (ptr as usize) % M512_VECTOR_SIZE);
157+
let a = _mm512_load_si512(ptr as *const __m512i);
158+
159+
let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b);
160+
let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c);
161+
let ctrl_mask = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a, v_translation_a), v_below_a);
162+
163+
let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64;
164+
165+
if mask != 0 {
166+
let at = sub(ptr, start_ptr);
167+
let mut cur = mask.trailing_zeros() as usize;
168+
loop {
169+
let c = *ptr.add(cur);
170+
let escape_byte = ESCAPE[c as usize];
171+
if escape_byte != 0 {
172+
let i = at + cur;
173+
if start < i {
174+
result.extend_from_slice(&bytes[start..i]);
175+
}
176+
write_escape(&mut result, escape_byte, c);
177+
start = i + 1;
178+
}
179+
mask ^= 1 << cur;
180+
if mask == 0 {
181+
break;
182+
}
183+
cur = mask.trailing_zeros() as usize;
184+
}
185+
}
186+
ptr = ptr.add(M512_VECTOR_SIZE);
187+
}
188+
189+
// Handle tail
190+
if ptr < end_ptr {
191+
let d = M512_VECTOR_SIZE - sub(end_ptr, ptr);
192+
let a = _mm512_loadu_si512(ptr.sub(d) as *const __m512i);
193+
194+
let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b);
195+
let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c);
196+
let ctrl_mask = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a, v_translation_a), v_below_a);
197+
198+
let mut mask = ((quote_mask | slash_mask | ctrl_mask) as u64)
199+
.wrapping_shr(d as u32);
200+
201+
if mask != 0 {
202+
let at = sub(ptr, start_ptr);
203+
let mut cur = mask.trailing_zeros() as usize;
204+
loop {
205+
let c = *ptr.add(cur);
206+
let escape_byte = ESCAPE[c as usize];
207+
if escape_byte != 0 {
208+
let i = at + cur;
209+
if start < i {
210+
result.extend_from_slice(&bytes[start..i]);
211+
}
212+
write_escape(&mut result, escape_byte, c);
213+
start = i + 1;
214+
}
215+
mask ^= 1 << cur;
216+
if mask == 0 {
217+
break;
218+
}
219+
cur = mask.trailing_zeros() as usize;
220+
}
221+
}
222+
}
223+
} else {
224+
// Fall back to AVX2 for small strings
225+
return encode_str_avx2(input);
226+
}
227+
228+
// Copy any remaining bytes
229+
if start < len {
230+
result.extend_from_slice(&bytes[start..]);
231+
}
232+
233+
result.push(b'"');
234+
unsafe { String::from_utf8_unchecked(result) }
235+
}
236+
27237
#[target_feature(enable = "avx2")]
28238
pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
29239
let s = input.as_ref();
@@ -85,13 +295,13 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
85295
}
86296

87297
// Main loop processing 128 bytes at a time
88-
if LOOP_SIZE <= len {
89-
while ptr <= end_ptr.sub(LOOP_SIZE) {
298+
if LOOP_SIZE_AVX2 <= len {
299+
while ptr <= end_ptr.sub(LOOP_SIZE_AVX2) {
90300
debug_assert_eq!(0, (ptr as usize) % M256_VECTOR_SIZE);
91301

92302
// Prefetch next iteration's data
93-
if ptr.add(LOOP_SIZE + PREFETCH_DISTANCE) < end_ptr {
94-
_mm_prefetch(ptr.add(LOOP_SIZE + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
303+
if ptr.add(LOOP_SIZE_AVX2 + PREFETCH_DISTANCE) < end_ptr {
304+
_mm_prefetch(ptr.add(LOOP_SIZE_AVX2 + PREFETCH_DISTANCE) as *const i8, _MM_HINT_T0);
95305
}
96306

97307
// Load all 4 vectors at once for better pipelining
@@ -135,8 +345,8 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
135345
if start < sub(ptr, start_ptr) {
136346
result.extend_from_slice(&bytes[start..sub(ptr, start_ptr)]);
137347
}
138-
result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE));
139-
start = sub(ptr, start_ptr) + LOOP_SIZE;
348+
result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE_AVX2));
349+
start = sub(ptr, start_ptr) + LOOP_SIZE_AVX2;
140350
} else {
141351
// Get individual masks only when needed
142352
let mask_a = _mm256_movemask_epi8(cmp_a);
@@ -151,7 +361,7 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
151361
process_mask_avx(ptr, start_ptr, &mut result, &mut start, bytes, mask_d, M256_VECTOR_SIZE * 3);
152362
}
153363

154-
ptr = ptr.add(LOOP_SIZE);
364+
ptr = ptr.add(LOOP_SIZE_AVX2);
155365
}
156366
}
157367

@@ -433,6 +643,46 @@ unsafe fn process_mask_avx(
433643
}
434644
}
435645

646+
#[inline(always)]
647+
unsafe fn process_mask_avx512(
648+
ptr: *const u8,
649+
start_ptr: *const u8,
650+
result: &mut Vec<u8>,
651+
start: &mut usize,
652+
bytes: &[u8],
653+
mask: u64,
654+
offset: usize,
655+
) {
656+
if mask == 0 {
657+
return;
658+
}
659+
660+
let ptr = ptr.add(offset);
661+
let at = sub(ptr, start_ptr);
662+
663+
// Process mask bits using bit manipulation
664+
let mut remaining = mask;
665+
while remaining != 0 {
666+
let cur = remaining.trailing_zeros() as usize;
667+
let c = *ptr.add(cur);
668+
let escape_byte = ESCAPE[c as usize];
669+
670+
if escape_byte != 0 {
671+
let i = at + cur;
672+
// Copy unescaped portion if needed
673+
if *start < i {
674+
result.extend_from_slice(&bytes[*start..i]);
675+
}
676+
// Write escape sequence
677+
write_escape(result, escape_byte, c);
678+
*start = i + 1;
679+
}
680+
681+
// Clear the lowest set bit
682+
remaining &= remaining - 1;
683+
}
684+
}
685+
436686
#[inline(always)]
437687
fn write_escape(result: &mut Vec<u8>, escape_byte: u8, c: u8) {
438688
result.push(b'\\');

0 commit comments

Comments
 (0)