Skip to content

Commit 420f790

Browse files
committed
feat: avx2 only
1 parent afddf46 commit 420f790

File tree

4 files changed

+289
-214
lines changed

4 files changed

+289
-214
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,18 @@ jobs:
1212
strategy:
1313
matrix:
1414
settings:
15-
- target: aarch64-unknown-linux-gnu
16-
os: ubuntu-24.04-arm
17-
- target: aarch64-pc-windows-msvc
18-
os: windows-11-arm
19-
- target: aarch64-apple-darwin
20-
os: macos-latest
15+
- target: x86_64-unknown-linux-gnu
16+
os: ubuntu-latest
17+
- target: x86_64-pc-windows-msvc
18+
os: windows-latest
2119
fail-fast: false
2220
runs-on: ${{ matrix.settings.os }}
2321
steps:
24-
- uses: actions/checkout@v4
22+
- uses: actions/checkout@v5
2523
- name: Setup Rust
2624
uses: dtolnay/rust-toolchain@stable
2725
with:
2826
targets: ${{ matrix.settings.target }}
29-
env:
30-
CARGO_INCREMENTAL: '1'
3127
- name: Run benchmarks
3228
run: cargo bench
3329
env:

src/aarch64.rs

Lines changed: 0 additions & 107 deletions
This file was deleted.

