Skip to content

Commit 8ab1b3c

Browse files
committed
sve2
1 parent 4a5eaab commit 8ab1b3c

File tree

2 files changed

+128
-68
lines changed

2 files changed

+128
-68
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ jobs:
2121
- name: Run benchmarks
2222
run: cargo bench
2323
env:
24-
RUSTFLAGS: '-C target-cpu=native'
24+
RUSTFLAGS: '-C target-cpu=neoverse-n2 -C target-feature=+sve2,+ls64'

src/aarch64.rs

Lines changed: 127 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,86 @@
1-
use std::arch::aarch64::{
2-
vceqq_u8, vdupq_n_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8,
1+
use std::arch::{
2+
aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8},
3+
asm, is_aarch64_feature_detected,
34
};
45

56
use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS};
67

7-
/// Four contiguous 16-byte NEON registers (64 B) per loop.
8+
/// Bytes handled per *outer* iteration in the new 8.7 path.
9+
/// (Still 64 B in the NEON fallback.)
810
const CHUNK: usize = 64;
9-
/// Distance (in bytes) to prefetch ahead. Must be a multiple of 8 for PRFM.
10-
/// Keeping ~4 iterations (4 × CHUNK = 256 B) ahead strikes a good balance
11-
/// between hiding memory latency and not evicting useful cache lines.
11+
/// Prefetch distance (works for both paths).
1212
const PREFETCH_DISTANCE: usize = CHUNK * 4;
1313

1414
pub fn encode_str<S: AsRef<str>>(input: S) -> String {
1515
let s = input.as_ref();
1616
let mut out = Vec::with_capacity(s.len() + 2);
17-
let bytes = s.as_bytes();
18-
let n = bytes.len();
17+
let b = s.as_bytes();
18+
let n = b.len();
1919
out.push(b'"');
2020

2121
unsafe {
22-
let tbl = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of the escape table
23-
let slash = vdupq_n_u8(b'\\');
24-
let mut i = 0;
25-
// Re-usable scratch – *uninitialised*, so no memset in the loop.
26-
// Using MaybeUninit instead of mem::zeroed() prevents the compiler from inserting an implicit memset (observable with -Cllvm-args=-print-after=expand-memcmp).
27-
// This is a proven micro-optimisation in Rust's standard library I/O stack.
2822
#[allow(invalid_value)]
29-
let mut placeholder: [u8; 16] = core::mem::MaybeUninit::uninit().assume_init();
23+
let mut scratch: [u8; 16] = core::mem::MaybeUninit::uninit().assume_init();
24+
25+
let mut i = 0;
26+
27+
/* ------------------------------------------------------------------ */
28+
/* === Arm v8.7 fast path: LS64 + SVE2 =============================== */
29+
30+
let tbl = vld1q_u8_x4(ESCAPE.as_ptr());
31+
let slash = vdupq_n_u8(b'\\');
3032

3133
while i + CHUNK <= n {
32-
let ptr = bytes.as_ptr().add(i);
33-
34-
/* ---- L1 prefetch: PREFETCH_DISTANCE bytes ahead ---- */
35-
core::arch::asm!(
36-
"prfm pldl1keep, [{0}, #{1}]",
37-
in(reg) ptr,
38-
const PREFETCH_DISTANCE,
39-
);
40-
/* ------------------------------------------ */
41-
42-
let quad = vld1q_u8_x4(ptr);
43-
44-
// load 64 B (four q-regs)
45-
let a = quad.0;
46-
let b = quad.1;
47-
let c = quad.2;
48-
let d = quad.3;
49-
50-
let mask_1 = vorrq_u8(vqtbl4q_u8(tbl, a), vceqq_u8(slash, a));
51-
let mask_2 = vorrq_u8(vqtbl4q_u8(tbl, b), vceqq_u8(slash, b));
52-
let mask_3 = vorrq_u8(vqtbl4q_u8(tbl, c), vceqq_u8(slash, c));
53-
let mask_4 = vorrq_u8(vqtbl4q_u8(tbl, d), vceqq_u8(slash, d));
54-
55-
let mask_r_1 = vmaxvq_u8(mask_1);
56-
let mask_r_2 = vmaxvq_u8(mask_2);
57-
let mask_r_3 = vmaxvq_u8(mask_3);
58-
let mask_r_4 = vmaxvq_u8(mask_4);
59-
60-
// fast path: nothing needs escaping
61-
if mask_r_1 | mask_r_2 | mask_r_3 | mask_r_4 == 0 {
62-
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
63-
i += CHUNK;
34+
let ptr = b.as_ptr().add(i);
35+
if is_aarch64_feature_detected!("sve2") {
36+
i += escape_block_sve(ptr, &mut out);
6437
continue;
65-
}
38+
} else {
39+
asm!("prfm pldl1keep, [{0}, #{1}]",
40+
in(reg) ptr, const PREFETCH_DISTANCE);
6641

67-
macro_rules! handle {
68-
($mask:expr, $mask_r:expr, $off:expr) => {
69-
if $mask_r == 0 {
70-
out.extend_from_slice(std::slice::from_raw_parts(ptr.add($off), 16));
71-
} else {
72-
vst1q_u8(placeholder.as_mut_ptr(), $mask);
73-
handle_block(&bytes[i + $off..i + $off + 16], &placeholder, &mut out);
74-
}
75-
};
76-
}
42+
let quad = vld1q_u8_x4(ptr);
43+
let a = quad.0;
44+
let b1 = quad.1;
45+
let c = quad.2;
46+
let d = quad.3;
47+
48+
let m1 = vorrq_u8(vqtbl4q_u8(tbl, a), vceqq_u8(slash, a));
49+
let m2 = vorrq_u8(vqtbl4q_u8(tbl, b1), vceqq_u8(slash, b1));
50+
let m3 = vorrq_u8(vqtbl4q_u8(tbl, c), vceqq_u8(slash, c));
51+
let m4 = vorrq_u8(vqtbl4q_u8(tbl, d), vceqq_u8(slash, d));
7752

78-
handle!(mask_1, mask_r_1, 0);
79-
handle!(mask_2, mask_r_2, 16);
80-
handle!(mask_3, mask_r_3, 32);
81-
handle!(mask_4, mask_r_4, 48);
53+
if vmaxvq_u8(m1) | vmaxvq_u8(m2) | vmaxvq_u8(m3) | vmaxvq_u8(m4) == 0 {
54+
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
55+
i += CHUNK;
56+
continue;
57+
}
8258

83-
i += CHUNK;
59+
macro_rules! handle {
60+
($m:expr,$r:expr,$off:expr) => {
61+
if $r == 0 {
62+
out.extend_from_slice(std::slice::from_raw_parts(ptr.add($off), 16));
63+
} else {
64+
vst1q_u8(scratch.as_mut_ptr(), $m);
65+
handle_block(&b[i + $off..i + $off + 16], &scratch, &mut out);
66+
}
67+
};
68+
}
69+
handle!(m1, vmaxvq_u8(m1), 0);
70+
handle!(m2, vmaxvq_u8(m2), 16);
71+
handle!(m3, vmaxvq_u8(m3), 32);
72+
handle!(m4, vmaxvq_u8(m4), 48);
73+
74+
i += CHUNK;
75+
}
8476
}
77+
/* ------------------------------------------------------------------ */
78+
8579
if i < n {
86-
encode_str_inner(&bytes[i..], &mut out);
80+
encode_str_inner(&b[i..], &mut out);
8781
}
8882
}
8983
out.push(b'"');
90-
// SAFETY: we only emit valid UTF-8
9184
unsafe { String::from_utf8_unchecked(out) }
9285
}
9386

