1
+ use std:: arch:: x86_64:: {
2
+ __m256i, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_loadu_si256, _mm256_or_si256,
3
+ _mm256_set1_epi8, _mm256_storeu_si256, _mm256_testz_si256,
4
+ } ;
5
+
6
+ use crate :: { encode_str_fallback, ESCAPE , HEX_BYTES , UU } ;
7
+
8
+ /// Four contiguous 32-byte AVX2 registers (128 B) per loop.
9
+ const CHUNK : usize = 128 ;
10
+ /// Distance (in bytes) to prefetch ahead.
11
+ /// Keeping ~4 iterations (4 × CHUNK = 512 B) ahead strikes a good balance
12
+ /// between hiding memory latency and not evicting useful cache lines.
13
+ const PREFETCH_DISTANCE : usize = CHUNK * 4 ;
14
+
15
+ pub fn encode_str < S : AsRef < str > > ( input : S ) -> String {
16
+ let s = input. as_ref ( ) ;
17
+ let mut out = Vec :: with_capacity ( s. len ( ) + 2 ) ;
18
+ let bytes = s. as_bytes ( ) ;
19
+ let n = bytes. len ( ) ;
20
+ out. push ( b'"' ) ;
21
+
22
+ unsafe {
23
+ let slash = _mm256_set1_epi8 ( b'\\' as i8 ) ;
24
+ let quote = _mm256_set1_epi8 ( b'"' as i8 ) ;
25
+ let tab = _mm256_set1_epi8 ( b'\t' as i8 ) ;
26
+ let newline = _mm256_set1_epi8 ( b'\n' as i8 ) ;
27
+ let carriage = _mm256_set1_epi8 ( b'\r' as i8 ) ;
28
+ let backspace = _mm256_set1_epi8 ( 0x08i8 ) ;
29
+ let formfeed = _mm256_set1_epi8 ( 0x0ci8 ) ;
30
+ let ctrl_upper_bound = _mm256_set1_epi8 ( 0x20i8 ) ;
31
+
32
+ let mut i = 0 ;
33
+
34
+ // Re-usable scratch – *uninitialised*, so no memset in the loop.
35
+ #[ allow( invalid_value) ]
36
+ let mut placeholder: [ u8 ; 32 ] = core:: mem:: MaybeUninit :: uninit ( ) . assume_init ( ) ;
37
+
38
+ while i + CHUNK <= n {
39
+ let ptr = bytes. as_ptr ( ) . add ( i) ;
40
+
41
+ // Prefetch data ahead
42
+ #[ cfg( any( target_arch = "x86" , target_arch = "x86_64" ) ) ]
43
+ {
44
+ core:: arch:: x86_64:: _mm_prefetch (
45
+ ptr. add ( PREFETCH_DISTANCE ) as * const i8 ,
46
+ core:: arch:: x86_64:: _MM_HINT_T0,
47
+ ) ;
48
+ }
49
+
50
+ // Load 128 bytes (four 32-byte chunks)
51
+ let a = _mm256_loadu_si256 ( ptr as * const __m256i ) ;
52
+ let b = _mm256_loadu_si256 ( ptr. add ( 32 ) as * const __m256i ) ;
53
+ let c = _mm256_loadu_si256 ( ptr. add ( 64 ) as * const __m256i ) ;
54
+ let d = _mm256_loadu_si256 ( ptr. add ( 96 ) as * const __m256i ) ;
55
+
56
+ // For each chunk, check if it needs escaping
57
+ let mask_1 = process_chunk (
58
+ a, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
59
+ ) ;
60
+ let mask_2 = process_chunk (
61
+ b, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
62
+ ) ;
63
+ let mask_3 = process_chunk (
64
+ c, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
65
+ ) ;
66
+ let mask_4 = process_chunk (
67
+ d, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
68
+ ) ;
69
+
70
+ // Check if any chunk needs escaping
71
+ let any_escape = _mm256_or_si256 (
72
+ _mm256_or_si256 ( mask_1, mask_2) ,
73
+ _mm256_or_si256 ( mask_3, mask_4) ,
74
+ ) ;
75
+
76
+ // Fast path: nothing needs escaping
77
+ if _mm256_testz_si256 ( any_escape, any_escape) != 0 {
78
+ out. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, CHUNK ) ) ;
79
+ i += CHUNK ;
80
+ continue ;
81
+ }
82
+
83
+ // Slow path: handle each 32-byte chunk
84
+ macro_rules! handle {
85
+ ( $mask: expr, $off: expr) => {
86
+ if _mm256_testz_si256( $mask, $mask) != 0 {
87
+ // No escapes in this chunk
88
+ out. extend_from_slice( std:: slice:: from_raw_parts( ptr. add( $off) , 32 ) ) ;
89
+ } else {
90
+ // Store mask and process byte by byte
91
+ _mm256_storeu_si256( placeholder. as_mut_ptr( ) as * mut __m256i, $mask) ;
92
+ handle_block( & bytes[ i + $off..i + $off + 32 ] , & placeholder, & mut out) ;
93
+ }
94
+ } ;
95
+ }
96
+
97
+ handle ! ( mask_1, 0 ) ;
98
+ handle ! ( mask_2, 32 ) ;
99
+ handle ! ( mask_3, 64 ) ;
100
+ handle ! ( mask_4, 96 ) ;
101
+
102
+ i += CHUNK ;
103
+ }
104
+
105
+ // Handle remaining bytes using the fallback
106
+ if i < n {
107
+ let remaining_str = std:: str:: from_utf8 ( & bytes[ i..] ) . unwrap ( ) ;
108
+ let escaped = encode_str_fallback ( remaining_str) ;
109
+ // Remove the quotes that encode_str_fallback adds
110
+ let escaped_bytes = escaped. as_bytes ( ) ;
111
+ out. extend_from_slice ( & escaped_bytes[ 1 ..escaped_bytes. len ( ) - 1 ] ) ;
112
+ }
113
+ }
114
+ out. push ( b'"' ) ;
115
+ // SAFETY: we only emit valid UTF-8
116
+ unsafe { String :: from_utf8_unchecked ( out) }
117
+ }
118
+
119
+ #[ inline( always) ]
120
+ unsafe fn process_chunk (
121
+ data : __m256i ,
122
+ slash : __m256i ,
123
+ quote : __m256i ,
124
+ tab : __m256i ,
125
+ newline : __m256i ,
126
+ carriage : __m256i ,
127
+ backspace : __m256i ,
128
+ formfeed : __m256i ,
129
+ ctrl_upper_bound : __m256i ,
130
+ ) -> __m256i {
131
+ // Check for each special character
132
+ let slash_mask = _mm256_cmpeq_epi8 ( data, slash) ;
133
+ let quote_mask = _mm256_cmpeq_epi8 ( data, quote) ;
134
+ let tab_mask = _mm256_cmpeq_epi8 ( data, tab) ;
135
+ let newline_mask = _mm256_cmpeq_epi8 ( data, newline) ;
136
+ let carriage_mask = _mm256_cmpeq_epi8 ( data, carriage) ;
137
+ let backspace_mask = _mm256_cmpeq_epi8 ( data, backspace) ;
138
+ let formfeed_mask = _mm256_cmpeq_epi8 ( data, formfeed) ;
139
+
140
+ // Check for control characters (< 0x20)
141
+ // Note: AVX2 doesn't have unsigned comparison, so we use signed comparison
142
+ // This works because ASCII control characters are all < 0x20 (positive signed values)
143
+ let ctrl_mask = _mm256_cmpgt_epi8 ( ctrl_upper_bound, data) ;
144
+
145
+ // Combine all masks
146
+ let combined = _mm256_or_si256 (
147
+ _mm256_or_si256 (
148
+ _mm256_or_si256 ( slash_mask, quote_mask) ,
149
+ _mm256_or_si256 ( tab_mask, newline_mask) ,
150
+ ) ,
151
+ _mm256_or_si256 (
152
+ _mm256_or_si256 ( carriage_mask, backspace_mask) ,
153
+ _mm256_or_si256 ( formfeed_mask, ctrl_mask) ,
154
+ ) ,
155
+ ) ;
156
+
157
+ combined
158
+ }
159
+
160
+ #[ inline( always) ]
161
+ unsafe fn handle_block ( src : & [ u8 ] , mask : & [ u8 ; 32 ] , dst : & mut Vec < u8 > ) {
162
+ for ( j, & m) in mask. iter ( ) . enumerate ( ) {
163
+ let c = src[ j] ;
164
+ if m == 0 {
165
+ dst. push ( c) ;
166
+ } else {
167
+ let escape_byte = ESCAPE [ c as usize ] ;
168
+ if escape_byte != 0 {
169
+ // Handle the escape
170
+ dst. push ( b'\\' ) ;
171
+ if escape_byte == UU {
172
+ // Unicode escape for control characters
173
+ dst. extend_from_slice ( b"u00" ) ;
174
+ let hex_digits = & HEX_BYTES [ c as usize ] ;
175
+ dst. push ( hex_digits. 0 ) ;
176
+ dst. push ( hex_digits. 1 ) ;
177
+ } else {
178
+ // Simple escape
179
+ dst. push ( escape_byte) ;
180
+ }
181
+ } else if c == b'\\' {
182
+ // Backslash needs escaping
183
+ dst. extend_from_slice ( b"\\ \\ " ) ;
184
+ } else {
185
+ // Should not happen if mask is correct
186
+ dst. push ( c) ;
187
+ }
188
+ }
189
+ }
190
+ }
0 commit comments