Skip to content

Commit eabe934

Browse files
authored
cmov: expand slice impls to all unsigned integers (#1370)
Includes specialized impls for `u16` and `u32` that coalesce to the word size to improve performance when operating over slices.
1 parent 2a8b0b1 commit eabe934

File tree

3 files changed

+177
-33
lines changed

3 files changed

+177
-33
lines changed

cmov/src/array.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
//! Trait impls for core arrays.
22
3-
use crate::{Cmov, CmovEq, Condition, slice::cmovnz_slice_unchecked};
3+
use crate::{Cmov, CmovEq, Condition};
44

55
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
66
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
77
impl<const N: usize> Cmov for [u8; N] {
88
#[inline]
99
fn cmovnz(&mut self, value: &Self, condition: Condition) {
10-
// "unchecked" means it doesn't check the inputs are equal-length, however they are in this
11-
// context because they're two equal-sized arrays
12-
cmovnz_slice_unchecked(self, value, condition);
10+
self.as_mut_slice().cmovnz(value, condition);
1311
}
1412
}
1513

cmov/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#![no_std]
2+
#![cfg_attr(docsrs, feature(doc_cfg))]
23
#![doc = include_str!("../README.md")]
34
#![doc(
45
html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",

cmov/src/slice.rs

Lines changed: 174 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,36 @@ type Word = u64;
1111
const WORD_SIZE: usize = size_of::<Word>();
1212
const _: () = assert!(size_of::<usize>() <= WORD_SIZE, "unexpected word size");
1313

14+
/// Assert the lengths of the two slices are equal.
15+
macro_rules! assert_lengths_eq {
16+
($a:expr, $b:expr) => {
17+
assert_eq!(
18+
$a, $b,
19+
"source slice length ({}) does not match destination slice length ({})",
20+
$b, $a
21+
);
22+
};
23+
}
24+
25+
/// Implement [`Cmov`] using a simple loop.
26+
macro_rules! impl_cmov_with_loop {
27+
($int:ty, $doc:expr) => {
28+
#[doc = $doc]
29+
#[doc = "# Panics"]
30+
#[doc = "- if slices have unequal lengths"]
31+
impl Cmov for [$int] {
32+
#[inline]
33+
#[track_caller]
34+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
35+
assert_lengths_eq!(self.len(), value.len());
36+
for (a, b) in self.iter_mut().zip(value.iter()) {
37+
a.cmovnz(b, condition);
38+
}
39+
}
40+
}
41+
};
42+
}
43+
1444
/// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
1545
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
1646
///
@@ -20,18 +50,154 @@ impl Cmov for [u8] {
2050
#[inline]
2151
#[track_caller]
2252
fn cmovnz(&mut self, value: &Self, condition: Condition) {
23-
assert_eq!(
24-
self.len(),
25-
value.len(),
26-
"source slice length ({}) does not match destination slice length ({})",
27-
value.len(),
28-
self.len()
29-
);
53+
assert_lengths_eq!(self.len(), value.len());
3054

31-
cmovnz_slice_unchecked(self, value, condition);
55+
let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
56+
let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
57+
58+
for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
59+
let mut a = Word::from_ne_bytes(*dst_chunk);
60+
let b = Word::from_ne_bytes(*src_chunk);
61+
a.cmovnz(&b, condition);
62+
dst_chunk.copy_from_slice(&a.to_ne_bytes());
63+
}
64+
65+
// Process the remainder a byte-at-a-time.
66+
for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) {
67+
a.cmovnz(b, condition);
68+
}
3269
}
3370
}
3471

