Skip to content

Commit 7c9ca7b

Browse files
authored
cmov: improve slice remainder handling (#1377)
When chunking slices into words, instead of handling the remainder by iterating over it an element at a time, converts the remainder into a `Word` and applies either `Cmov` or `CmovEq` to the word. For `Cmov`, the computed `Word` is used to fill `dst_remainder`. The implementation is generic and used with `u8`, `u16`, and `u32`.
1 parent bb30a74 commit 7c9ca7b

File tree

1 file changed

+127
-20
lines changed

1 file changed

+127
-20
lines changed

cmov/src/slice.rs

Lines changed: 127 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
//! Trait impls for core slices.
22
33
use crate::{Cmov, CmovEq, Condition};
4-
use core::slice;
4+
use core::{
5+
ops::{BitOrAssign, Shl},
6+
slice,
7+
};
58

69
// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
710
#[cfg(not(target_pointer_width = "64"))]
@@ -43,10 +46,7 @@ impl Cmov for [u8] {
4346
dst_chunk.copy_from_slice(&a.to_ne_bytes());
4447
}
4548

46-
// Process the remainder a byte-at-a-time.
47-
for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) {
48-
a.cmovnz(b, condition);
49-
}
49+
cmovnz_remainder(dst_remainder, src_remainder, condition);
5050
}
5151
}
5252

@@ -74,10 +74,7 @@ impl Cmov for [u16] {
7474
dst_chunk[1] = (a >> 16) as u16;
7575
}
7676

77-
// If slice is odd-length
78-
if !dst_remainder.is_empty() {
79-
dst_remainder[0].cmovnz(&src_remainder[0], condition);
80-
}
77+
cmovnz_remainder(dst_remainder, src_remainder, condition);
8178
}
8279
}
8380

@@ -115,9 +112,7 @@ impl Cmov for [u16] {
115112
dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16;
116113
}
117114

118-
for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) {
119-
a.cmovnz(b, condition);
120-
}
115+
cmovnz_remainder(dst_remainder, src_remainder, condition);
121116
}
122117
}
123118

@@ -163,10 +158,7 @@ impl Cmov for [u32] {
163158
dst_chunk[1] = (a >> 32) as u32;
164159
}
165160

166-
// If slice is odd-length
167-
if !dst_remainder.is_empty() {
168-
dst_remainder[0].cmovnz(&src_remainder[0], condition);
169-
}
161+
cmovnz_remainder(dst_remainder, src_remainder, condition);
170162
}
171163
}
172164

@@ -323,10 +315,7 @@ impl CmovEq for [u8] {
323315
a.cmovne(&b, input, output);
324316
}
325317

326-
// Process the remainder a byte-at-a-time.
327-
for (a, b) in self_remainder.iter().zip(rhs_remainder.iter()) {
328-
a.cmovne(b, input, output);
329-
}
318+
cmovne_remainder(self_remainder, rhs_remainder, input, output);
330319
}
331320
}
332321

@@ -370,6 +359,68 @@ impl_cmoveq_with_loop!(
370359
"Implementation for `u128` slices where we can just loop."
371360
);
372361

