Skip to content

Commit 89d5790

Browse files
authored
perf: SIMD common prefix (#49)
1 parent 61ab7c7 commit 89d5790

File tree

2 files changed

+38
-55
lines changed

2 files changed

+38
-55
lines changed

benches/benches/cmp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub fn group(criterion: &mut Criterion) {
1818
let other = n.slice(..n.len().saturating_sub(1));
1919
(n, other)
2020
}),
21-
|(a, b)| a.common_prefix_length(black_box(&b)),
21+
|(a, b)| a.common_prefix_length(&b),
2222
);
2323
}
2424
}

src/nibbles.rs

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl Ord for Nibbles {
163163
let l = cmp::min(self_len, other_len);
164164
let len_cmp = self.len().cmp(&other.len());
165165

166-
let byte_idx = first_diff_byte_idx(&self.nibbles, &other.nibbles);
166+
let byte_idx = longest_prefix_byte(&self.nibbles, &other.nibbles);
167167
let r = if byte_idx < l {
168168
// SAFETY: `byte_idx` < 32, so `31 - byte_idx` is valid.
169169
let le_idx = 31 - byte_idx;
@@ -647,20 +647,9 @@ impl Nibbles {
647647
}
648648

649649
/// Returns `true` if this nibble sequence starts with the given prefix.
650+
#[inline]
650651
pub fn starts_with(&self, other: &Self) -> bool {
651-
// Fast path: if lengths don't allow prefix, return false
652-
if other.len() > self.len() {
653-
return false;
654-
}
655-
656-
// Fast path: empty prefix always matches
657-
if other.is_empty() {
658-
return true;
659-
}
660-
661-
// Direct comparison using masks
662-
let mask = SLICE_MASKS[other.len()];
663-
(self.nibbles & mask) == other.nibbles
652+
other.len() <= self.len() && (self.nibbles & SLICE_MASKS[other.len()]) == other.nibbles
664653
}
665654

666655
/// Returns `true` if this nibble sequence ends with the given suffix.
@@ -783,41 +772,15 @@ impl Nibbles {
783772
/// let b = Nibbles::from_nibbles(&[0x0A, 0x0B, 0x0C, 0x0E]);
784773
/// assert_eq!(a.common_prefix_length(&b), 3);
785774
/// ```
775+
#[inline]
786776
pub fn common_prefix_length(&self, other: &Self) -> usize {
787-
// Handle empty cases
788-
if self.is_empty() || other.is_empty() {
789-
return 0;
790-
}
791-
792-
let min_nibble_len = self.len().min(other.len());
793-
794-
// Fast path for small sequences that fit in one u64 limb
795-
if min_nibble_len <= 16 {
796-
// Extract the highest u64 limb which contains all the nibbles
797-
let self_limb = self.nibbles.as_limbs()[3];
798-
let other_limb = other.nibbles.as_limbs()[3];
799-
800-
// Create mask for the nibbles we care about
801-
let mask = u64::MAX << ((16 - min_nibble_len) * 4);
802-
let xor = (self_limb ^ other_limb) & mask;
803-
804-
if xor == 0 {
805-
return min_nibble_len;
806-
} else {
807-
return xor.leading_zeros() as usize / 4;
808-
}
809-
}
810-
811-
let xor = if min_nibble_len == NIBBLES && self.len() == other.len() {
812-
// No need to mask for 64 nibble sequences, just XOR
813-
self.nibbles ^ other.nibbles
814-
} else {
815-
// For other lengths, mask the nibbles we care about, and then XOR
816-
let mask = SLICE_MASKS[min_nibble_len];
817-
(self.nibbles ^ other.nibbles) & mask
818-
};
777+
let l = self.len().min(other.len());
778+
self.common_prefix_length_raw(other).min(l)
779+
}
819780

820-
if xor == U256::ZERO { min_nibble_len } else { xor.leading_zeros() / 4 }
781+
#[inline]
782+
fn common_prefix_length_raw(&self, other: &Self) -> usize {
783+
longest_prefix_bit(&self.nibbles, &other.nibbles) / 4
821784
}
822785

823786
/// Returns the total number of bits in this [`Nibbles`].
@@ -1308,27 +1271,47 @@ const fn as_le_slice(x: &U256) -> ByteContainer<'_, { U256::BYTES }> {
13081271
}
13091272

13101273
#[inline]
1311-
fn first_diff_byte_idx(a: &U256, b: &U256) -> usize {
1274+
fn longest_prefix_byte(a: &U256, b: &U256) -> usize {
1275+
longest_prefix::<false>(a, b) / 8
1276+
}
1277+
1278+
#[inline]
1279+
fn longest_prefix_bit(a: &U256, b: &U256) -> usize {
1280+
longest_prefix::<true>(a, b)
1281+
}
1282+
1283+
#[inline]
1284+
fn longest_prefix<const EXACT: bool>(a: &U256, b: &U256) -> usize {
13121285
cfg_if! {
13131286
if #[cfg(target_arch = "x86_64")] {
13141287
#[cfg(feature = "std")]
13151288
let enabled = std::is_x86_feature_detected!("avx2");
13161289
#[cfg(not(feature = "std"))]
13171290
let enabled = cfg!(target_feature = "avx2");
13181291
if enabled {
1319-
use core::arch::x86_64::*;
13201292
return unsafe {
1321-
let a = _mm256_loadu_si256(a.as_limbs().as_ptr().cast());
1322-
let b = _mm256_loadu_si256(b.as_limbs().as_ptr().cast());
1323-
let diff = _mm256_cmpeq_epi8(a, b);
1293+
use core::arch::x86_64::*;
1294+
let x = _mm256_loadu_si256(a.as_limbs().as_ptr().cast());
1295+
let y = _mm256_loadu_si256(b.as_limbs().as_ptr().cast());
1296+
let diff = _mm256_cmpeq_epi8(x, y);
13241297
let mask = _mm256_movemask_epi8(diff);
1325-
mask.leading_ones() as usize
1298+
let bytes = mask.leading_ones() as usize;
1299+
if !EXACT || bytes == 32 {
1300+
return bytes * 8;
1301+
}
1302+
let le_idx = 31 - bytes;
1303+
let a = *a.as_le_slice().get_unchecked(le_idx);
1304+
let b = *b.as_le_slice().get_unchecked(le_idx);
1305+
let diff = a ^ b;
1306+
let bits = diff.leading_zeros() as usize;
1307+
bytes * 8 + bits
13261308
};
13271309
}
13281310
}
13291311
}
13301312

1331-
(*a ^ *b).leading_zeros() / 8
1313+
let diff = *a ^ *b;
1314+
diff.leading_zeros()
13321315
}
13331316

13341317
#[inline]

0 commit comments

Comments
 (0)