72+
/// Optimized implementation for slices of `u16` which coalesces them into word-sized chunks first,
73+
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
74+
///
75+
/// # Panics
76+
/// - if slices have unequal lengths
77+
#[cfg(not(target_pointer_width = "64"))]
78+
#[cfg_attr(docsrs, doc(cfg(true)))]
79+
impl Cmov for [u16] {
80+
#[inline]
81+
#[track_caller]
82+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
83+
assert_lengths_eq!(self.len(), value.len());
84+
85+
let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 2>(self);
86+
let (src_chunks, src_remainder) = slice_as_chunks::<u16, 2>(value);
87+
88+
for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
89+
let mut a = Word::from(dst_chunk[0]) | (Word::from(dst_chunk[1]) << 16);
90+
let b = Word::from(src_chunk[0]) | (Word::from(src_chunk[1]) << 16);
91+
a.cmovnz(&b, condition);
92+
dst_chunk[0] = (a & 0xFFFF) as u16;
93+
dst_chunk[1] = (a >> 16) as u16;
94+
}
95+
96+
// If slice is odd-length
97+
if !dst_remainder.is_empty() {
98+
dst_remainder[0].cmovnz(&src_remainder[0], condition);
99+
}
100+
}
101+
}
102+
103+
/// Optimized implementation for slices of `u16` which coalesces them into word-sized chunks first,
104+
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
105+
///
106+
/// # Panics
107+
/// - if slices have unequal lengths
108+
#[cfg(target_pointer_width = "64")]
109+
#[cfg_attr(docsrs, doc(cfg(true)))]
110+
impl Cmov for [u16] {
111+
#[inline]
112+
#[track_caller]
113+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
114+
assert_lengths_eq!(self.len(), value.len());
115+
116+
#[inline(always)]
117+
fn u16x4_to_u64(input: &[u16; 4]) -> u64 {
118+
Word::from(input[0])
119+
| (Word::from(input[1]) << 16)
120+
| (Word::from(input[2]) << 32)
121+
| (Word::from(input[3]) << 48)
122+
}
123+
124+
let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u16, 4>(self);
125+
let (src_chunks, src_remainder) = slice_as_chunks::<u16, 4>(value);
126+
127+
for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
128+
let mut a = u16x4_to_u64(dst_chunk);
129+
let b = u16x4_to_u64(src_chunk);
130+
a.cmovnz(&b, condition);
131+
dst_chunk[0] = (a & 0xFFFF) as u16;
132+
dst_chunk[1] = ((a >> 16) & 0xFFFF) as u16;
133+
dst_chunk[2] = ((a >> 32) & 0xFFFF) as u16;
134+
dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16;
135+
}
136+
137+
for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) {
138+
a.cmovnz(b, condition);
139+
}
140+
}
141+
}
142+
143+
/// Implementation for slices of `u32` on 32-bit platforms, where we can just loop.
144+
///
145+
/// # Panics
146+
/// - if slices have unequal lengths
147+
#[cfg(not(target_pointer_width = "64"))]
148+
#[cfg_attr(docsrs, doc(cfg(true)))]
149+
impl Cmov for [u32] {
150+
#[inline]
151+
#[track_caller]
152+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
153+
assert_lengths_eq!(self.len(), value.len());
154+
155+
for (a, b) in self.iter_mut().zip(value.iter()) {
156+
a.cmovnz(b, condition);
157+
}
158+
}
159+
}
160+
161+
/// Optimized implementation for slices of `u32` which coalesces them into word-sized chunks first,
162+
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
163+
///
164+
/// # Panics
165+
/// - if slices have unequal lengths
166+
#[cfg(target_pointer_width = "64")]
167+
#[cfg_attr(docsrs, doc(cfg(true)))]
168+
impl Cmov for [u32] {
169+
#[inline]
170+
#[track_caller]
171+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
172+
assert_lengths_eq!(self.len(), value.len());
173+
174+
let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u32, 2>(self);
175+
let (src_chunks, src_remainder) = slice_as_chunks::<u32, 2>(value);
176+
177+
for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
178+
let mut a = Word::from(dst_chunk[0]) | (Word::from(dst_chunk[1]) << 32);
179+
let b = Word::from(src_chunk[0]) | (Word::from(src_chunk[1]) << 32);
180+
a.cmovnz(&b, condition);
181+
dst_chunk[0] = (a & 0xFFFF_FFFF) as u32;
182+
dst_chunk[1] = (a >> 32) as u32;
183+
}
184+
185+
// If slice is odd-length
186+
if !dst_remainder.is_empty() {
187+
dst_remainder[0].cmovnz(&src_remainder[0], condition);
188+
}
189+
}
190+
}
191+
192+
impl_cmov_with_loop!(
193+
u64,
194+
"Implementation for `u64` slices where we can just loop."
195+
);
196+
impl_cmov_with_loop!(
197+
u128,
198+
"Implementation for `u128` slices where we can just loop."
199+
);
200+
35201
/// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
36202
/// then performs [`CmovEq`] at the word-level to cut down on the total number of instructions.
37203
///
@@ -63,27 +229,6 @@ impl CmovEq for [u8] {
63229
}
64230
}
65231

66-
/// Conditionally move `src` to `dst` in constant-time if `condition` is non-zero.
67-
///
68-
/// This function does not check the slices are equal-length and expects the caller to do so first.
69-
#[inline(always)]
70-
pub(crate) fn cmovnz_slice_unchecked(dst: &mut [u8], src: &[u8], condition: Condition) {
71-
let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(dst);
72-
let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(src);
73-
74-
for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
75-
let mut a = Word::from_ne_bytes(*dst_chunk);
76-
let b = Word::from_ne_bytes(*src_chunk);
77-
a.cmovnz(&b, condition);
78-
dst_chunk.copy_from_slice(&a.to_ne_bytes());
79-
}
80-
81-
// Process the remainder a byte-at-a-time.
82-
for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) {
83-
a.cmovnz(b, condition);
84-
}
85-
}
86-
87232
/// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
88233
/// TODO(tarcieri): use upstream function when we bump MSRV
89234
#[inline]

0 commit comments

Comments
 (0)