@@ -100,8 +93,75 @@ unsafe fn handle_block(src: &[u8], mask: &[u8; 16], dst: &mut Vec<u8>) {
10093
} else if m == 0xFF {
10194
dst.extend_from_slice(REVERSE_SOLIDUS);
10295
} else {
103-
let e = CharEscape::from_escape_table(m, c);
104-
write_char_escape(dst, e);
96+
write_char_escape(dst, CharEscape::from_escape_table(m, c));
97+
}
98+
}
99+
}
100+
101+
#[inline(always)]
102+
unsafe fn escape_block_sve(ptr: *const u8, dst: &mut Vec<u8>) -> usize {
103+
/* ------------------------------------------------------------------ */
104+
/* One-shot: copy ESCAPE[0..64] into z4-z7 */
105+
/* Each LD1B uses an in-range offset and bumps x9 by 16 bytes. */
106+
core::arch::asm!(
107+
"ptrue p0.b",
108+
"mov x9, {tbl}",
109+
"ld1b z4.b, p0/z, [x9]",
110+
"add x9, x9, #16",
111+
"ld1b z5.b, p0/z, [x9]",
112+
"add x9, x9, #16",
113+
"ld1b z6.b, p0/z, [x9]",
114+
"add x9, x9, #16",
115+
"ld1b z7.b, p0/z, [x9]",
116+
tbl = in(reg) crate::ESCAPE.as_ptr(),
117+
out("x9") _,
118+
options(readonly, nostack, preserves_flags)
119+
);
120+
/* ------------------------------------------------------------------ */
121+
122+
/* 1️⃣ Single-copy 64-byte fetch into L1 */
123+
core::arch::asm!(
124+
"ld64b x0, [{src}]",
125+
src = in(reg) ptr,
126+
out("x0") _, out("x1") _, out("x2") _, out("x3") _,
127+
out("x4") _, out("x5") _, out("x6") _, out("x7") _,
128+
options(nostack)
129+
);
130+
131+
/* 2️⃣ Build escape mask */
132+
let mut mask: u32;
133+
core::arch::asm!(
134+
"ptrue p0.b",
135+
"ld1b z0.b, p0/z, [{src}]",
136+
"tbl z1.b, {{z4.b, z5.b, z6.b, z7.b}}, z0.b",
137+
"dup z2.b, {slash}",
138+
"cmeq z2.b, p0/m, z0.b, z2.b",
139+
"orr z3.b, z1.b, z2.b",
140+
"umaxv {mask:w}, p0, z3.b", // scalar result → wMask
141+
src = in(reg) ptr,
142+
slash = const b'\\',
143+
mask = lateout(reg) mask,
144+
options(preserves_flags, nostack, readonly)
145+
);
146+
147+
if mask == 0 {
148+
dst.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
149+
return CHUNK;
150+
}
151+
152+
/* 3️⃣ Spill z3 and escape the bad bytes */
153+
let mut m = [0u8; CHUNK];
154+
core::arch::asm!("ptrue p0.b", "st1b z3.b, p0, [{buf}]",
155+
buf = in(reg) m.as_mut_ptr(), options(nostack));
156+
for (i, &bit) in m.iter().enumerate() {
157+
let c = *ptr.add(i);
158+
if bit == 0 {
159+
dst.push(c);
160+
} else if bit == 0xFF {
161+
dst.extend_from_slice(crate::REVERSE_SOLIDUS);
162+
} else {
163+
crate::write_char_escape(dst, CharEscape::from_escape_table(bit, c));
105164
}
106165
}
166+
CHUNK
107167
}

0 commit comments

Comments
 (0)