Skip to content

Commit 9a628d7

Browse files
committed
Armv8.7 ST64B
1 parent 35c21dd commit 9a628d7

File tree

1 file changed

+74
-102
lines changed

1 file changed

+74
-102
lines changed

src/aarch64.rs

Lines changed: 74 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,88 @@
1-
use std::arch::aarch64::{
2-
vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8,
1+
use std::arch::{
2+
aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8},
3+
asm,
34
};
45

56
use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS};
67

7-
/// Four contiguous 16-byte NEON registers (64 B) per loop.
8-
const CHUNK: usize = 64;
8+
/// We now chew 64 bytes (one cache line) at a time.
9+
const CHUNK_SIZE: usize = 64;
10+
11+
#[inline(always)]
12+
unsafe fn ld64b_to_stack(src: *const u8, dst: *mut u8) {
13+
// Loads 64 bytes atomically with LD64B and immediately writes them
14+
// to `dst` with ST64B so the following NEON code can work from L1.
15+
asm!(
16+
// x0 – x7 must be consecutive for LD64B/ST64B; we declare them
17+
// explicitly and ignore the values after the store.
18+
"ld64b x0, [{inptr}]",
19+
"st64b x0, [{outptr}]",
20+
inptr = in(reg) src,
21+
outptr = in(reg) dst,
22+
out("x0") _, out("x1") _, out("x2") _, out("x3") _,
23+
out("x4") _, out("x5") _, out("x6") _, out("x7") _,
24+
options(nostack, preserves_flags)
25+
);
26+
}
927

