11use std:: arch:: x86_64:: {
22 __m128i, __m256i, __m512i, _mm256_add_epi8, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_load_si256,
33 _mm256_loadu_si256, _mm256_movemask_epi8, _mm256_or_si256, _mm256_set1_epi8,
4- _mm512_add_epi8 , _mm512_cmpeq_epi8_mask, _mm512_cmpgt_epi8_mask , _mm512_load_si512,
5- _mm512_loadu_si512 , _mm512_set1_epi8,
4+ _mm512_cmpeq_epi8_mask, _mm512_cmplt_epu8_mask , _mm512_load_si512, _mm512_loadu_si512 ,
5+ _mm512_set1_epi8,
66 _mm_add_epi8, _mm_cmpeq_epi8, _mm_cmpgt_epi8, _mm_load_si128, _mm_loadu_si128,
77 _mm_movemask_epi8, _mm_or_si128, _mm_prefetch, _mm_set1_epi8, _MM_HINT_T0,
88} ;
@@ -46,10 +46,9 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
4646 let mut start = 0 ;
4747
4848 if len >= M512_VECTOR_SIZE {
49- let v_translation_a = _mm512_set1_epi8 ( TRANSLATION_A ) ;
50- let v_below_a = _mm512_set1_epi8 ( BELOW_A ) ;
5149 let v_b = _mm512_set1_epi8 ( B ) ;
5250 let v_c = _mm512_set1_epi8 ( C ) ;
51+ let v_ctrl_limit = _mm512_set1_epi8 ( 0x20 ) ;
5352
5453 // Handle alignment - skip if already aligned
5554 const M512_VECTOR_ALIGN : usize = M512_VECTOR_SIZE - 1 ;
@@ -61,29 +60,27 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
6160 // Check for quotes, backslash, and control characters
6261 let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
6362 let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
64- let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a ) , v_below_a ) ;
63+ let ctrl_mask = _mm512_cmplt_epu8_mask ( a, v_ctrl_limit ) ;
6564
6665 let mut mask = ( quote_mask | slash_mask | ctrl_mask) as u64 ;
66+ if align < 64 {
67+ mask &= ( 1u64 << align) - 1 ;
68+ }
6769
6870 if mask != 0 {
6971 let at = sub ( ptr, start_ptr) ;
70- let mut cur = mask . trailing_zeros ( ) as usize ;
71- while cur < align {
72+ while mask != 0 {
73+ let cur = mask . trailing_zeros ( ) as usize ;
7274 let c = * ptr. add ( cur) ;
7375 let escape_byte = ESCAPE [ c as usize ] ;
74- if escape_byte != 0 {
75- let i = at + cur;
76- if start < i {
77- result. extend_from_slice ( & bytes[ start..i] ) ;
78- }
79- write_escape ( & mut result, escape_byte, c) ;
80- start = i + 1 ;
76+ debug_assert ! ( escape_byte != 0 ) ;
77+ let i = at + cur;
78+ if start < i {
79+ result. extend_from_slice ( & bytes[ start..i] ) ;
8180 }
82- mask ^= 1 << cur;
83- if mask == 0 {
84- break ;
85- }
86- cur = mask. trailing_zeros ( ) as usize ;
81+ write_escape ( & mut result, escape_byte, c) ;
82+ start = i + 1 ;
83+ mask &= mask - 1 ;
8784 }
8885 }
8986 ptr = ptr. add ( align) ;
@@ -118,10 +115,10 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
118115 let slash_3 = _mm512_cmpeq_epi8_mask ( a3, v_c) ;
119116
120117 // Check for control characters (< 0x20) in all vectors
121- let ctrl_0 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a0, v_translation_a ) , v_below_a ) ;
122- let ctrl_1 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a1, v_translation_a ) , v_below_a ) ;
123- let ctrl_2 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a2, v_translation_a ) , v_below_a ) ;
124- let ctrl_3 = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a3, v_translation_a ) , v_below_a ) ;
118+ let ctrl_0 = _mm512_cmplt_epu8_mask ( a0, v_ctrl_limit ) ;
119+ let ctrl_1 = _mm512_cmplt_epu8_mask ( a1, v_ctrl_limit ) ;
120+ let ctrl_2 = _mm512_cmplt_epu8_mask ( a2, v_ctrl_limit ) ;
121+ let ctrl_3 = _mm512_cmplt_epu8_mask ( a3, v_ctrl_limit ) ;
125122
126123 // Combine all masks
127124 let mask_a = quote_0 | slash_0 | ctrl_0;
@@ -158,29 +155,24 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
158155
159156 let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
160157 let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
161- let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a ) , v_below_a ) ;
158+ let ctrl_mask = _mm512_cmplt_epu8_mask ( a, v_ctrl_limit ) ;
162159
163160 let mut mask = ( quote_mask | slash_mask | ctrl_mask) as u64 ;
164161
165162 if mask != 0 {
166163 let at = sub ( ptr, start_ptr) ;
167- let mut cur = mask . trailing_zeros ( ) as usize ;
168- loop {
164+ while mask != 0 {
165+ let cur = mask . trailing_zeros ( ) as usize ;
169166 let c = * ptr. add ( cur) ;
170167 let escape_byte = ESCAPE [ c as usize ] ;
171- if escape_byte != 0 {
172- let i = at + cur;
173- if start < i {
174- result. extend_from_slice ( & bytes[ start..i] ) ;
175- }
176- write_escape ( & mut result, escape_byte, c) ;
177- start = i + 1 ;
178- }
179- mask ^= 1 << cur;
180- if mask == 0 {
181- break ;
168+ debug_assert ! ( escape_byte != 0 ) ;
169+ let i = at + cur;
170+ if start < i {
171+ result. extend_from_slice ( & bytes[ start..i] ) ;
182172 }
183- cur = mask. trailing_zeros ( ) as usize ;
173+ write_escape ( & mut result, escape_byte, c) ;
174+ start = i + 1 ;
175+ mask &= mask - 1 ;
184176 }
185177 }
186178 ptr = ptr. add ( M512_VECTOR_SIZE ) ;
@@ -193,30 +185,25 @@ pub unsafe fn encode_str_avx512<S: AsRef<str>>(input: S) -> String {
193185
194186 let quote_mask = _mm512_cmpeq_epi8_mask ( a, v_b) ;
195187 let slash_mask = _mm512_cmpeq_epi8_mask ( a, v_c) ;
196- let ctrl_mask = _mm512_cmpgt_epi8_mask ( _mm512_add_epi8 ( a, v_translation_a ) , v_below_a ) ;
188+ let ctrl_mask = _mm512_cmplt_epu8_mask ( a, v_ctrl_limit ) ;
197189
198190 let mut mask = ( ( quote_mask | slash_mask | ctrl_mask) as u64 )
199191 . wrapping_shr ( d as u32 ) ;
200192
201193 if mask != 0 {
202194 let at = sub ( ptr, start_ptr) ;
203- let mut cur = mask . trailing_zeros ( ) as usize ;
204- loop {
195+ while mask != 0 {
196+ let cur = mask . trailing_zeros ( ) as usize ;
205197 let c = * ptr. add ( cur) ;
206198 let escape_byte = ESCAPE [ c as usize ] ;
207- if escape_byte != 0 {
208- let i = at + cur;
209- if start < i {
210- result. extend_from_slice ( & bytes[ start..i] ) ;
211- }
212- write_escape ( & mut result, escape_byte, c) ;
213- start = i + 1 ;
199+ debug_assert ! ( escape_byte != 0 ) ;
200+ let i = at + cur;
201+ if start < i {
202+ result. extend_from_slice ( & bytes[ start..i] ) ;
214203 }
215- mask ^= 1 << cur;
216- if mask == 0 {
217- break ;
218- }
219- cur = mask. trailing_zeros ( ) as usize ;
204+ write_escape ( & mut result, escape_byte, c) ;
205+ start = i + 1 ;
206+ mask &= mask - 1 ;
220207 }
221208 }
222209 }
@@ -626,17 +613,16 @@ unsafe fn process_mask_avx(
626613 let cur = remaining. trailing_zeros ( ) as usize ;
627614 let c = * ptr. add ( cur) ;
628615 let escape_byte = ESCAPE [ c as usize ] ;
616+ debug_assert ! ( escape_byte != 0 ) ;
629617
630- if escape_byte != 0 {
631- let i = at + cur;
632- // Copy unescaped portion if needed
633- if * start < i {
634- result. extend_from_slice ( & bytes[ * start..i] ) ;
635- }
636- // Write escape sequence
637- write_escape ( result, escape_byte, c) ;
638- * start = i + 1 ;
618+ let i = at + cur;
619+ // Copy unescaped portion if needed
620+ if * start < i {
621+ result. extend_from_slice ( & bytes[ * start..i] ) ;
639622 }
623+ // Write escape sequence
624+ write_escape ( result, escape_byte, c) ;
625+ * start = i + 1 ;
640626
641627 // Clear the lowest set bit
642628 remaining &= remaining - 1 ;
@@ -666,17 +652,16 @@ unsafe fn process_mask_avx512(
666652 let cur = remaining. trailing_zeros ( ) as usize ;
667653 let c = * ptr. add ( cur) ;
668654 let escape_byte = ESCAPE [ c as usize ] ;
655+ debug_assert ! ( escape_byte != 0 ) ;
669656
670- if escape_byte != 0 {
671- let i = at + cur;
672- // Copy unescaped portion if needed
673- if * start < i {
674- result. extend_from_slice ( & bytes[ * start..i] ) ;
675- }
676- // Write escape sequence
677- write_escape ( result, escape_byte, c) ;
678- * start = i + 1 ;
657+ let i = at + cur;
658+ // Copy unescaped portion if needed
659+ if * start < i {
660+ result. extend_from_slice ( & bytes[ * start..i] ) ;
679661 }
662+ // Write escape sequence
663+ write_escape ( result, escape_byte, c) ;
664+ * start = i + 1 ;
680665
681666 // Clear the lowest set bit
682667 remaining &= remaining - 1 ;
@@ -697,4 +682,3 @@ fn write_escape(result: &mut Vec<u8>, escape_byte: u8, c: u8) {
697682 result. push ( escape_byte) ;
698683 }
699684}
700-
0 commit comments