1
1
use std:: arch:: x86_64:: {
2
- __m128i, __m256i, _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
2
+ __m128i, __m256i, __m512i , _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
3
3
_mm256_loadu_si256, _mm256_movemask_epi8, _mm256_or_si256, _mm256_set1_epi8,
4
+ _mm512_add_epi8, _mm512_cmpeq_epi8_mask, _mm512_cmpgt_epi8_mask, _mm512_load_si512,
5
+ _mm512_loadu_si512, _mm512_set1_epi8,
4
6
_mm_add_epi8, _mm_cmpeq_epi8, _mm_cmpgt_epi8, _mm_load_si128, _mm_loadu_si128,
5
7
_mm_movemask_epi8, _mm_or_si128, _mm_prefetch, _mm_set1_epi8, _MM_HINT_T0,
6
8
} ;
@@ -13,9 +15,11 @@ const BELOW_A: i8 = i8::MAX - (31i8 - 0i8) - 1;
13
15
const B : i8 = 34i8 ; // '"'
14
16
const C : i8 = 92i8 ; // '\\'
15
17
18
+ const M512_VECTOR_SIZE : usize = std:: mem:: size_of :: < __m512i > ( ) ;
16
19
const M256_VECTOR_SIZE : usize = std:: mem:: size_of :: < __m256i > ( ) ;
17
20
const M128_VECTOR_SIZE : usize = std:: mem:: size_of :: < __m128i > ( ) ;
18
- const LOOP_SIZE : usize = 4 * M256_VECTOR_SIZE ; // Process 128 bytes at a time
21
+ const LOOP_SIZE_AVX2 : usize = 4 * M256_VECTOR_SIZE ; // Process 128 bytes at a time
22
+ const LOOP_SIZE_AVX512 : usize = 4 * M512_VECTOR_SIZE ; // Process 256 bytes at a time
19
23
const PREFETCH_DISTANCE : usize = 256 ; // Prefetch 256 bytes ahead
20
24
21
25
#[ inline( always) ]
@@ -24,6 +28,212 @@ fn sub(a: *const u8, b: *const u8) -> usize {
24
28
( a as usize ) - ( b as usize )
25
29
}
26
30
31
+ #[ target_feature( enable = "avx512f" , enable = "avx512bw" ) ]
32
+ pub unsafe fn encode_str_avx512 < S : AsRef < str > > ( input : S ) -> String {
33
+ let s = input. as_ref ( ) ;
34
+ let bytes = s. as_bytes ( ) ;
35
+ let len = bytes. len ( ) ;
36
+
37
+ // Pre-allocate with estimated capacity
38
+ let estimated_capacity = len + len / 2 + 2 ;
39
+ let mut result = Vec :: with_capacity ( estimated_capacity) ;
40
+
41
+ result. push ( b'"' ) ;
42
+
43
+ let start_ptr = bytes. as_ptr ( ) ;
44
+ let end_ptr = bytes[ len..] . as_ptr ( ) ;
45
+ let mut ptr = start_ptr;
46
+ let mut start = 0 ;
47
+
48
+ if len >= M512_VECTOR_SIZE {
49
+ let v_translation_a = _mm512_set1_epi8 ( TRANSLATION_A ) ;
50
+ let v_below_a = _mm512_set1_epi8 ( BELOW_A ) ;
51
+ let v_b = _mm512_set1_epi8 ( B ) ;
52
+ let v_c = _mm512_set1_epi8 ( C ) ;
53
+
54
+ // Handle alignment - skip if already aligned
55
+ const M512_VECTOR_ALIGN : usize = M512_VECTOR_SIZE - 1 ;
56
+ let misalignment = start_ptr as usize & M512_VECTOR_ALIGN ;
57
+ if misalignment != 0 {
58
+ let align = M512_VECTOR_SIZE - misalignment;
59
+ let a = _mm512_loadu_si512 ( ptr as * const __m512i ) ;
60
+
61
+ // Check for quotes, backslash, and control characters
62
+ let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
63
+ let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
64
+ let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a) , v_below_a) ;
65
+
66
+ let mut mask = ( quote_mask | slash_mask | ctrl_mask) as u64 ;
67
+
68
+ if mask != 0 {
69
+ let at = sub ( ptr, start_ptr) ;
70
+ let mut cur = mask. trailing_zeros ( ) as usize ;
71
+ while cur < align {
72
+ let c = * ptr. add ( cur) ;
73
+ let escape_byte = ESCAPE [ c as usize ] ;
74
+ if escape_byte != 0 {
75
+ let i = at + cur;
76
+ if start < i {
77
+ result. extend_from_slice ( & bytes[ start..i] ) ;
78
+ }
79
+ write_escape ( & mut result, escape_byte, c) ;
80
+ start = i + 1 ;
81
+ }
82
+ mask ^= 1 << cur;
83
+ if mask == 0 {
84
+ break ;
85
+ }
86
+ cur = mask. trailing_zeros ( ) as usize ;
87
+ }
88
+ }
89
+ ptr = ptr. add ( align) ;
90
+ }
91
+
92
+ // Main loop processing 256 bytes at a time
93
+ if LOOP_SIZE_AVX512 <= len {
94
+ while ptr <= end_ptr. sub ( LOOP_SIZE_AVX512 ) {
95
+ debug_assert_eq ! ( 0 , ( ptr as usize ) % M512_VECTOR_SIZE ) ;
96
+
97
+ // Prefetch next iteration's data
98
+ if ptr. add ( LOOP_SIZE_AVX512 + PREFETCH_DISTANCE ) < end_ptr {
99
+ _mm_prefetch ( ptr. add ( LOOP_SIZE_AVX512 + PREFETCH_DISTANCE ) as * const i8 , _MM_HINT_T0) ;
100
+ }
101
+
102
+ // Load all 4 vectors at once for better pipelining
103
+ let a0 = _mm512_load_si512 ( ptr as * const __m512i ) ;
104
+ let a1 = _mm512_load_si512 ( ptr. add ( M512_VECTOR_SIZE ) as * const __m512i ) ;
105
+ let a2 = _mm512_load_si512 ( ptr. add ( M512_VECTOR_SIZE * 2 ) as * const __m512i ) ;
106
+ let a3 = _mm512_load_si512 ( ptr. add ( M512_VECTOR_SIZE * 3 ) as * const __m512i ) ;
107
+
108
+ // Check for quotes (") in all vectors
109
+ let quote_0 = _mm512_cmpeq_epi8_mask ( a0, v_b) ;
110
+ let quote_1 = _mm512_cmpeq_epi8_mask ( a1, v_b) ;
111
+ let quote_2 = _mm512_cmpeq_epi8_mask ( a2, v_b) ;
112
+ let quote_3 = _mm512_cmpeq_epi8_mask ( a3, v_b) ;
113
+
114
+ // Check for backslash (\) in all vectors
115
+ let slash_0 = _mm512_cmpeq_epi8_mask ( a0, v_c) ;
116
+ let slash_1 = _mm512_cmpeq_epi8_mask ( a1, v_c) ;
117
+ let slash_2 = _mm512_cmpeq_epi8_mask ( a2, v_c) ;
118
+ let slash_3 = _mm512_cmpeq_epi8_mask ( a3, v_c) ;
119
+
120
+ // Check for control characters (< 0x20) in all vectors
121
+ let ctrl_0 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a0, v_translation_a) , v_below_a) ;
122
+ let ctrl_1 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a1, v_translation_a) , v_below_a) ;
123
+ let ctrl_2 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a2, v_translation_a) , v_below_a) ;
124
+ let ctrl_3 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a3, v_translation_a) , v_below_a) ;
125
+
126
+ // Combine all masks
127
+ let mask_a = quote_0 | slash_0 | ctrl_0;
128
+ let mask_b = quote_1 | slash_1 | ctrl_1;
129
+ let mask_c = quote_2 | slash_2 | ctrl_2;
130
+ let mask_d = quote_3 | slash_3 | ctrl_3;
131
+
132
+ // Fast path: check if any escaping needed
133
+ let any_escape = mask_a | mask_b | mask_c | mask_d;
134
+
135
+ if any_escape == 0 {
136
+ // No escapes needed, copy whole chunk
137
+ if start < sub ( ptr, start_ptr) {
138
+ result. extend_from_slice ( & bytes[ start..sub ( ptr, start_ptr) ] ) ;
139
+ }
140
+ result. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, LOOP_SIZE_AVX512 ) ) ;
141
+ start = sub ( ptr, start_ptr) + LOOP_SIZE_AVX512 ;
142
+ } else {
143
+ // Process each 64-byte chunk that has escapes
144
+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_a, 0 ) ;
145
+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_b, M512_VECTOR_SIZE ) ;
146
+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_c, M512_VECTOR_SIZE * 2 ) ;
147
+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_d, M512_VECTOR_SIZE * 3 ) ;
148
+ }
149
+
150
+ ptr = ptr. add ( LOOP_SIZE_AVX512 ) ;
151
+ }
152
+ }
153
+
154
+ // Process remaining aligned chunks
155
+ while ptr <= end_ptr. sub ( M512_VECTOR_SIZE ) {
156
+ debug_assert_eq ! ( 0 , ( ptr as usize ) % M512_VECTOR_SIZE ) ;
157
+ let a = _mm512_load_si512 ( ptr as * const __m512i ) ;
158
+
159
+ let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
160
+ let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
161
+ let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a) , v_below_a) ;
162
+
163
+ let mut mask = ( quote_mask | slash_mask | ctrl_mask) as u64 ;
164
+
165
+ if mask != 0 {
166
+ let at = sub ( ptr, start_ptr) ;
167
+ let mut cur = mask. trailing_zeros ( ) as usize ;
168
+ loop {
169
+ let c = * ptr. add ( cur) ;
170
+ let escape_byte = ESCAPE [ c as usize ] ;
171
+ if escape_byte != 0 {
172
+ let i = at + cur;
173
+ if start < i {
174
+ result. extend_from_slice ( & bytes[ start..i] ) ;
175
+ }
176
+ write_escape ( & mut result, escape_byte, c) ;
177
+ start = i + 1 ;
178
+ }
179
+ mask ^= 1 << cur;
180
+ if mask == 0 {
181
+ break ;
182
+ }
183
+ cur = mask. trailing_zeros ( ) as usize ;
184
+ }
185
+ }
186
+ ptr = ptr. add ( M512_VECTOR_SIZE ) ;
187
+ }
188
+
189
+ // Handle tail
190
+ if ptr < end_ptr {
191
+ let d = M512_VECTOR_SIZE - sub ( end_ptr, ptr) ;
192
+ let a = _mm512_loadu_si512 ( ptr. sub ( d) as * const __m512i ) ;
193
+
194
+ let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
195
+ let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
196
+ let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a) , v_below_a) ;
197
+
198
+ let mut mask = ( ( quote_mask | slash_mask | ctrl_mask) as u64 )
199
+ . wrapping_shr ( d as u32 ) ;
200
+
201
+ if mask != 0 {
202
+ let at = sub ( ptr, start_ptr) ;
203
+ let mut cur = mask. trailing_zeros ( ) as usize ;
204
+ loop {
205
+ let c = * ptr. add ( cur) ;
206
+ let escape_byte = ESCAPE [ c as usize ] ;
207
+ if escape_byte != 0 {
208
+ let i = at + cur;
209
+ if start < i {
210
+ result. extend_from_slice ( & bytes[ start..i] ) ;
211
+ }
212
+ write_escape ( & mut result, escape_byte, c) ;
213
+ start = i + 1 ;
214
+ }
215
+ mask ^= 1 << cur;
216
+ if mask == 0 {
217
+ break ;
218
+ }
219
+ cur = mask. trailing_zeros ( ) as usize ;
220
+ }
221
+ }
222
+ }
223
+ } else {
224
+ // Fall back to AVX2 for small strings
225
+ return encode_str_avx2 ( input) ;
226
+ }
227
+
228
+ // Copy any remaining bytes
229
+ if start < len {
230
+ result. extend_from_slice ( & bytes[ start..] ) ;
231
+ }
232
+
233
+ result. push ( b'"' ) ;
234
+ unsafe { String :: from_utf8_unchecked ( result) }
235
+ }
236
+
27
237
#[ target_feature( enable = "avx2" ) ]
28
238
pub unsafe fn encode_str_avx2 < S : AsRef < str > > ( input : S ) -> String {
29
239
let s = input. as_ref ( ) ;
@@ -85,13 +295,13 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
85
295
}
86
296
87
297
// Main loop processing 128 bytes at a time
88
- if LOOP_SIZE <= len {
89
- while ptr <= end_ptr. sub ( LOOP_SIZE ) {
298
+ if LOOP_SIZE_AVX2 <= len {
299
+ while ptr <= end_ptr. sub ( LOOP_SIZE_AVX2 ) {
90
300
debug_assert_eq ! ( 0 , ( ptr as usize ) % M256_VECTOR_SIZE ) ;
91
301
92
302
// Prefetch next iteration's data
93
- if ptr. add ( LOOP_SIZE + PREFETCH_DISTANCE ) < end_ptr {
94
- _mm_prefetch ( ptr. add ( LOOP_SIZE + PREFETCH_DISTANCE ) as * const i8 , _MM_HINT_T0) ;
303
+ if ptr. add ( LOOP_SIZE_AVX2 + PREFETCH_DISTANCE ) < end_ptr {
304
+ _mm_prefetch ( ptr. add ( LOOP_SIZE_AVX2 + PREFETCH_DISTANCE ) as * const i8 , _MM_HINT_T0) ;
95
305
}
96
306
97
307
// Load all 4 vectors at once for better pipelining
@@ -135,8 +345,8 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
135
345
if start < sub ( ptr, start_ptr) {
136
346
result. extend_from_slice ( & bytes[ start..sub ( ptr, start_ptr) ] ) ;
137
347
}
138
- result. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, LOOP_SIZE ) ) ;
139
- start = sub ( ptr, start_ptr) + LOOP_SIZE ;
348
+ result. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, LOOP_SIZE_AVX2 ) ) ;
349
+ start = sub ( ptr, start_ptr) + LOOP_SIZE_AVX2 ;
140
350
} else {
141
351
// Get individual masks only when needed
142
352
let mask_a = _mm256_movemask_epi8 ( cmp_a) ;
@@ -151,7 +361,7 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
151
361
process_mask_avx ( ptr, start_ptr, & mut result, & mut start, bytes, mask_d, M256_VECTOR_SIZE * 3 ) ;
152
362
}
153
363
154
- ptr = ptr. add ( LOOP_SIZE ) ;
364
+ ptr = ptr. add ( LOOP_SIZE_AVX2 ) ;
155
365
}
156
366
}
157
367
@@ -433,6 +643,46 @@ unsafe fn process_mask_avx(
433
643
}
434
644
}
435
645
646
+ #[ inline( always) ]
647
+ unsafe fn process_mask_avx512 (
648
+ ptr : * const u8 ,
649
+ start_ptr : * const u8 ,
650
+ result : & mut Vec < u8 > ,
651
+ start : & mut usize ,
652
+ bytes : & [ u8 ] ,
653
+ mask : u64 ,
654
+ offset : usize ,
655
+ ) {
656
+ if mask == 0 {
657
+ return ;
658
+ }
659
+
660
+ let ptr = ptr. add ( offset) ;
661
+ let at = sub ( ptr, start_ptr) ;
662
+
663
+ // Process mask bits using bit manipulation
664
+ let mut remaining = mask;
665
+ while remaining != 0 {
666
+ let cur = remaining. trailing_zeros ( ) as usize ;
667
+ let c = * ptr. add ( cur) ;
668
+ let escape_byte = ESCAPE [ c as usize ] ;
669
+
670
+ if escape_byte != 0 {
671
+ let i = at + cur;
672
+ // Copy unescaped portion if needed
673
+ if * start < i {
674
+ result. extend_from_slice ( & bytes[ * start..i] ) ;
675
+ }
676
+ // Write escape sequence
677
+ write_escape ( result, escape_byte, c) ;
678
+ * start = i + 1 ;
679
+ }
680
+
681
+ // Clear the lowest set bit
682
+ remaining &= remaining - 1 ;
683
+ }
684
+ }
685
+
436
686
#[ inline( always) ]
437
687
fn write_escape ( result : & mut Vec < u8 > , escape_byte : u8 , c : u8 ) {
438
688
result. push ( b'\\' ) ;
0 commit comments