1028
pub fn encode_str<S: AsRef<str>>(input: S) -> String {
11-
let s = input.as_ref();
12-
let mut out = Vec::with_capacity(s.len() + 2);
13-
let b = s.as_bytes();
14-
let n = b.len();
15-
out.push(b'"');
29+
let input_str = input.as_ref();
30+
let mut output = Vec::with_capacity(input_str.len() + 2);
31+
let bytes = input_str.as_bytes();
32+
let len = bytes.len();
33+
let writer = &mut output;
34+
writer.push(b'"');
1635

1736
unsafe {
18-
let tbl = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of the escape table
19-
let slash = vdupq_n_u8(b'\\');
20-
let mut i = 0;
21-
22-
while i + CHUNK <= n {
23-
let ptr = b.as_ptr().add(i);
24-
25-
/* ---- L1 prefetch: CHUNK size ahead ---- */
26-
core::arch::asm!("prfm pldl1keep, [{0}, #64]", in(reg) ptr);
27-
/* ------------------------------------------ */
28-
29-
// load 64 B (four q-regs)
30-
let a = vld1q_u8(ptr);
31-
let m1 = vqtbl4q_u8(tbl, a);
32-
let m2 = vceqq_u8(slash, a);
33-
34-
let b2 = vld1q_u8(ptr.add(16));
35-
let m3 = vqtbl4q_u8(tbl, b2);
36-
let m4 = vceqq_u8(slash, b2);
37-
38-
let c = vld1q_u8(ptr.add(32));
39-
let m5 = vqtbl4q_u8(tbl, c);
40-
let m6 = vceqq_u8(slash, c);
41-
42-
let d = vld1q_u8(ptr.add(48));
43-
let m7 = vqtbl4q_u8(tbl, d);
44-
let m8 = vceqq_u8(slash, d);
45-
46-
let mask_1 = vorrq_u8(m1, m2);
47-
let mask_2 = vorrq_u8(m3, m4);
48-
let mask_3 = vorrq_u8(m5, m6);
49-
let mask_4 = vorrq_u8(m7, m8);
50-
51-
let mask_r_1 = vmaxvq_u8(mask_1);
52-
let mask_r_2 = vmaxvq_u8(mask_2);
53-
let mask_r_3 = vmaxvq_u8(mask_3);
54-
let mask_r_4 = vmaxvq_u8(mask_4);
55-
56-
// fast path: nothing needs escaping
57-
if mask_r_1 | mask_r_2 | mask_r_3 | mask_r_4 == 0 {
58-
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
59-
i += CHUNK;
60-
continue;
61-
}
62-
let mut tmp: [u8; 16] = core::mem::zeroed();
63-
64-
if mask_r_1 == 0 {
65-
out.extend_from_slice(std::slice::from_raw_parts(ptr, 16));
66-
} else {
67-
vst1q_u8(tmp.as_mut_ptr(), mask_1);
68-
handle_block(&b[i..i + 16], &tmp, &mut out);
69-
}
70-
71-
if mask_r_2 == 0 {
72-
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(16), 16));
73-
} else {
74-
vst1q_u8(tmp.as_mut_ptr(), mask_2);
75-
handle_block(&b[i + 16..i + 32], &tmp, &mut out);
76-
}
77-
78-
if mask_r_3 == 0 {
79-
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(32), 16));
80-
} else {
81-
vst1q_u8(tmp.as_mut_ptr(), mask_3);
82-
handle_block(&b[i + 32..i + 48], &tmp, &mut out);
83-
}
84-
85-
if mask_r_4 == 0 {
86-
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(48), 16));
87-
} else {
88-
vst1q_u8(tmp.as_mut_ptr(), mask_4);
89-
handle_block(&b[i + 48..i + 64], &tmp, &mut out);
37+
let mut start = 0;
38+
let escape_low = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of table
39+
let escape_high = vdupq_n_u8(b'\\');
40+
41+
// === LS64-accelerated main loop =====================================
42+
while start + CHUNK_SIZE <= len {
43+
// 1. Pull 64 bytes to a stack buffer via LD64B/ST64B
44+
let mut block = [0u8; CHUNK_SIZE];
45+
ld64b_to_stack(bytes.as_ptr().add(start), block.as_mut_ptr());
46+
47+
// 2. Process the 64 B in four 16 B slices, **unchanged** logic
48+
let mut slice_idx = 0;
49+
while slice_idx < CHUNK_SIZE {
50+
let chunk_ptr = block.as_ptr().add(slice_idx);
51+
let chunk = vld1q_u8(chunk_ptr);
52+
let low_mask = vqtbl4q_u8(escape_low, chunk);
53+
let high_mask = vceqq_u8(escape_high, chunk);
54+
55+
if vmaxvq_u8(low_mask) == 0 && vmaxvq_u8(high_mask) == 0 {
56+
writer.extend_from_slice(std::slice::from_raw_parts(chunk_ptr, 16));
57+
slice_idx += 16;
58+
continue;
59+
}
60+
61+
// Combine masks and fall back to scalar per-byte handling
62+
let escape_mask = vorrq_u8(low_mask, high_mask);
63+
let mask_arr: [u8; 16] = core::mem::transmute(escape_mask);
64+
65+
for (i, &m) in mask_arr.iter().enumerate() {
66+
let b = *chunk_ptr.add(i);
67+
if m == 0 {
68+
writer.push(b);
69+
} else if m == 0xFF {
70+
writer.extend_from_slice(REVERSE_SOLIDUS);
71+
} else {
72+
let ce = CharEscape::from_escape_table(m, b);
73+
write_char_escape(writer, ce);
74+
}
75+
}
76+
slice_idx += 16;
9077
}
91-
92-
i += CHUNK;
93-
}
94-
if i < n {
95-
encode_str_inner(&b[i..], &mut out);
78+
start += CHUNK_SIZE;
9679
}
97-
}
98-
out.push(b'"');
99-
// SAFETY: we only emit valid UTF-8
100-
unsafe { String::from_utf8_unchecked(out) }
101-
}
10280

103-
#[inline(always)]
104-
unsafe fn handle_block(src: &[u8], mask: &[u8; 16], dst: &mut Vec<u8>) {
105-
for (j, &m) in mask.iter().enumerate() {
106-
let c = src[j];
107-
if m == 0 {
108-
dst.push(c);
109-
} else if m == 0xFF {
110-
dst.extend_from_slice(REVERSE_SOLIDUS);
111-
} else {
112-
let e = CharEscape::from_escape_table(m, c);
113-
write_char_escape(dst, e);
81+
if start < len {
82+
encode_str_inner(&bytes[start..], writer);
11483
}
11584
}
85+
86+
writer.push(b'"');
87+
unsafe { String::from_utf8_unchecked(output) }
11688
}

0 commit comments

Comments
 (0)