Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ jobs:
targets: 'aarch64-unknown-linux-gnu'
env:
CARGO_INCREMENTAL: '1'
- name: Print rustc target cpus
run: rustc --print=target-cpus
- name: Run benchmarks
run: cargo bench
env:
RUSTFLAGS: '-C target-cpu=native'
RUSTFLAGS: '-C target-cpu=native -C target-feature=+ls64'
176 changes: 74 additions & 102 deletions src/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,116 +1,88 @@
use std::arch::aarch64::{
vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8, vst1q_u8,
use std::arch::{
aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, vmaxvq_u8, vorrq_u8, vqtbl4q_u8},
asm,
};

use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS};

/// Four contiguous 16-byte NEON registers (64 B) per loop.
const CHUNK: usize = 64;
/// We now chew 64 bytes (one cache line) at a time.
const CHUNK_SIZE: usize = 64;

#[inline(always)]
unsafe fn ld64b_to_stack(src: *const u8, dst: *mut u8) {
// Loads 64 bytes atomically with LD64B and immediately writes them
// to `dst` with ST64B so the following NEON code can work from L1.
asm!(
// x0 – x7 must be consecutive for LD64B/ST64B; we declare them
// explicitly and ignore the values after the store.
"ld64b x0, [{inptr}]",
"st64b x0, [{outptr}]",
inptr = in(reg) src,
outptr = in(reg) dst,
out("x0") _, out("x1") _, out("x2") _, out("x3") _,
out("x4") _, out("x5") _, out("x6") _, out("x7") _,
options(nostack, preserves_flags)
);
}

pub fn encode_str<S: AsRef<str>>(input: S) -> String {
let s = input.as_ref();
let mut out = Vec::with_capacity(s.len() + 2);
let b = s.as_bytes();
let n = b.len();
out.push(b'"');
let input_str = input.as_ref();
let mut output = Vec::with_capacity(input_str.len() + 2);
let bytes = input_str.as_bytes();
let len = bytes.len();
let writer = &mut output;
writer.push(b'"');

unsafe {
let tbl = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of the escape table
let slash = vdupq_n_u8(b'\\');
let mut i = 0;

while i + CHUNK <= n {
let ptr = b.as_ptr().add(i);

/* ---- L1 prefetch: CHUNK size ahead ---- */
core::arch::asm!("prfm pldl1keep, [{0}, #64]", in(reg) ptr);
/* ------------------------------------------ */

// load 64 B (four q-regs)
let a = vld1q_u8(ptr);
let m1 = vqtbl4q_u8(tbl, a);
let m2 = vceqq_u8(slash, a);

let b2 = vld1q_u8(ptr.add(16));
let m3 = vqtbl4q_u8(tbl, b2);
let m4 = vceqq_u8(slash, b2);

let c = vld1q_u8(ptr.add(32));
let m5 = vqtbl4q_u8(tbl, c);
let m6 = vceqq_u8(slash, c);

let d = vld1q_u8(ptr.add(48));
let m7 = vqtbl4q_u8(tbl, d);
let m8 = vceqq_u8(slash, d);

let mask_1 = vorrq_u8(m1, m2);
let mask_2 = vorrq_u8(m3, m4);
let mask_3 = vorrq_u8(m5, m6);
let mask_4 = vorrq_u8(m7, m8);

let mask_r_1 = vmaxvq_u8(mask_1);
let mask_r_2 = vmaxvq_u8(mask_2);
let mask_r_3 = vmaxvq_u8(mask_3);
let mask_r_4 = vmaxvq_u8(mask_4);

// fast path: nothing needs escaping
if mask_r_1 | mask_r_2 | mask_r_3 | mask_r_4 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr, CHUNK));
i += CHUNK;
continue;
}
let mut tmp: [u8; 16] = core::mem::zeroed();

if mask_r_1 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr, 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_1);
handle_block(&b[i..i + 16], &tmp, &mut out);
}

if mask_r_2 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(16), 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_2);
handle_block(&b[i + 16..i + 32], &tmp, &mut out);
}

if mask_r_3 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(32), 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_3);
handle_block(&b[i + 32..i + 48], &tmp, &mut out);
}

if mask_r_4 == 0 {
out.extend_from_slice(std::slice::from_raw_parts(ptr.add(48), 16));
} else {
vst1q_u8(tmp.as_mut_ptr(), mask_4);
handle_block(&b[i + 48..i + 64], &tmp, &mut out);
let mut start = 0;
let escape_low = vld1q_u8_x4(ESCAPE.as_ptr()); // first 64 B of table
let escape_high = vdupq_n_u8(b'\\');

// === LS64-accelerated main loop =====================================
while start + CHUNK_SIZE <= len {
// 1. Pull 64 bytes to a stack buffer via LD64B/ST64B
let mut block = [0u8; CHUNK_SIZE];
ld64b_to_stack(bytes.as_ptr().add(start), block.as_mut_ptr());

// 2. Process the 64 B in four 16 B slices, **unchanged** logic
let mut slice_idx = 0;
while slice_idx < CHUNK_SIZE {
let chunk_ptr = block.as_ptr().add(slice_idx);
let chunk = vld1q_u8(chunk_ptr);
let low_mask = vqtbl4q_u8(escape_low, chunk);
let high_mask = vceqq_u8(escape_high, chunk);

if vmaxvq_u8(low_mask) == 0 && vmaxvq_u8(high_mask) == 0 {
writer.extend_from_slice(std::slice::from_raw_parts(chunk_ptr, 16));
slice_idx += 16;
continue;
}

// Combine masks and fall back to scalar per-byte handling
let escape_mask = vorrq_u8(low_mask, high_mask);
let mask_arr: [u8; 16] = core::mem::transmute(escape_mask);

for (i, &m) in mask_arr.iter().enumerate() {
let b = *chunk_ptr.add(i);
if m == 0 {
writer.push(b);
} else if m == 0xFF {
writer.extend_from_slice(REVERSE_SOLIDUS);
} else {
let ce = CharEscape::from_escape_table(m, b);
write_char_escape(writer, ce);
}
}
slice_idx += 16;
}

i += CHUNK;
}
if i < n {
encode_str_inner(&b[i..], &mut out);
start += CHUNK_SIZE;
}
}
out.push(b'"');
// SAFETY: we only emit valid UTF-8
unsafe { String::from_utf8_unchecked(out) }
}

#[inline(always)]
unsafe fn handle_block(src: &[u8], mask: &[u8; 16], dst: &mut Vec<u8>) {
for (j, &m) in mask.iter().enumerate() {
let c = src[j];
if m == 0 {
dst.push(c);
} else if m == 0xFF {
dst.extend_from_slice(REVERSE_SOLIDUS);
} else {
let e = CharEscape::from_escape_table(m, c);
write_char_escape(dst, e);
if start < len {
encode_str_inner(&bytes[start..], writer);
}
}

writer.push(b'"');
unsafe { String::from_utf8_unchecked(output) }
}
Loading