11use std:: arch:: x86_64:: {
2- __m128i, __m256i, _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
2+ __m128i, __m256i, __m512i , _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
33 _mm256_loadu_si256, _mm256_movemask_epi8, _mm256_or_si256, _mm256_set1_epi8,
4+ _mm512_add_epi8, _mm512_cmpeq_epi8_mask, _mm512_cmpgt_epi8_mask, _mm512_load_si512,
5+ _mm512_loadu_si512, _mm512_set1_epi8,
46 _mm_add_epi8, _mm_cmpeq_epi8, _mm_cmpgt_epi8, _mm_load_si128, _mm_loadu_si128,
57 _mm_movemask_epi8, _mm_or_si128, _mm_prefetch, _mm_set1_epi8, _MM_HINT_T0,
68} ;
@@ -13,9 +15,11 @@ const BELOW_A: i8 = i8::MAX - (31i8 - 0i8) - 1;
1315const B : i8 = 34i8 ; // '"'
1416const C : i8 = 92i8 ; // '\\'
1517
18+ const M512_VECTOR_SIZE : usize = std:: mem:: size_of :: < __m512i > ( ) ;
1619const M256_VECTOR_SIZE : usize = std:: mem:: size_of :: < __m256i > ( ) ;
1720const M128_VECTOR_SIZE : usize = std:: mem:: size_of :: < __m128i > ( ) ;
18- const LOOP_SIZE : usize = 4 * M256_VECTOR_SIZE ; // Process 128 bytes at a time
21+ const LOOP_SIZE_AVX2 : usize = 4 * M256_VECTOR_SIZE ; // Process 128 bytes at a time
22+ const LOOP_SIZE_AVX512 : usize = 4 * M512_VECTOR_SIZE ; // Process 256 bytes at a time
1923const PREFETCH_DISTANCE : usize = 256 ; // Prefetch 256 bytes ahead
2024
2125#[ inline( always) ]
@@ -24,6 +28,212 @@ fn sub(a: *const u8, b: *const u8) -> usize {
2428 ( a as usize ) - ( b as usize )
2529}
2630
31+ #[ target_feature( enable = "avx512f" , enable = "avx512bw" ) ]
32+ pub unsafe fn encode_str_avx512 < S : AsRef < str > > ( input : S ) -> String {
33+ let s = input. as_ref ( ) ;
34+ let bytes = s. as_bytes ( ) ;
35+ let len = bytes. len ( ) ;
36+
37+ // Pre-allocate with estimated capacity
38+ let estimated_capacity = len + len / 2 + 2 ;
39+ let mut result = Vec :: with_capacity ( estimated_capacity) ;
40+
41+ result. push ( b'"' ) ;
42+
43+ let start_ptr = bytes. as_ptr ( ) ;
44+ let end_ptr = bytes[ len..] . as_ptr ( ) ;
45+ let mut ptr = start_ptr;
46+ let mut start = 0 ;
47+
48+ if len >= M512_VECTOR_SIZE {
49+ let v_translation_a = _mm512_set1_epi8 ( TRANSLATION_A ) ;
50+ let v_below_a = _mm512_set1_epi8 ( BELOW_A ) ;
51+ let v_b = _mm512_set1_epi8 ( B ) ;
52+ let v_c = _mm512_set1_epi8 ( C ) ;
53+
54+ // Handle alignment - skip if already aligned
55+ const M512_VECTOR_ALIGN : usize = M512_VECTOR_SIZE - 1 ;
56+ let misalignment = start_ptr as usize & M512_VECTOR_ALIGN ;
57+ if misalignment != 0 {
58+ let align = M512_VECTOR_SIZE - misalignment;
59+ let a = _mm512_loadu_si512 ( ptr as * const __m512i ) ;
60+
61+ // Check for quotes, backslash, and control characters
62+ let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
63+ let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
64+ let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a) , v_below_a) ;
65+
66+ let mut mask = ( quote_mask | slash_mask | ctrl_mask) as u64 ;
67+
68+ if mask != 0 {
69+ let at = sub ( ptr, start_ptr) ;
70+ let mut cur = mask. trailing_zeros ( ) as usize ;
71+ while cur < align {
72+ let c = * ptr. add ( cur) ;
73+ let escape_byte = ESCAPE [ c as usize ] ;
74+ if escape_byte != 0 {
75+ let i = at + cur;
76+ if start < i {
77+ result. extend_from_slice ( & bytes[ start..i] ) ;
78+ }
79+ write_escape ( & mut result, escape_byte, c) ;
80+ start = i + 1 ;
81+ }
82+ mask ^= 1 << cur;
83+ if mask == 0 {
84+ break ;
85+ }
86+ cur = mask. trailing_zeros ( ) as usize ;
87+ }
88+ }
89+ ptr = ptr. add ( align) ;
90+ }
91+
92+ // Main loop processing 256 bytes at a time
93+ if LOOP_SIZE_AVX512 <= len {
94+ while ptr <= end_ptr. sub ( LOOP_SIZE_AVX512 ) {
95+ debug_assert_eq ! ( 0 , ( ptr as usize ) % M512_VECTOR_SIZE ) ;
96+
97+ // Prefetch next iteration's data
98+ if ptr. add ( LOOP_SIZE_AVX512 + PREFETCH_DISTANCE ) < end_ptr {
99+ _mm_prefetch ( ptr. add ( LOOP_SIZE_AVX512 + PREFETCH_DISTANCE ) as * const i8 , _MM_HINT_T0) ;
100+ }
101+
102+ // Load all 4 vectors at once for better pipelining
103+ let a0 = _mm512_load_si512 ( ptr as * const __m512i ) ;
104+ let a1 = _mm512_load_si512 ( ptr. add ( M512_VECTOR_SIZE ) as * const __m512i ) ;
105+ let a2 = _mm512_load_si512 ( ptr. add ( M512_VECTOR_SIZE * 2 ) as * const __m512i ) ;
106+ let a3 = _mm512_load_si512 ( ptr. add ( M512_VECTOR_SIZE * 3 ) as * const __m512i ) ;
107+
108+ // Check for quotes (") in all vectors
109+ let quote_0 = _mm512_cmpeq_epi8_mask ( a0, v_b) ;
110+ let quote_1 = _mm512_cmpeq_epi8_mask ( a1, v_b) ;
111+ let quote_2 = _mm512_cmpeq_epi8_mask ( a2, v_b) ;
112+ let quote_3 = _mm512_cmpeq_epi8_mask ( a3, v_b) ;
113+
114+ // Check for backslash (\) in all vectors
115+ let slash_0 = _mm512_cmpeq_epi8_mask ( a0, v_c) ;
116+ let slash_1 = _mm512_cmpeq_epi8_mask ( a1, v_c) ;
117+ let slash_2 = _mm512_cmpeq_epi8_mask ( a2, v_c) ;
118+ let slash_3 = _mm512_cmpeq_epi8_mask ( a3, v_c) ;
119+
120+ // Check for control characters (< 0x20) in all vectors
121+ let ctrl_0 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a0, v_translation_a) , v_below_a) ;
122+ let ctrl_1 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a1, v_translation_a) , v_below_a) ;
123+ let ctrl_2 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a2, v_translation_a) , v_below_a) ;
124+ let ctrl_3 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a3, v_translation_a) , v_below_a) ;
125+
126+ // Combine all masks
127+ let mask_a = quote_0 | slash_0 | ctrl_0;
128+ let mask_b = quote_1 | slash_1 | ctrl_1;
129+ let mask_c = quote_2 | slash_2 | ctrl_2;
130+ let mask_d = quote_3 | slash_3 | ctrl_3;
131+
132+ // Fast path: check if any escaping needed
133+ let any_escape = mask_a | mask_b | mask_c | mask_d;
134+
135+ if any_escape == 0 {
136+ // No escapes needed, copy whole chunk
137+ if start < sub ( ptr, start_ptr) {
138+ result. extend_from_slice ( & bytes[ start..sub ( ptr, start_ptr) ] ) ;
139+ }
140+ result. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, LOOP_SIZE_AVX512 ) ) ;
141+ start = sub ( ptr, start_ptr) + LOOP_SIZE_AVX512 ;
142+ } else {
143+ // Process each 64-byte chunk that has escapes
144+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_a, 0 ) ;
145+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_b, M512_VECTOR_SIZE ) ;
146+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_c, M512_VECTOR_SIZE * 2 ) ;
147+ process_mask_avx512 ( ptr, start_ptr, & mut result, & mut start, bytes, mask_d, M512_VECTOR_SIZE * 3 ) ;
148+ }
149+
150+ ptr = ptr. add ( LOOP_SIZE_AVX512 ) ;
151+ }
152+ }
153+
154+ // Process remaining aligned chunks
155+ while ptr <= end_ptr. sub ( M512_VECTOR_SIZE ) {
156+ debug_assert_eq ! ( 0 , ( ptr as usize ) % M512_VECTOR_SIZE ) ;
157+ let a = _mm512_load_si512 ( ptr as * const __m512i ) ;
158+
159+ let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
160+ let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
161+ let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a) , v_below_a) ;
162+
163+ let mut mask = ( quote_mask | slash_mask | ctrl_mask) as u64 ;
164+
165+ if mask != 0 {
166+ let at = sub ( ptr, start_ptr) ;
167+ let mut cur = mask. trailing_zeros ( ) as usize ;
168+ loop {
169+ let c = * ptr. add ( cur) ;
170+ let escape_byte = ESCAPE [ c as usize ] ;
171+ if escape_byte != 0 {
172+ let i = at + cur;
173+ if start < i {
174+ result. extend_from_slice ( & bytes[ start..i] ) ;
175+ }
176+ write_escape ( & mut result, escape_byte, c) ;
177+ start = i + 1 ;
178+ }
179+ mask ^= 1 << cur;
180+ if mask == 0 {
181+ break ;
182+ }
183+ cur = mask. trailing_zeros ( ) as usize ;
184+ }
185+ }
186+ ptr = ptr. add ( M512_VECTOR_SIZE ) ;
187+ }
188+
189+ // Handle tail
190+ if ptr < end_ptr {
191+ let d = M512_VECTOR_SIZE - sub ( end_ptr, ptr) ;
192+ let a = _mm512_loadu_si512 ( ptr. sub ( d) as * const __m512i ) ;
193+
194+ let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
195+ let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
196+ let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a) , v_below_a) ;
197+
198+ let mut mask = ( ( quote_mask | slash_mask | ctrl_mask) as u64 )
199+ . wrapping_shr ( d as u32 ) ;
200+
201+ if mask != 0 {
202+ let at = sub ( ptr, start_ptr) ;
203+ let mut cur = mask. trailing_zeros ( ) as usize ;
204+ loop {
205+ let c = * ptr. add ( cur) ;
206+ let escape_byte = ESCAPE [ c as usize ] ;
207+ if escape_byte != 0 {
208+ let i = at + cur;
209+ if start < i {
210+ result. extend_from_slice ( & bytes[ start..i] ) ;
211+ }
212+ write_escape ( & mut result, escape_byte, c) ;
213+ start = i + 1 ;
214+ }
215+ mask ^= 1 << cur;
216+ if mask == 0 {
217+ break ;
218+ }
219+ cur = mask. trailing_zeros ( ) as usize ;
220+ }
221+ }
222+ }
223+ } else {
224+ // Fall back to AVX2 for small strings
225+ return encode_str_avx2 ( input) ;
226+ }
227+
228+ // Copy any remaining bytes
229+ if start < len {
230+ result. extend_from_slice ( & bytes[ start..] ) ;
231+ }
232+
233+ result. push ( b'"' ) ;
234+ unsafe { String :: from_utf8_unchecked ( result) }
235+ }
236+
27237#[ target_feature( enable = "avx2" ) ]
28238pub unsafe fn encode_str_avx2 < S : AsRef < str > > ( input : S ) -> String {
29239 let s = input. as_ref ( ) ;
@@ -85,13 +295,13 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
85295 }
86296
87297 // Main loop processing 128 bytes at a time
88- if LOOP_SIZE <= len {
89- while ptr <= end_ptr. sub ( LOOP_SIZE ) {
298+ if LOOP_SIZE_AVX2 <= len {
299+ while ptr <= end_ptr. sub ( LOOP_SIZE_AVX2 ) {
90300 debug_assert_eq ! ( 0 , ( ptr as usize ) % M256_VECTOR_SIZE ) ;
91301
92302 // Prefetch next iteration's data
93- if ptr. add ( LOOP_SIZE + PREFETCH_DISTANCE ) < end_ptr {
94- _mm_prefetch ( ptr. add ( LOOP_SIZE + PREFETCH_DISTANCE ) as * const i8 , _MM_HINT_T0) ;
303+ if ptr. add ( LOOP_SIZE_AVX2 + PREFETCH_DISTANCE ) < end_ptr {
304+ _mm_prefetch ( ptr. add ( LOOP_SIZE_AVX2 + PREFETCH_DISTANCE ) as * const i8 , _MM_HINT_T0) ;
95305 }
96306
97307 // Load all 4 vectors at once for better pipelining
@@ -135,8 +345,8 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
135345 if start < sub ( ptr, start_ptr) {
136346 result. extend_from_slice ( & bytes[ start..sub ( ptr, start_ptr) ] ) ;
137347 }
138- result. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, LOOP_SIZE ) ) ;
139- start = sub ( ptr, start_ptr) + LOOP_SIZE ;
348+ result. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, LOOP_SIZE_AVX2 ) ) ;
349+ start = sub ( ptr, start_ptr) + LOOP_SIZE_AVX2 ;
140350 } else {
141351 // Get individual masks only when needed
142352 let mask_a = _mm256_movemask_epi8 ( cmp_a) ;
@@ -151,7 +361,7 @@ pub unsafe fn encode_str_avx2<S: AsRef<str>>(input: S) -> String {
151361 process_mask_avx ( ptr, start_ptr, & mut result, & mut start, bytes, mask_d, M256_VECTOR_SIZE * 3 ) ;
152362 }
153363
154- ptr = ptr. add ( LOOP_SIZE ) ;
364+ ptr = ptr. add ( LOOP_SIZE_AVX2 ) ;
155365 }
156366 }
157367
@@ -433,6 +643,46 @@ unsafe fn process_mask_avx(
433643 }
434644}
435645
646+ #[ inline( always) ]
647+ unsafe fn process_mask_avx512 (
648+ ptr : * const u8 ,
649+ start_ptr : * const u8 ,
650+ result : & mut Vec < u8 > ,
651+ start : & mut usize ,
652+ bytes : & [ u8 ] ,
653+ mask : u64 ,
654+ offset : usize ,
655+ ) {
656+ if mask == 0 {
657+ return ;
658+ }
659+
660+ let ptr = ptr. add ( offset) ;
661+ let at = sub ( ptr, start_ptr) ;
662+
663+ // Process mask bits using bit manipulation
664+ let mut remaining = mask;
665+ while remaining != 0 {
666+ let cur = remaining. trailing_zeros ( ) as usize ;
667+ let c = * ptr. add ( cur) ;
668+ let escape_byte = ESCAPE [ c as usize ] ;
669+
670+ if escape_byte != 0 {
671+ let i = at + cur;
672+ // Copy unescaped portion if needed
673+ if * start < i {
674+ result. extend_from_slice ( & bytes[ * start..i] ) ;
675+ }
676+ // Write escape sequence
677+ write_escape ( result, escape_byte, c) ;
678+ * start = i + 1 ;
679+ }
680+
681+ // Clear the lowest set bit
682+ remaining &= remaining - 1 ;
683+ }
684+ }
685+
436686#[ inline( always) ]
437687fn write_escape ( result : & mut Vec < u8 > , escape_byte : u8 , c : u8 ) {
438688 result. push ( b'\\' ) ;
0 commit comments