Skip to content

Commit 5b9a02a

Browse files
committed
optimize avx512
1 parent e9440e9 commit 5b9a02a

File tree

1 file changed

+56
-72
lines changed

1 file changed

+56
-72
lines changed

src/x86.rs

Lines changed: 56 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use std::arch::x86_64::{
22
__m128i, __m256i, __m512i, _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
33
_mm256_loadu_si256, _mm256_movemask_epi8, _mm256_or_si256, _mm256_set1_epi8,
4-
_mm512_add_epi8, _mm512_cmpeq_epi8_mask, _mm512_cmpgt_epi8_mask, _mm512_load_si512,
5-
_mm512_loadu_si512, _mm512_set1_epi8,
4+
_mm512_cmpeq_epi8_mask, _mm512_cmplt_epu8_mask, _mm512_load_si512, _mm512_loadu_si512,
5+
_mm512_set1_epi8,
66
_mm_add_epi8, _mm_cmpeq_epi8, _mm_cmpgt_epi8, _mm_load_si128, _mm_loadu_si128,
77
_mm_movemask_epi8, _mm_or_si128, _mm_prefetch, _mm_set1_epi8, _MM_HINT_T0,
88
};
@@ -46,10 +46,9 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
4646
let mut start = 0;
4747

4848
if len >= M512_VECTOR_SIZE {
49-
let v_translation_a = _mm512_set1_epi8(TRANSLATION_A);
50-
let v_below_a = _mm512_set1_epi8(BELOW_A);
5149
let v_b = _mm512_set1_epi8(B);
5250
let v_c = _mm512_set1_epi8(C);
51+
let v_ctrl_limit = _mm512_set1_epi8(0x20);
5352

5453
// Handle alignment - skip if already aligned
5554
const M512_VECTOR_ALIGN: usize = M512_VECTOR_SIZE - 1;
@@ -61,29 +60,27 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
6160
// Check for quotes, backslash, and control characters
6261
let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b);
6362
let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c);
64-
let ctrl_mask = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a, v_translation_a), v_below_a);
63+
let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit);
6564

6665
let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64;
66+
if align < 64 {
67+
mask &= (1u64 << align) - 1;
68+
}
6769

6870
if mask != 0 {
6971
let at = sub(ptr, start_ptr);
70-
let mut cur = mask.trailing_zeros() as usize;
71-
while cur < align {
72+
while mask != 0 {
73+
let cur = mask.trailing_zeros() as usize;
7274
let c = *ptr.add(cur);
7375
let escape_byte = ESCAPE[c as usize];
74-
if escape_byte != 0 {
75-
let i = at + cur;
76-
if start < i {
77-
result.extend_from_slice(&bytes[start..i]);
78-
}
79-
write_escape(&mut result, escape_byte, c);
80-
start = i + 1;
76+
debug_assert!(escape_byte != 0);
77+
let i = at + cur;
78+
if start < i {
79+
result.extend_from_slice(&bytes[start..i]);
8180
}
82-
mask ^= 1 << cur;
83-
if mask == 0 {
84-
break;
85-
}
86-
cur = mask.trailing_zeros() as usize;
81+
write_escape(&mut result, escape_byte, c);
82+
start = i + 1;
83+
mask &= mask - 1;
8784
}
8885
}
8986
ptr = ptr.add(align);
@@ -118,10 +115,10 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
118115
let slash_3 = _mm512_cmpeq_epi8_mask(a3, v_c);
119116

120117
// Check for control characters (< 0x20) in all vectors
121-
let ctrl_0 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a0, v_translation_a), v_below_a);
122-
let ctrl_1 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a1, v_translation_a), v_below_a);
123-
let ctrl_2 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a2, v_translation_a), v_below_a);
124-
let ctrl_3 = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a3, v_translation_a), v_below_a);
118+
let ctrl_0 = _mm512_cmplt_epu8_mask(a0, v_ctrl_limit);
119+
let ctrl_1 = _mm512_cmplt_epu8_mask(a1, v_ctrl_limit);
120+
let ctrl_2 = _mm512_cmplt_epu8_mask(a2, v_ctrl_limit);
121+
let ctrl_3 = _mm512_cmplt_epu8_mask(a3, v_ctrl_limit);
125122

126123
// Combine all masks
127124
let mask_a = quote_0 | slash_0 | ctrl_0;
@@ -158,29 +155,24 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
158155

159156
let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b);
160157
let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c);
161-
let ctrl_mask = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a, v_translation_a), v_below_a);
158+
let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit);
162159

163160
let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64;
164161