src/avx2.rs

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
use std::arch::x86_64::{
2+
__m256i, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_loadu_si256, _mm256_or_si256,
3+
_mm256_set1_epi8, _mm256_storeu_si256, _mm256_testz_si256,
4+
};
5+
6+
use crate::{encode_str_fallback, ESCAPE, HEX_BYTES, UU};
7+
8+
/// Four contiguous 32-byte AVX2 registers (128 B) per loop.
9+
const CHUNK: usize = 128;
10+
/// Distance (in bytes) to prefetch ahead.
11+
/// Keeping ~4 iterations (4 × CHUNK = 512 B) ahead strikes a good balance
12+
/// between hiding memory latency and not evicting useful cache lines.
13+
const PREFETCH_DISTANCE: usize = CHUNK * 4;
14+
15+
pub fn encode_str<S: AsRef<str>>(input: S) -> String {
16+
let s = input.as_ref();
17+
let mut out = Vec::with_capacity(s.len() + 2);
18+
let bytes = s.as_bytes();
19+
let n = bytes.len();
20+
out.push(b'"');
21+
22+
unsafe {
23+
let slash = _mm256_set1_epi8(b'\\' as i8);
24+
let quote = _mm256_set1_epi8(b'"' as i8);
25+
let tab = _mm256_set1_epi8(b'\t' as i8);
26+
let newline = _mm256_set1_epi8(b'\n' as i8);
27+
let carriage = _mm256_set1_epi8(b'\r' as i8);
28+
let backspace = _mm256_set1_epi8(0x08i8);
29+
let formfeed = _mm256_set1_epi8(0x0ci8);
30+
let ctrl_upper_bound = _mm256_set1_epi8(0x20i8);
31+
32+
let mut i = 0;
33+
34+
// Re-usable scratch – *uninitialised*, so no memset in the loop.
35+
#[allow(invalid_value)]
36+
let mut placeholder: [u8; 32] = core::mem::MaybeUninit::uninit().assume_init();
37+
38+
while i + CHUNK <= n {
39+
let ptr = bytes.as_ptr().add(i);
40+
41+
// Prefetch data ahead
42+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
43+
{
44+
core::arch::x86_64::_mm_prefetch(
45+
ptr.add(PREFETCH_DISTANCE) as *const i8,
46+
core::arch::x86_64::_MM_HINT_T0,
47+
);
48+
}
49+
50+
// Load 128 bytes (four 32-byte chunks)
51+
let a = _mm256_loadu_si256(ptr as *const __m256i);
52+
let b = _mm256_loadu_si256(ptr.add(32) as *const __m256i);
53+
let c = _mm256_loadu_si256(ptr.add(64) as *const __m256i);
54+
let d = _mm256_loadu_si256(ptr.add(96) as *const __m256i);
55+
56+
// For each chunk, check if it needs escaping
57+
let mask_1 = process_chunk(
58+
a, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
59+
);
60+
let mask_2 = process_chunk(
61+
b, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
62+
);
63+
let mask_3 = process_chunk(
64+
c, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
65+
);
66+
let mask_4 = process_chunk(
67+
d, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
68+
);
69+
70+
// Check if any chunk needs escaping
71+
let any_escape = _mm256_or_si256(
72+
_mm256_or_si256(mask_1, mask_2),
73+
_mm256_or_si256(mask_3, mask_4),
74+
);
75+
76+
// Fast path: nothing needs escaping
77+
if _mm256_testz_si256(any_escape, any_escape) != 0 {
78+
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
79+
i += CHUNK;
80+
continue;
81+
}
82+
83+
// Slow path: handle each 32-byte chunk
84+
macro_rules! handle {
85+
($mask:expr, $off:expr) => {
86+
if _mm256_testz_si256($mask, $mask) != 0 {
87+
// No escapes in this chunk
88+
out.extend_from_slice(std::slice::from_raw_parts(ptr.add($off), 32));
89+
} else {
90+
// Store mask and process byte by byte
91+
_mm256_storeu_si256(placeholder.as_mut_ptr() as *mut __m256i, $mask);
92+
handle_block(&bytes[i + $off..i + $off + 32], &placeholder, &mut out);
93+
}
94+
};
95+
}
96+
97+
handle!(mask_1, 0);
98+
handle!(mask_2, 32);
99+
handle!(mask_3, 64);
100+
handle!(mask_4, 96);
101+
102+
i += CHUNK;
103+
}
104+
105+
// Handle remaining bytes using the fallback
106+
if i < n {
107+
let remaining_str = std::str::from_utf8(&bytes[i..]).unwrap();
108+
let escaped = encode_str_fallback(remaining_str);
109+
// Remove the quotes that encode_str_fallback adds
110+
let escaped_bytes = escaped.as_bytes();
111+
out.extend_from_slice(&escaped_bytes[1..escaped_bytes.len() - 1]);
112+
}
113+
}
114+
out.push(b'"');
115+
// SAFETY: we only emit valid UTF-8
116+
unsafe { String::from_utf8_unchecked(out) }
117+
}
118+
119+
#[inline(always)]
120+
unsafe fn process_chunk(
121+
data: __m256i,
122+
slash: __m256i,
123+
quote: __m256i,
124+
tab: __m256i,
125+
newline: __m256i,
126+
carriage: __m256i,
127+
backspace: __m256i,
128+
formfeed: __m256i,
129+
ctrl_upper_bound: __m256i,
130+
) -> __m256i {
131+
// Check for each special character
132+
let slash_mask = _mm256_cmpeq_epi8(data, slash);
133+
let quote_mask = _mm256_cmpeq_epi8(data, quote);
134+
let tab_mask = _mm256_cmpeq_epi8(data, tab);
135+
let newline_mask = _mm256_cmpeq_epi8(data, newline);
136+
let carriage_mask = _mm256_cmpeq_epi8(data, carriage);
137+
let backspace_mask = _mm256_cmpeq_epi8(data, backspace);
138+
let formfeed_mask = _mm256_cmpeq_epi8(data, formfeed);
139+
140+
// Check for control characters (< 0x20)
141+
// Note: AVX2 doesn't have unsigned comparison, so we use signed comparison
142+
// This works because ASCII control characters are all < 0x20 (positive signed values)
143+
let ctrl_mask = _mm256_cmpgt_epi8(ctrl_upper_bound, data);
144+
145+
// Combine all masks
146+
let combined = _mm256_or_si256(
147+
_mm256_or_si256(
148+
_mm256_or_si256(slash_mask, quote_mask),
149+
_mm256_or_si256(tab_mask, newline_mask),
150+
),
151+
_mm256_or_si256(
152+
_mm256_or_si256(carriage_mask, backspace_mask),
153+
_mm256_or_si256(formfeed_mask, ctrl_mask),
154+
),
155+
);
156+
157+
combined
158+
}
159+
160+
#[inline(always)]
161+
unsafe fn handle_block(src: &[u8], mask: &[u8; 32], dst: &mut Vec<u8>) {
162+
for (j, &m) in mask.iter().enumerate() {
163+
let c = src[j];
164+
if m == 0 {
165+
dst.push(c);
166+
} else {
167+
let escape_byte = ESCAPE[c as usize];
168+
if escape_byte != 0 {
169+
// Handle the escape
170+
dst.push(b'\\');
171+
if escape_byte == UU {
172+
// Unicode escape for control characters
173+
dst.extend_from_slice(b"u00");
174+
let hex_digits = &HEX_BYTES[c as usize];
175+
dst.push(hex_digits.0);
176+
dst.push(hex_digits.1);
177+
} else {
178+
// Simple escape
179+
dst.push(escape_byte);
180+
}
181+
} else if c == b'\\' {
182+
// Backslash needs escaping
183+
dst.extend_from_slice(b"\\\\");
184+
} else {
185+
// Should not happen if mask is correct
186+
dst.push(c);
187+
}
188+
}
189+
}
190+
}

0 commit comments

Comments
 (0)