362+
/// Compare the two remainder slices by loading a `Word` then performing `cmovne`.
363+
#[inline]
364+
fn cmovne_remainder<T>(
365+
a_remainder: &[T],
366+
b_remainder: &[T],
367+
input: Condition,
368+
output: &mut Condition,
369+
) where
370+
T: Copy,
371+
Word: From<T>,
372+
{
373+
let a = slice_to_word(a_remainder);
374+
let b = slice_to_word(b_remainder);
375+
a.cmovne(&b, input, output);
376+
}
377+
378+
/// Load the remainder from chunking the slice into a single `Word`, perform `cmovnz`, then write
379+
/// the result back out to `dst_remainder`.
380+
#[inline]
381+
fn cmovnz_remainder<T>(dst_remainder: &mut [T], src_remainder: &[T], condition: Condition)
382+
where
383+
T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
384+
Word: From<T>,
385+
{
386+
let mut remainder = slice_to_word(dst_remainder);
387+
remainder.cmovnz(&slice_to_word(src_remainder), condition);
388+
word_to_slice(remainder, dst_remainder);
389+
}
390+
391+
/// Create a [`Word`] from the given input slice.
392+
#[inline]
393+
fn slice_to_word<T>(slice: &[T]) -> Word
394+
where
395+
T: Copy,
396+
Word: From<T>,
397+
{
398+
debug_assert!(size_of_val(slice) <= WORD_SIZE, "slice too large");
399+
slice
400+
.iter()
401+
.rev()
402+
.copied()
403+
.fold(0, |acc, n| (acc << (size_of::<T>() * 8)) | Word::from(n))
404+
}
405+
406+
/// Serialize [`Word`] as bytes using the same byte ordering as `slice_to_word`.
407+
#[inline]
408+
fn word_to_slice<T>(word: Word, out: &mut [T])
409+
where
410+
T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>,
411+
{
412+
debug_assert!(size_of::<T>() > 0, "can't be used with ZSTs");
413+
debug_assert!(out.len() <= WORD_SIZE, "slice too large");
414+
415+
let bytes = word.to_le_bytes();
416+
for (o, chunk) in out.iter_mut().zip(bytes.chunks(size_of::<T>())) {
417+
*o = T::from(0u8);
418+
for (i, &byte) in chunk.iter().enumerate() {
419+
*o |= T::from(byte) << (i * 8);
420+
}
421+
}
422+
}
423+
373424
/// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
374425
/// TODO(tarcieri): use upstream function when we bump MSRV
375426
#[inline]
@@ -439,3 +490,59 @@ unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &
439490
// a slice of `new_len` many `N` elements chunks.
440491
unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
441492
}
493+
494+
#[cfg(test)]
495+
mod tests {
496+
#[test]
497+
fn cmovnz_remainder() {
498+
// - Test endianness handling on non-64-bit platforms
499+
// - Test handling of odd length slices on 64-bit platforms
500+
#[cfg(not(target_pointer_width = "64"))]
501+
const A_U16: [u16; 2] = [0xAAAA, 0xBBBB];
502+
#[cfg(target_pointer_width = "64")]
503+
const A_U16: [u16; 3] = [0xAAAA, 0xBBBB, 0xCCCC];
504+
505+
#[cfg(not(target_pointer_width = "64"))]
506+
const B_U16: [u16; 2] = [0x10, 0xFFFF];
507+
#[cfg(target_pointer_width = "64")]
508+
const B_U16: [u16; 3] = [0x10, 0x10, 0xFFFF];
509+
510+
let mut out = A_U16;
511+
512+
super::cmovnz_remainder(&mut out, &B_U16, 0);
513+
assert_eq!(A_U16, out);
514+
515+
super::cmovnz_remainder(&mut out, &B_U16, 1);
516+
assert_eq!(B_U16, out);
517+
}
518+
519+
#[test]
520+
fn slice_to_word() {
521+
assert_eq!(0xAABBCC, super::slice_to_word(&[0xCCu8, 0xBB, 0xAA]));
522+
assert_eq!(0xAAAABBBB, super::slice_to_word(&[0xBBBBu16, 0xAAAA]));
523+
524+
#[cfg(target_pointer_width = "64")]
525+
assert_eq!(
526+
0xAAAABBBBCCCC,
527+
super::slice_to_word(&[0xCCCCu16, 0xBBBB, 0xAAAA])
528+
);
529+
}
530+
531+
#[test]
532+
fn word_to_slice() {
533+
let mut out = [0u8; 3];
534+
super::word_to_slice(0xAABBCC, &mut out);
535+
assert_eq!(&[0xCC, 0xBB, 0xAA], &out);
536+
537+
let mut out = [0u16; 2];
538+
super::word_to_slice(0xAAAABBBB, &mut out);
539+
assert_eq!(&[0xBBBB, 0xAAAA], &out);
540+
541+
#[cfg(target_pointer_width = "64")]
542+
{
543+
let mut out = [0u16; 3];
544+
super::word_to_slice(0xAAAABBBBCCCC, &mut out);
545+
assert_eq!(&[0xCCCC, 0xBBBB, 0xAAAA], &out);
546+
}
547+
}
548+
}

0 commit comments

Comments
 (0)