From c5f93df92eb9b35c2af46f94a34d32e28df8b1a5 Mon Sep 17 00:00:00 2001 From: CGQAQ Date: Thu, 1 Aug 2024 16:34:28 +0800 Subject: [PATCH 1/3] WIP: not working yet --- src/lib.rs | 15 +++++++++++- src/x86_64.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 src/x86_64.rs diff --git a/src/lib.rs b/src/lib.rs index a313f7f..7bf06bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,12 @@ pub use aarch64::encode_str; #[cfg(target_arch = "aarch64")] mod aarch64; +#[cfg(target_arch = "x86_64")] +mod x86_64; + +#[cfg(target_arch = "x86_64")] +pub use x86_64::encode_str; + const BB: u8 = b'b'; // \x08 const TT: u8 = b't'; // \x09 const NN: u8 = b'n'; // \x0A @@ -104,7 +110,7 @@ pub fn encode_str_fallback>(input: S) -> String { output } -#[cfg(not(target_arch = "aarch64"))] +#[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))] pub fn encode_str>(input: S) -> String { encode_str_fallback(input) } @@ -196,3 +202,10 @@ fn test_escape_json_string() { fixture ); } + + +#[test] +fn test() { + let x = ESCAPE[b'\\' as usize]; + println!("{x}") +} \ No newline at end of file diff --git a/src/x86_64.rs b/src/x86_64.rs new file mode 100644 index 0000000..66e9b6b --- /dev/null +++ b/src/x86_64.rs @@ -0,0 +1,68 @@ +use std::arch::x86_64::{ + __m128i, + _mm_adds_epu8, + _mm_cmpeq_epi8, _mm_loadu_si128, _mm_set1_epi8, + _mm_shuffle_epi8, _mm_test_all_zeros, +}; + +use std::mem::transmute; + +use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS}; + +const CHUNK_SIZE: usize = 16; + +pub fn encode_str>(input: S) -> String { + let input_str = input.as_ref(); + let mut output = Vec::with_capacity(input_str.len() + 2); + let bytes = input_str.as_bytes(); + let len = bytes.len(); + let writer = &mut output; + writer.push(b'"'); + + // Safety: SIMD instructions + unsafe { + let zero = _mm_set1_epi8(-1); + + let mut start = 0; + while start + CHUNK_SIZE < len { + let next_chunk = start + CHUNK_SIZE; + let current_chunk_slice: &[u8] = &bytes[start..next_chunk]; + let table_low = _mm_loadu_si128(ESCAPE.as_ptr() as *const __m128i); + let table_high = _mm_set1_epi8(transmute::(b'\\')); + let chunk = _mm_loadu_si128(current_chunk_slice.as_ptr() as *const __m128i); + let low_mask = _mm_shuffle_epi8(table_low, chunk); + let high_mask = _mm_cmpeq_epi8(table_high,chunk); + // check every bits of mask is zero + if _mm_test_all_zeros(low_mask, zero) != 0 && _mm_test_all_zeros(high_mask, zero) != 0 { + writer.extend_from_slice(current_chunk_slice); + start = next_chunk; + continue; + } + + // Vector add the masks to get a single mask + // add low_mask and high_mask to get a single mask + let escape_table_mask = _mm_adds_epu8(low_mask, high_mask); + let escape_table_mask_slice = transmute::<__m128i, [u8; 16]>(escape_table_mask); + for (index, value) in escape_table_mask_slice.into_iter().enumerate() { + if value == 0 { + writer.push(bytes[start + index]); + } else if value == 255 { + // value is in the high table mask, which means it's `\` + writer.extend_from_slice(REVERSE_SOLIDUS); + } else { + let char_escape = + CharEscape::from_escape_table(value, current_chunk_slice[index]); + write_char_escape(writer, char_escape); + } + } + start = next_chunk; + } + if start < len { + encode_str_inner(&bytes[start..], writer); + } + } + + writer.push(b'"'); + // Safety: the bytes are valid UTF-8 + unsafe { String::from_utf8_unchecked(output) } +} \ No newline at end of file From f352cc53aa92985bc96426664697f8ec31052472 Mon Sep 17 00:00:00 2001 From: CGQAQ Date: Fri, 2 Aug 2024 11:06:09 +0800 Subject: [PATCH 2/3] working, but too slow --- src/x86_64.rs | 71 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 13 deletions(-) diff --git a/src/x86_64.rs b/src/x86_64.rs index 66e9b6b..3d466d7 100644 --- a/src/x86_64.rs +++ b/src/x86_64.rs @@ -1,10 +1,4 @@ -use std::arch::x86_64::{ - __m128i, - _mm_adds_epu8, - _mm_cmpeq_epi8, _mm_loadu_si128, _mm_set1_epi8, - _mm_shuffle_epi8, _mm_test_all_zeros, -}; - +use std::arch::x86_64::*; use std::mem::transmute; use crate::{encode_str_inner, write_char_escape, CharEscape, ESCAPE, REVERSE_SOLIDUS}; @@ -21,27 +15,39 @@ pub fn encode_str>(input: S) -> String { // Safety: SIMD instructions unsafe { - let zero = _mm_set1_epi8(-1); + let table_low = [ + _mm_loadu_si128(ESCAPE[0..16].as_ptr() as *const __m128i), + _mm_loadu_si128(ESCAPE[16..32].as_ptr() as *const __m128i), + _mm_loadu_si128(ESCAPE[32..48].as_ptr() as *const __m128i), + _mm_loadu_si128(ESCAPE[48..64].as_ptr() as *const __m128i), + ]; + // let ones = _mm_set1_epi8(1); let mut start = 0; while start + CHUNK_SIZE < len { let next_chunk = start + CHUNK_SIZE; let current_chunk_slice: &[u8] = &bytes[start..next_chunk]; - let table_low = _mm_loadu_si128(ESCAPE.as_ptr() as *const __m128i); - let table_high = _mm_set1_epi8(transmute::(b'\\')); + let table_high = _mm_set1_epi8(b'\\' as i8); let chunk = _mm_loadu_si128(current_chunk_slice.as_ptr() as *const __m128i); - let low_mask = _mm_shuffle_epi8(table_low, chunk); + let low_mask = table_lookup_sse42(chunk, table_low); let high_mask = _mm_cmpeq_epi8(table_high,chunk); // check every bits of mask is zero - if _mm_test_all_zeros(low_mask, zero) != 0 && _mm_test_all_zeros(high_mask, zero) != 0 { + if horizontal_add_u8_sse42(low_mask) == 0 && horizontal_add_u8_sse42(high_mask) == 0 { writer.extend_from_slice(current_chunk_slice); start = next_chunk; continue; } + // check every bits of mask is zero + // if _mm_testz_si128(low_mask, ones) == 1 && _mm_testz_si128(high_mask, ones) == 1 { + // writer.extend_from_slice(current_chunk_slice); + // start = next_chunk; + // continue; + // } + // Vector add the masks to get a single mask // add low_mask and high_mask to get a single mask - let escape_table_mask = _mm_adds_epu8(low_mask, high_mask); + let escape_table_mask = _mm_add_epi8(low_mask, high_mask); let escape_table_mask_slice = transmute::<__m128i, [u8; 16]>(escape_table_mask); for (index, value) in escape_table_mask_slice.into_iter().enumerate() { if value == 0 { @@ -65,4 +71,43 @@ pub fn encode_str>(input: S) -> String { writer.push(b'"'); // Safety: the bytes are valid UTF-8 unsafe { String::from_utf8_unchecked(output) } +} + + +fn table_lookup_sse42(indices: __m128i, table: [__m128i; 4]) -> __m128i { + unsafe { + // Compute the lookup results for each 16-byte chunk + let lookup0 = _mm_shuffle_epi8(table[0], indices); + let lookup1 = _mm_shuffle_epi8(table[1], indices); + let lookup2 = _mm_shuffle_epi8(table[2], indices); + let lookup3 = _mm_shuffle_epi8(table[3], indices); + + // Calculate masks to determine which lookup result to use + let cmp0 = _mm_cmplt_epi8(indices, _mm_set1_epi8(16)); + let cmp1 = _mm_and_si128(_mm_cmplt_epi8(indices, _mm_set1_epi8(32)), _mm_cmpgt_epi8(indices, _mm_set1_epi8(15))); + let cmp2 = _mm_and_si128(_mm_cmplt_epi8(indices, _mm_set1_epi8(48)), _mm_cmpgt_epi8(indices, _mm_set1_epi8(31))); + let cmp3 = _mm_cmpgt_epi8(indices, _mm_set1_epi8(47)); + + // Blend the lookup results based on the masks + let result0 = _mm_blendv_epi8(_mm_setzero_si128(), lookup0, cmp0); + let result1 = _mm_blendv_epi8(result0, lookup1, cmp1); + let result2 = _mm_blendv_epi8(result1, lookup2, cmp2); + let final_result = _mm_blendv_epi8(result2, lookup3, cmp3); + + final_result + } +} + +fn horizontal_add_u8_sse42(vector: __m128i) -> u8 { + unsafe { + // Compute the sum of the absolute differences + let sum = _mm_sad_epu8(vector, _mm_setzero_si128()); + + // Extract the sums from the resulting __m128i + let sum_array = std::mem::transmute::<__m128i, [u64; 2]>(sum); + let total_sum = sum_array[0] + sum_array[1]; + + // Cast the result to u8 (sum cannot exceed 255*16 = 4080) + total_sum as u8 + } } \ No newline at end of file From 8ee741b45623ff877f56425a5a87329a69ca7cac Mon Sep 17 00:00:00 2001 From: cgqaq Date: Sat, 3 Aug 2024 13:43:50 +0800 Subject: [PATCH 3/3] remove unused test --- src/lib.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7bf06bc..4837c62 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -202,10 +202,3 @@ fn test_escape_json_string() { fixture ); } - - -#[test] -fn test() { - let x = ESCAPE[b'\\' as usize]; - println!("{x}") -} \ No newline at end of file