diff --git a/src/lib.rs b/src/lib.rs index 3e3e1ca..73138f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,12 +127,19 @@ pub fn escape>(input: S) -> String { result.push(b'"'); let s = input.as_ref(); let bytes = s.as_bytes(); + let len = bytes.len(); // Runtime CPU feature detection for x86_64 - if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && len >= x86::LOOP_SIZE_AVX512 + { unsafe { x86::escape_avx512(bytes, &mut result) } - } else if is_x86_feature_detected!("avx2") { + } else if is_x86_feature_detected!("avx2") && len >= x86::LOOP_SIZE_AVX2 { unsafe { x86::escape_avx2(bytes, &mut result) } - } else if is_x86_feature_detected!("sse2") { + } else if is_x86_feature_detected!("sse2") + && /* if len < 128, no need to use simd */ + len >= x86::LOOP_SIZE_AVX2 + { unsafe { x86::escape_sse2(bytes, &mut result) } } else { escape_inner(bytes, &mut result); diff --git a/src/x86.rs b/src/x86.rs index 7b59926..c43ae22 100644 --- a/src/x86.rs +++ b/src/x86.rs @@ -19,9 +19,10 @@ const C: i8 = 92i8; // '\\' const M512_VECTOR_SIZE: usize = std::mem::size_of::<__m512i>(); const M256_VECTOR_SIZE: usize = std::mem::size_of::<__m256i>(); const M128_VECTOR_SIZE: usize = std::mem::size_of::<__m128i>(); -const LOOP_SIZE_AVX2: usize = 4 * M256_VECTOR_SIZE; // Process 128 bytes at a time -const LOOP_SIZE_AVX512: usize = 4 * M512_VECTOR_SIZE; // Process 256 bytes at a time -const PREFETCH_DISTANCE: usize = 256; // Prefetch 256 bytes ahead +pub(crate) const LOOP_SIZE_AVX2: usize = 4 * M256_VECTOR_SIZE; // Process 128 bytes at a time +pub(crate) const LOOP_SIZE_AVX512: usize = 4 * M512_VECTOR_SIZE; // Process 256 bytes at a time +const PREFETCH_DISTANCE_AVX2: usize = 256; // Prefetch 256 bytes ahead for AVX2 +const PREFETCH_DISTANCE_AVX512: usize = 512; // Prefetch 512 bytes ahead for AVX512 #[inline(always)] fn sub(a: *const u8, b: *const u8) -> usize { @@ -39,197 +40,190 @@ pub unsafe fn escape_avx512(bytes: &[u8], result: &mut Vec) { let mut ptr = start_ptr; let mut start = 0; - if len >= M512_VECTOR_SIZE { - let v_b = _mm512_set1_epi8(B); - let v_c = _mm512_set1_epi8(C); - let v_ctrl_limit = _mm512_set1_epi8(0x20); - - // Handle alignment - skip if already aligned - const M512_VECTOR_ALIGN: usize = M512_VECTOR_SIZE - 1; - let misalignment = start_ptr as usize & M512_VECTOR_ALIGN; - if misalignment != 0 { - let align = M512_VECTOR_SIZE - misalignment; - let a = _mm512_loadu_si512(ptr as *const __m512i); - - // Check for quotes, backslash, and control characters - let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b); - let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c); - let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit); - - let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64; - if align < 64 { - mask &= (1u64 << align) - 1; - } + let v_b = _mm512_set1_epi8(B); + let v_c = _mm512_set1_epi8(C); + let v_ctrl_limit = _mm512_set1_epi8(0x20); + + // Handle alignment - skip if already aligned + const M512_VECTOR_ALIGN: usize = M512_VECTOR_SIZE - 1; + let misalignment = start_ptr as usize & M512_VECTOR_ALIGN; + if misalignment != 0 { + let align = M512_VECTOR_SIZE - misalignment; + let a = _mm512_loadu_si512(ptr as *const __m512i); + + // Check for quotes, backslash, and control characters + let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b); + let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c); + let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit); + + let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64; + if align < 64 { + mask &= (1u64 << align) - 1; + } - if mask != 0 { - let at = sub(ptr, start_ptr); - while mask != 0 { - let cur = mask.trailing_zeros() as usize; - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - debug_assert!(escape_byte != 0); - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - mask &= mask - 1; + if mask != 0 { + let at = sub(ptr, start_ptr); + while mask != 0 { + let cur = mask.trailing_zeros() as usize; + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + debug_assert!(escape_byte != 0); + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } + write_escape(result, escape_byte, c); + start = i + 1; + mask &= mask - 1; } - ptr = ptr.add(align); } + ptr = ptr.add(align); + } - // Main loop processing 256 bytes at a time - if LOOP_SIZE_AVX512 <= len { - while ptr <= end_ptr.sub(LOOP_SIZE_AVX512) { - debug_assert_eq!(0, (ptr as usize) % M512_VECTOR_SIZE); - - // Prefetch next iteration's data - if ptr.add(LOOP_SIZE_AVX512 + PREFETCH_DISTANCE) < end_ptr { - _mm_prefetch( - ptr.add(LOOP_SIZE_AVX512 + PREFETCH_DISTANCE) as *const i8, - _MM_HINT_T0, - ); - } + // Main loop processing 256 bytes at a time + while ptr <= end_ptr.sub(LOOP_SIZE_AVX512) { + debug_assert_eq!(0, (ptr as usize) % M512_VECTOR_SIZE); - // Load all 4 vectors at once for better pipelining - let a0 = _mm512_load_si512(ptr as *const __m512i); - let a1 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE) as *const __m512i); - let a2 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE * 2) as *const __m512i); - let a3 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE * 3) as *const __m512i); - - // Check for quotes (") in all vectors - let quote_0 = _mm512_cmpeq_epi8_mask(a0, v_b); - let quote_1 = _mm512_cmpeq_epi8_mask(a1, v_b); - let quote_2 = _mm512_cmpeq_epi8_mask(a2, v_b); - let quote_3 = _mm512_cmpeq_epi8_mask(a3, v_b); - - // Check for backslash (\) in all vectors - let slash_0 = _mm512_cmpeq_epi8_mask(a0, v_c); - let slash_1 = _mm512_cmpeq_epi8_mask(a1, v_c); - let slash_2 = _mm512_cmpeq_epi8_mask(a2, v_c); - let slash_3 = _mm512_cmpeq_epi8_mask(a3, v_c); - - // Check for control characters (< 0x20) in all vectors - let ctrl_0 = _mm512_cmplt_epu8_mask(a0, v_ctrl_limit); - let ctrl_1 = _mm512_cmplt_epu8_mask(a1, v_ctrl_limit); - let ctrl_2 = _mm512_cmplt_epu8_mask(a2, v_ctrl_limit); - let ctrl_3 = _mm512_cmplt_epu8_mask(a3, v_ctrl_limit); - - // Combine all masks - let mask_a = quote_0 | slash_0 | ctrl_0; - let mask_b = quote_1 | slash_1 | ctrl_1; - let mask_c = quote_2 | slash_2 | ctrl_2; - let mask_d = quote_3 | slash_3 | ctrl_3; - - // Fast path: check if any escaping needed - let any_escape = mask_a | mask_b | mask_c | mask_d; - - if any_escape == 0 { - // No escapes needed, copy whole chunk - if start < sub(ptr, start_ptr) { - result.extend_from_slice(&bytes[start..sub(ptr, start_ptr)]); - } - result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE_AVX512)); - start = sub(ptr, start_ptr) + LOOP_SIZE_AVX512; - } else { - // Process each 64-byte chunk that has escapes - process_mask_avx512(ptr, start_ptr, result, &mut start, bytes, mask_a, 0); - process_mask_avx512( - ptr, - start_ptr, - result, - &mut start, - bytes, - mask_b, - M512_VECTOR_SIZE, - ); - process_mask_avx512( - ptr, - start_ptr, - result, - &mut start, - bytes, - mask_c, - M512_VECTOR_SIZE * 2, - ); - process_mask_avx512( - ptr, - start_ptr, - result, - &mut start, - bytes, - mask_d, - M512_VECTOR_SIZE * 3, - ); - } + // Prefetch next iteration's data + if ptr.add(LOOP_SIZE_AVX512 + PREFETCH_DISTANCE_AVX512) < end_ptr { + _mm_prefetch( + ptr.add(LOOP_SIZE_AVX512 + PREFETCH_DISTANCE_AVX512) as *const i8, + _MM_HINT_T0, + ); + } - ptr = ptr.add(LOOP_SIZE_AVX512); + // Load all 4 vectors at once for better pipelining + let a0 = _mm512_load_si512(ptr as *const __m512i); + let a1 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE) as *const __m512i); + let a2 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE * 2) as *const __m512i); + let a3 = _mm512_load_si512(ptr.add(M512_VECTOR_SIZE * 3) as *const __m512i); + + // Check for quotes (") in all vectors + let quote_0 = _mm512_cmpeq_epi8_mask(a0, v_b); + let quote_1 = _mm512_cmpeq_epi8_mask(a1, v_b); + let quote_2 = _mm512_cmpeq_epi8_mask(a2, v_b); + let quote_3 = _mm512_cmpeq_epi8_mask(a3, v_b); + + // Check for backslash (\) in all vectors + let slash_0 = _mm512_cmpeq_epi8_mask(a0, v_c); + let slash_1 = _mm512_cmpeq_epi8_mask(a1, v_c); + let slash_2 = _mm512_cmpeq_epi8_mask(a2, v_c); + let slash_3 = _mm512_cmpeq_epi8_mask(a3, v_c); + + // Check for control characters (< 0x20) in all vectors + let ctrl_0 = _mm512_cmplt_epu8_mask(a0, v_ctrl_limit); + let ctrl_1 = _mm512_cmplt_epu8_mask(a1, v_ctrl_limit); + let ctrl_2 = _mm512_cmplt_epu8_mask(a2, v_ctrl_limit); + let ctrl_3 = _mm512_cmplt_epu8_mask(a3, v_ctrl_limit); + + // Combine all masks + let mask_a = quote_0 | slash_0 | ctrl_0; + let mask_b = quote_1 | slash_1 | ctrl_1; + let mask_c = quote_2 | slash_2 | ctrl_2; + let mask_d = quote_3 | slash_3 | ctrl_3; + + // Fast path: check if any escaping needed + let any_escape = mask_a | mask_b | mask_c | mask_d; + + if any_escape == 0 { + // No escapes needed, copy whole chunk + if start < sub(ptr, start_ptr) { + result.extend_from_slice(&bytes[start..sub(ptr, start_ptr)]); } + result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE_AVX512)); + start = sub(ptr, start_ptr) + LOOP_SIZE_AVX512; + } else { + // Process each 64-byte chunk that has escapes + process_mask_avx512(ptr, start_ptr, result, &mut start, bytes, mask_a, 0); + process_mask_avx512( + ptr, + start_ptr, + result, + &mut start, + bytes, + mask_b, + M512_VECTOR_SIZE, + ); + process_mask_avx512( + ptr, + start_ptr, + result, + &mut start, + bytes, + mask_c, + M512_VECTOR_SIZE * 2, + ); + process_mask_avx512( + ptr, + start_ptr, + result, + &mut start, + bytes, + mask_d, + M512_VECTOR_SIZE * 3, + ); } - // Process remaining aligned chunks - while ptr <= end_ptr.sub(M512_VECTOR_SIZE) { - debug_assert_eq!(0, (ptr as usize) % M512_VECTOR_SIZE); - let a = _mm512_load_si512(ptr as *const __m512i); - - let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b); - let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c); - let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit); - - let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64; + ptr = ptr.add(LOOP_SIZE_AVX512); + } - if mask != 0 { - let at = sub(ptr, start_ptr); - while mask != 0 { - let cur = mask.trailing_zeros() as usize; - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - debug_assert!(escape_byte != 0); - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - mask &= mask - 1; + // Process remaining aligned chunks + while ptr <= end_ptr.sub(M512_VECTOR_SIZE) { + debug_assert_eq!(0, (ptr as usize) % M512_VECTOR_SIZE); + let a = _mm512_load_si512(ptr as *const __m512i); + + let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b); + let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c); + let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit); + + let mut mask = (quote_mask | slash_mask | ctrl_mask) as u64; + + if mask != 0 { + let at = sub(ptr, start_ptr); + while mask != 0 { + let cur = mask.trailing_zeros() as usize; + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + debug_assert!(escape_byte != 0); + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } + write_escape(result, escape_byte, c); + start = i + 1; + mask &= mask - 1; } - ptr = ptr.add(M512_VECTOR_SIZE); } + ptr = ptr.add(M512_VECTOR_SIZE); + } - // Handle tail - if ptr < end_ptr { - let d = M512_VECTOR_SIZE - sub(end_ptr, ptr); - let a = _mm512_loadu_si512(ptr.sub(d) as *const __m512i); - - let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b); - let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c); - let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit); - - let mut mask = ((quote_mask | slash_mask | ctrl_mask) as u64).wrapping_shr(d as u32); - - if mask != 0 { - let at = sub(ptr, start_ptr); - while mask != 0 { - let cur = mask.trailing_zeros() as usize; - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - debug_assert!(escape_byte != 0); - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - mask &= mask - 1; + // Handle tail + if ptr < end_ptr { + let d = M512_VECTOR_SIZE - sub(end_ptr, ptr); + let a = _mm512_loadu_si512(ptr.sub(d) as *const __m512i); + + let quote_mask = _mm512_cmpeq_epi8_mask(a, v_b); + let slash_mask = _mm512_cmpeq_epi8_mask(a, v_c); + let ctrl_mask = _mm512_cmplt_epu8_mask(a, v_ctrl_limit); + + let mut mask = ((quote_mask | slash_mask | ctrl_mask) as u64).wrapping_shr(d as u32); + + if mask != 0 { + let at = sub(ptr, start_ptr); + while mask != 0 { + let cur = mask.trailing_zeros() as usize; + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + debug_assert!(escape_byte != 0); + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } + write_escape(result, escape_byte, c); + start = i + 1; + mask &= mask - 1; } } - } else { - // Fall back to AVX2 for small strings - return escape_avx2(bytes, result); } // Copy any remaining bytes @@ -248,217 +242,210 @@ pub unsafe fn escape_avx2(bytes: &[u8], result: &mut Vec) { let mut ptr = start_ptr; let mut start = 0; - if len >= M256_VECTOR_SIZE { - let v_translation_a = _mm256_set1_epi8(TRANSLATION_A); - let v_below_a = _mm256_set1_epi8(BELOW_A); - let v_b = _mm256_set1_epi8(B); - let v_c = _mm256_set1_epi8(C); - - // Handle alignment - skip if already aligned - const M256_VECTOR_ALIGN: usize = M256_VECTOR_SIZE - 1; - let misalignment = start_ptr as usize & M256_VECTOR_ALIGN; - if misalignment != 0 { - let align = M256_VECTOR_SIZE - misalignment; - let mut mask = { - let a = _mm256_loadu_si256(ptr as *const __m256i); - _mm256_movemask_epi8(_mm256_or_si256( - _mm256_or_si256(_mm256_cmpeq_epi8(a, v_b), _mm256_cmpeq_epi8(a, v_c)), - _mm256_cmpgt_epi8(_mm256_add_epi8(a, v_translation_a), v_below_a), - )) - }; - - if mask != 0 { - let at = sub(ptr, start_ptr); - let mut cur = mask.trailing_zeros() as usize; - while cur < align { - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - mask ^= 1 << cur; - if mask == 0 { - break; + let v_translation_a = _mm256_set1_epi8(TRANSLATION_A); + let v_below_a = _mm256_set1_epi8(BELOW_A); + let v_b = _mm256_set1_epi8(B); + let v_c = _mm256_set1_epi8(C); + + // Handle alignment - skip if already aligned + const M256_VECTOR_ALIGN: usize = M256_VECTOR_SIZE - 1; + let misalignment = start_ptr as usize & M256_VECTOR_ALIGN; + if misalignment != 0 { + let align = M256_VECTOR_SIZE - misalignment; + let mut mask = { + let a = _mm256_loadu_si256(ptr as *const __m256i); + _mm256_movemask_epi8(_mm256_or_si256( + _mm256_or_si256(_mm256_cmpeq_epi8(a, v_b), _mm256_cmpeq_epi8(a, v_c)), + _mm256_cmpgt_epi8(_mm256_add_epi8(a, v_translation_a), v_below_a), + )) + }; + + if mask != 0 { + let at = sub(ptr, start_ptr); + let mut cur = mask.trailing_zeros() as usize; + while cur < align { + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + if escape_byte != 0 { + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } - cur = mask.trailing_zeros() as usize; + write_escape(result, escape_byte, c); + start = i + 1; + } + mask ^= 1 << cur; + if mask == 0 { + break; } + cur = mask.trailing_zeros() as usize; } - ptr = ptr.add(align); } + ptr = ptr.add(align); + } - // Main loop processing 128 bytes at a time - if LOOP_SIZE_AVX2 <= len { - while ptr <= end_ptr.sub(LOOP_SIZE_AVX2) { - debug_assert_eq!(0, (ptr as usize) % M256_VECTOR_SIZE); - - // Prefetch next iteration's data - if ptr.add(LOOP_SIZE_AVX2 + PREFETCH_DISTANCE) < end_ptr { - _mm_prefetch( - ptr.add(LOOP_SIZE_AVX2 + PREFETCH_DISTANCE) as *const i8, - _MM_HINT_T0, - ); - } + // Main loop processing 128 bytes at a time + while ptr <= end_ptr.sub(LOOP_SIZE_AVX2) { + debug_assert_eq!(0, (ptr as usize) % M256_VECTOR_SIZE); - // Load all 4 vectors at once for better pipelining - let a0 = _mm256_load_si256(ptr as *const __m256i); - let a1 = _mm256_load_si256(ptr.add(M256_VECTOR_SIZE) as *const __m256i); - let a2 = _mm256_load_si256(ptr.add(M256_VECTOR_SIZE * 2) as *const __m256i); - let a3 = _mm256_load_si256(ptr.add(M256_VECTOR_SIZE * 3) as *const __m256i); - - // Check for quotes (") in all vectors - let quote_0 = _mm256_cmpeq_epi8(a0, v_b); - let quote_1 = _mm256_cmpeq_epi8(a1, v_b); - let quote_2 = _mm256_cmpeq_epi8(a2, v_b); - let quote_3 = _mm256_cmpeq_epi8(a3, v_b); - - // Check for backslash (\) in all vectors - let slash_0 = _mm256_cmpeq_epi8(a0, v_c); - let slash_1 = _mm256_cmpeq_epi8(a1, v_c); - let slash_2 = _mm256_cmpeq_epi8(a2, v_c); - let slash_3 = _mm256_cmpeq_epi8(a3, v_c); - - // Check for control characters (< 0x20) in all vectors - let ctrl_0 = _mm256_cmpgt_epi8(_mm256_add_epi8(a0, v_translation_a), v_below_a); - let ctrl_1 = _mm256_cmpgt_epi8(_mm256_add_epi8(a1, v_translation_a), v_below_a); - let ctrl_2 = _mm256_cmpgt_epi8(_mm256_add_epi8(a2, v_translation_a), v_below_a); - let ctrl_3 = _mm256_cmpgt_epi8(_mm256_add_epi8(a3, v_translation_a), v_below_a); - - // Combine all masks - let cmp_a = _mm256_or_si256(_mm256_or_si256(quote_0, slash_0), ctrl_0); - let cmp_b = _mm256_or_si256(_mm256_or_si256(quote_1, slash_1), ctrl_1); - let cmp_c = _mm256_or_si256(_mm256_or_si256(quote_2, slash_2), ctrl_2); - let cmp_d = _mm256_or_si256(_mm256_or_si256(quote_3, slash_3), ctrl_3); - - // Fast path: check if any escaping needed - let any_escape = - _mm256_or_si256(_mm256_or_si256(cmp_a, cmp_b), _mm256_or_si256(cmp_c, cmp_d)); - - if _mm256_movemask_epi8(any_escape) == 0 { - // No escapes needed, copy whole chunk - if start < sub(ptr, start_ptr) { - result.extend_from_slice(&bytes[start..sub(ptr, start_ptr)]); - } - result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE_AVX2)); - start = sub(ptr, start_ptr) + LOOP_SIZE_AVX2; - } else { - // Get individual masks only when needed - let mask_a = _mm256_movemask_epi8(cmp_a); - let mask_b = _mm256_movemask_epi8(cmp_b); - let mask_c = _mm256_movemask_epi8(cmp_c); - let mask_d = _mm256_movemask_epi8(cmp_d); - - // Process each 32-byte chunk that has escapes - process_mask_avx(ptr, start_ptr, result, &mut start, bytes, mask_a, 0); - process_mask_avx( - ptr, - start_ptr, - result, - &mut start, - bytes, - mask_b, - M256_VECTOR_SIZE, - ); - process_mask_avx( - ptr, - start_ptr, - result, - &mut start, - bytes, - mask_c, - M256_VECTOR_SIZE * 2, - ); - process_mask_avx( - ptr, - start_ptr, - result, - &mut start, - bytes, - mask_d, - M256_VECTOR_SIZE * 3, - ); - } + // Prefetch next iteration's data + if ptr.add(LOOP_SIZE_AVX2 + PREFETCH_DISTANCE_AVX2) < end_ptr { + _mm_prefetch( + ptr.add(LOOP_SIZE_AVX2 + PREFETCH_DISTANCE_AVX2) as *const i8, + _MM_HINT_T0, + ); + } - ptr = ptr.add(LOOP_SIZE_AVX2); + // Load all 4 vectors at once for better pipelining + let a0 = _mm256_load_si256(ptr as *const __m256i); + let a1 = _mm256_load_si256(ptr.add(M256_VECTOR_SIZE) as *const __m256i); + let a2 = _mm256_load_si256(ptr.add(M256_VECTOR_SIZE * 2) as *const __m256i); + let a3 = _mm256_load_si256(ptr.add(M256_VECTOR_SIZE * 3) as *const __m256i); + + // Check for quotes (") in all vectors + let quote_0 = _mm256_cmpeq_epi8(a0, v_b); + let quote_1 = _mm256_cmpeq_epi8(a1, v_b); + let quote_2 = _mm256_cmpeq_epi8(a2, v_b); + let quote_3 = _mm256_cmpeq_epi8(a3, v_b); + + // Check for backslash (\) in all vectors + let slash_0 = _mm256_cmpeq_epi8(a0, v_c); + let slash_1 = _mm256_cmpeq_epi8(a1, v_c); + let slash_2 = _mm256_cmpeq_epi8(a2, v_c); + let slash_3 = _mm256_cmpeq_epi8(a3, v_c); + + // Check for control characters (< 0x20) in all vectors + let ctrl_0 = _mm256_cmpgt_epi8(_mm256_add_epi8(a0, v_translation_a), v_below_a); + let ctrl_1 = _mm256_cmpgt_epi8(_mm256_add_epi8(a1, v_translation_a), v_below_a); + let ctrl_2 = _mm256_cmpgt_epi8(_mm256_add_epi8(a2, v_translation_a), v_below_a); + let ctrl_3 = _mm256_cmpgt_epi8(_mm256_add_epi8(a3, v_translation_a), v_below_a); + + // Combine all masks + let cmp_a = _mm256_or_si256(_mm256_or_si256(quote_0, slash_0), ctrl_0); + let cmp_b = _mm256_or_si256(_mm256_or_si256(quote_1, slash_1), ctrl_1); + let cmp_c = _mm256_or_si256(_mm256_or_si256(quote_2, slash_2), ctrl_2); + let cmp_d = _mm256_or_si256(_mm256_or_si256(quote_3, slash_3), ctrl_3); + + // Fast path: check if any escaping needed + let any_escape = + _mm256_or_si256(_mm256_or_si256(cmp_a, cmp_b), _mm256_or_si256(cmp_c, cmp_d)); + + if _mm256_movemask_epi8(any_escape) == 0 { + // No escapes needed, copy whole chunk + if start < sub(ptr, start_ptr) { + result.extend_from_slice(&bytes[start..sub(ptr, start_ptr)]); } + result.extend_from_slice(std::slice::from_raw_parts(ptr, LOOP_SIZE_AVX2)); + start = sub(ptr, start_ptr) + LOOP_SIZE_AVX2; + } else { + // Get individual masks only when needed + let mask_a = _mm256_movemask_epi8(cmp_a); + let mask_b = _mm256_movemask_epi8(cmp_b); + let mask_c = _mm256_movemask_epi8(cmp_c); + let mask_d = _mm256_movemask_epi8(cmp_d); + + // Process each 32-byte chunk that has escapes + process_mask_avx(ptr, start_ptr, result, &mut start, bytes, mask_a, 0); + process_mask_avx( + ptr, + start_ptr, + result, + &mut start, + bytes, + mask_b, + M256_VECTOR_SIZE, + ); + process_mask_avx( + ptr, + start_ptr, + result, + &mut start, + bytes, + mask_c, + M256_VECTOR_SIZE * 2, + ); + process_mask_avx( + ptr, + start_ptr, + result, + &mut start, + bytes, + mask_d, + M256_VECTOR_SIZE * 3, + ); } - // Process remaining aligned chunks - while ptr <= end_ptr.sub(M256_VECTOR_SIZE) { - debug_assert_eq!(0, (ptr as usize) % M256_VECTOR_SIZE); - let mut mask = { - let a = _mm256_load_si256(ptr as *const __m256i); - _mm256_movemask_epi8(_mm256_or_si256( - _mm256_or_si256(_mm256_cmpeq_epi8(a, v_b), _mm256_cmpeq_epi8(a, v_c)), - _mm256_cmpgt_epi8(_mm256_add_epi8(a, v_translation_a), v_below_a), - )) - }; - - if mask != 0 { - let at = sub(ptr, start_ptr); - let mut cur = mask.trailing_zeros() as usize; - loop { - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - mask ^= 1 << cur; - if mask == 0 { - break; + ptr = ptr.add(LOOP_SIZE_AVX2); + } + + // Process remaining aligned chunks + while ptr <= end_ptr.sub(M256_VECTOR_SIZE) { + debug_assert_eq!(0, (ptr as usize) % M256_VECTOR_SIZE); + let mut mask = { + let a = _mm256_load_si256(ptr as *const __m256i); + _mm256_movemask_epi8(_mm256_or_si256( + _mm256_or_si256(_mm256_cmpeq_epi8(a, v_b), _mm256_cmpeq_epi8(a, v_c)), + _mm256_cmpgt_epi8(_mm256_add_epi8(a, v_translation_a), v_below_a), + )) + }; + + if mask != 0 { + let at = sub(ptr, start_ptr); + let mut cur = mask.trailing_zeros() as usize; + loop { + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + if escape_byte != 0 { + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } - cur = mask.trailing_zeros() as usize; + write_escape(result, escape_byte, c); + start = i + 1; + } + mask ^= 1 << cur; + if mask == 0 { + break; } + cur = mask.trailing_zeros() as usize; } - ptr = ptr.add(M256_VECTOR_SIZE); } + ptr = ptr.add(M256_VECTOR_SIZE); + } - // Handle tail - if ptr < end_ptr { - let d = M256_VECTOR_SIZE - sub(end_ptr, ptr); - let mut mask = ({ - let a = _mm256_loadu_si256(ptr.sub(d) as *const __m256i); - _mm256_movemask_epi8(_mm256_or_si256( - _mm256_or_si256(_mm256_cmpeq_epi8(a, v_b), _mm256_cmpeq_epi8(a, v_c)), - _mm256_cmpgt_epi8(_mm256_add_epi8(a, v_translation_a), v_below_a), - )) - } as u32) - .wrapping_shr(d as u32); - - if mask != 0 { - let at = sub(ptr, start_ptr); - let mut cur = mask.trailing_zeros() as usize; - loop { - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - mask ^= 1 << cur; - if mask == 0 { - break; + // Handle tail + if ptr < end_ptr { + let d = M256_VECTOR_SIZE - sub(end_ptr, ptr); + let mut mask = ({ + let a = _mm256_loadu_si256(ptr.sub(d) as *const __m256i); + _mm256_movemask_epi8(_mm256_or_si256( + _mm256_or_si256(_mm256_cmpeq_epi8(a, v_b), _mm256_cmpeq_epi8(a, v_c)), + _mm256_cmpgt_epi8(_mm256_add_epi8(a, v_translation_a), v_below_a), + )) + } as u32) + .wrapping_shr(d as u32); + + if mask != 0 { + let at = sub(ptr, start_ptr); + let mut cur = mask.trailing_zeros() as usize; + loop { + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + if escape_byte != 0 { + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } - cur = mask.trailing_zeros() as usize; + write_escape(result, escape_byte, c); + start = i + 1; + } + mask ^= 1 << cur; + if mask == 0 { + break; } + cur = mask.trailing_zeros() as usize; } } - } else { - // Fall back to SSE2 for small strings - return escape_sse2(bytes, result); } // Copy any remaining bytes @@ -479,130 +466,113 @@ pub unsafe fn escape_sse2(bytes: &[u8], result: &mut Vec) { const M128_VECTOR_ALIGN: usize = M128_VECTOR_SIZE - 1; - if len < M128_VECTOR_SIZE { - // Scalar fallback for very small strings - while ptr < end_ptr { - let c = *ptr; - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = sub(ptr, start_ptr); - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - ptr = ptr.offset(1); - } - } else { - let v_translation_a = _mm_set1_epi8(TRANSLATION_A); - let v_below_a = _mm_set1_epi8(BELOW_A); - let v_b = _mm_set1_epi8(B); - let v_c = _mm_set1_epi8(C); - - // Handle alignment - skip if already aligned - let misalignment = start_ptr as usize & M128_VECTOR_ALIGN; - if misalignment != 0 { - let align = M128_VECTOR_SIZE - misalignment; - let mut mask = { - let a = _mm_loadu_si128(ptr as *const __m128i); - _mm_movemask_epi8(_mm_or_si128( - _mm_or_si128(_mm_cmpeq_epi8(a, v_b), _mm_cmpeq_epi8(a, v_c)), - _mm_cmpgt_epi8(_mm_add_epi8(a, v_translation_a), v_below_a), - )) - }; - - if mask != 0 { - let at = sub(ptr, start_ptr); - let mut cur = mask.trailing_zeros() as usize; - while cur < align { - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - mask ^= 1 << cur; - if mask == 0 { - break; + let v_translation_a = _mm_set1_epi8(TRANSLATION_A); + let v_below_a = _mm_set1_epi8(BELOW_A); + let v_b = _mm_set1_epi8(B); + let v_c = _mm_set1_epi8(C); + + // Handle alignment - skip if already aligned + let misalignment = start_ptr as usize & M128_VECTOR_ALIGN; + if misalignment != 0 { + let align = M128_VECTOR_SIZE - misalignment; + let mut mask = { + let a = _mm_loadu_si128(ptr as *const __m128i); + _mm_movemask_epi8(_mm_or_si128( + _mm_or_si128(_mm_cmpeq_epi8(a, v_b), _mm_cmpeq_epi8(a, v_c)), + _mm_cmpgt_epi8(_mm_add_epi8(a, v_translation_a), v_below_a), + )) + }; + + if mask != 0 { + let at = sub(ptr, start_ptr); + let mut cur = mask.trailing_zeros() as usize; + while cur < align { + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + if escape_byte != 0 { + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } - cur = mask.trailing_zeros() as usize; + write_escape(result, escape_byte, c); + start = i + 1; + } + mask ^= 1 << cur; + if mask == 0 { + break; } + cur = mask.trailing_zeros() as usize; } - ptr = ptr.add(align); } + ptr = ptr.add(align); + } - // Main loop - while ptr <= end_ptr.sub(M128_VECTOR_SIZE) { - debug_assert_eq!(0, (ptr as usize) % M128_VECTOR_SIZE); - let mut mask = { - let a = _mm_load_si128(ptr as *const __m128i); - _mm_movemask_epi8(_mm_or_si128( - _mm_or_si128(_mm_cmpeq_epi8(a, v_b), _mm_cmpeq_epi8(a, v_c)), - _mm_cmpgt_epi8(_mm_add_epi8(a, v_translation_a), v_below_a), - )) - }; - - if mask != 0 { - let at = sub(ptr, start_ptr); - let mut cur = mask.trailing_zeros() as usize; - loop { - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - mask ^= 1 << cur; - if mask == 0 { - break; + // Main loop + while ptr <= end_ptr.sub(M128_VECTOR_SIZE) { + debug_assert_eq!(0, (ptr as usize) % M128_VECTOR_SIZE); + let mut mask = { + let a = _mm_load_si128(ptr as *const __m128i); + _mm_movemask_epi8(_mm_or_si128( + _mm_or_si128(_mm_cmpeq_epi8(a, v_b), _mm_cmpeq_epi8(a, v_c)), + _mm_cmpgt_epi8(_mm_add_epi8(a, v_translation_a), v_below_a), + )) + }; + + if mask != 0 { + let at = sub(ptr, start_ptr); + let mut cur = mask.trailing_zeros() as usize; + loop { + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + if escape_byte != 0 { + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } - cur = mask.trailing_zeros() as usize; + write_escape(result, escape_byte, c); + start = i + 1; } + mask ^= 1 << cur; + if mask == 0 { + break; + } + cur = mask.trailing_zeros() as usize; } - ptr = ptr.add(M128_VECTOR_SIZE); } + ptr = ptr.add(M128_VECTOR_SIZE); + } - // Handle tail - if ptr < end_ptr { - let d = M128_VECTOR_SIZE - sub(end_ptr, ptr); - let mut mask = ({ - let a = _mm_loadu_si128(ptr.sub(d) as *const __m128i); - _mm_movemask_epi8(_mm_or_si128( - _mm_or_si128(_mm_cmpeq_epi8(a, v_b), _mm_cmpeq_epi8(a, v_c)), - _mm_cmpgt_epi8(_mm_add_epi8(a, v_translation_a), v_below_a), - )) - } as u16) - .wrapping_shr(d as u32); - - if mask != 0 { - let at = sub(ptr, start_ptr); - let mut cur = mask.trailing_zeros() as usize; - loop { - let c = *ptr.add(cur); - let escape_byte = ESCAPE[c as usize]; - if escape_byte != 0 { - let i = at + cur; - if start < i { - result.extend_from_slice(&bytes[start..i]); - } - write_escape(result, escape_byte, c); - start = i + 1; - } - mask ^= 1 << cur; - if mask == 0 { - break; + // Handle tail + if ptr < end_ptr { + let d = M128_VECTOR_SIZE - sub(end_ptr, ptr); + let mut mask = ({ + let a = _mm_loadu_si128(ptr.sub(d) as *const __m128i); + _mm_movemask_epi8(_mm_or_si128( + _mm_or_si128(_mm_cmpeq_epi8(a, v_b), _mm_cmpeq_epi8(a, v_c)), + _mm_cmpgt_epi8(_mm_add_epi8(a, v_translation_a), v_below_a), + )) + } as u16) + .wrapping_shr(d as u32); + + if mask != 0 { + let at = sub(ptr, start_ptr); + let mut cur = mask.trailing_zeros() as usize; + loop { + let c = *ptr.add(cur); + let escape_byte = ESCAPE[c as usize]; + if escape_byte != 0 { + let i = at + cur; + if start < i { + result.extend_from_slice(&bytes[start..i]); } - cur = mask.trailing_zeros() as usize; + write_escape(result, escape_byte, c); + start = i + 1; + } + mask ^= 1 << cur; + if mask == 0 { + break; } + cur = mask.trailing_zeros() as usize; } } }