165162
if mask != 0 {
166163
let at = sub(ptr, start_ptr);
167-
let mut cur = mask.trailing_zeros() as usize;
168-
loop {
164+
while mask != 0 {
165+
let cur = mask.trailing_zeros() as usize;
169166
let c = *ptr.add(cur);
170167
let escape_byte = ESCAPE[c as usize];
171-
if escape_byte != 0 {
172-
let i = at + cur;
173-
if start < i {
174-
result.extend_from_slice(&bytes[start..i]);
175-
}
176-
write_escape(&mut result, escape_byte, c);
177-
start = i + 1;
178-
}
179-
mask ^= 1 << cur;
180-
if mask == 0 {
181-
break;
168+
debug_assert!(escape_byte != 0);
169+
let i = at + cur;
170+
if start < i {
171+
result.extend_from_slice(&bytes[start..i]);
182172
}
183-
cur = mask.trailing_zeros() as usize;
173+
write_escape(&mut result, escape_byte, c);
174+
start = i + 1;
175+
mask &= mask - 1;
184176
}
185177
}
186178
ptr = ptr.add(M512_VECTOR_SIZE);
@@ -193,30 +185,25 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
193185

194186
let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b);
195187
let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c);
196-
let ctrl_mask = _mm512_cmpgt_epi8_mask(_mm512_add_epi8(a, v_translation_a), v_below_a);
188+
let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit);
197189

198190
let mut mask = ((quote_mask | slash_mask | ctrl_mask) as u64)
199191
.wrapping_shr(d as u32);
200192

201193
if mask != 0 {
202194
let at = sub(ptr, start_ptr);
203-
let mut cur = mask.trailing_zeros() as usize;
204-
loop {
195+
while mask != 0 {
196+
let cur = mask.trailing_zeros() as usize;
205197
let c = *ptr.add(cur);
206198
let escape_byte = ESCAPE[c as usize];
207-
if escape_byte != 0 {
208-
let i = at + cur;
209-
if start < i {
210-
result.extend_from_slice(&bytes[start..i]);
211-
}
212-
write_escape(&mut result, escape_byte, c);
213-
start = i + 1;
199+
debug_assert!(escape_byte != 0);
200+
let i = at + cur;
201+
if start < i {
202+
result.extend_from_slice(&bytes[start..i]);
214203
}
215-
mask ^= 1 << cur;
216-
if mask == 0 {
217-
break;
218-
}
219-
cur = mask.trailing_zeros() as usize;
204+
write_escape(&mut result, escape_byte, c);
205+
start = i + 1;
206+
mask &= mask - 1;
220207
}
221208
}
222209
}
@@ -626,17 +613,16 @@ unsafe fn process_mask_avx(
626613
let cur = remaining.trailing_zeros() as usize;
627614
let c = *ptr.add(cur);
628615
let escape_byte = ESCAPE[c as usize];
616+
debug_assert!(escape_byte != 0);
629617

630-
if escape_byte != 0 {
631-
let i = at + cur;
632-
// Copy unescaped portion if needed
633-
if *start < i {
634-
result.extend_from_slice(&bytes[*start..i]);
635-
}
636-
// Write escape sequence
637-
write_escape(result, escape_byte, c);
638-
*start = i + 1;
618+
let i = at + cur;
619+
// Copy unescaped portion if needed
620+
if *start < i {
621+
result.extend_from_slice(&bytes[*start..i]);
639622
}
623+
// Write escape sequence
624+
write_escape(result, escape_byte, c);
625+
*start = i + 1;
640626

641627
// Clear the lowest set bit
642628
remaining &= remaining - 1;
@@ -666,17 +652,16 @@ unsafe fn process_mask_avx512(
666652
let cur = remaining.trailing_zeros() as usize;
667653
let c = *ptr.add(cur);
668654
let escape_byte = ESCAPE[c as usize];
655+
debug_assert!(escape_byte != 0);
669656

670-
if escape_byte != 0 {
671-
let i = at + cur;
672-
// Copy unescaped portion if needed
673-
if *start < i {
674-
result.extend_from_slice(&bytes[*start..i]);
675-
}
676-
// Write escape sequence
677-
write_escape(result, escape_byte, c);
678-
*start = i + 1;
657+
let i = at + cur;
658+
// Copy unescaped portion if needed
659+
if *start < i {
660+
result.extend_from_slice(&bytes[*start..i]);
679661
}
662+
// Write escape sequence
663+
write_escape(result, escape_byte, c);
664+
*start = i + 1;
680665

681666
// Clear the lowest set bit
682667
remaining &= remaining - 1;
@@ -697,4 +682,3 @@ fn write_escape(result: &mut Vec<u8>, escape_byte: u8, c: u8) {
697682
result.push(escape_byte);
698683
}
699684
}
700-

0 commit comments

Comments
 (0)