Skip to content

Commit c844f00

Browse files
committed
cmov: impl optimized CmovEq for [u8] [BREAKING]
Note: version bumped to v0.5.0-pre to denote breaking change (not for release) Perhaps the first and foremost use case for a crate like this (or `subtle` or `ctutils) is comparing byte slices in constant-time, however the existing codegen for this is bad, because it goes a byte-at-a-time, converting them to a `u32` or `u64`, then emitting predication instructions (or using bitwise masking) on each individual byte. Instead this removes the `CmovEq` impl for `[T]` and replaces it with an optimized impl of `CmovEq` for `[u8]`, reusing the code for the optimized `CmovEq` impl for arrays added in #1353. This approach goes in word-sized chunks of the slice, converting them to a word-sized integer (`u32` or `u64`) and using the `CmovEq` impl on those types, which should result in much more efficient code. With this change all of the slice chunking code is now in the `slice` module, which lets us move the vendored copies of `[T]::as_chunks(_mut)` there, get rid of a `utils` module, and rename it back to `macros` (though that's perhaps a misnomer as it contains only one macro). A small change to the `Cmov` impl added in #1354: it panics if the input sizes aren't equal, using the same panic message as `copy_from_slice`.
1 parent 18d53e8 commit c844f00

File tree

9 files changed

+254
-196
lines changed

9 files changed

+254
-196
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmov/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "cmov"
3-
version = "0.4.6"
3+
version = "0.5.0-pre"
44
authors = ["RustCrypto Developers"]
55
edition = "2024"
66
rust-version = "1.85"

cmov/src/array.rs

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,22 @@
11
//! Trait impls for core arrays.
22
3-
use crate::{
4-
Cmov, CmovEq, Condition,
5-
utils::{WORD_SIZE, Word, slice_as_chunks},
6-
};
3+
use crate::{Cmov, CmovEq, Condition, slice::cmovnz_slice_unchecked};
74

85
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
96
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
107
impl<const N: usize> Cmov for [u8; N] {
118
#[inline]
129
fn cmovnz(&mut self, value: &Self, condition: Condition) {
13-
self.as_mut_slice().cmovnz(value, condition);
10+
// "unchecked" means it doesn't check the inputs are equal-length, however they are in this
11+
// context because they're two equal-sized arrays
12+
cmovnz_slice_unchecked(self, value, condition);
1413
}
1514
}
1615

1716
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
1817
/// then performs [`CmovEq`] at the word-level to cut down on the total number of instructions.
1918
impl<const N: usize> CmovEq for [u8; N] {
2019
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
21-
let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
22-
let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
23-
24-
for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
25-
let a = Word::from_ne_bytes(*self_chunk);
26-
let b = Word::from_ne_bytes(*rhs_chunk);
27-
a.cmovne(&b, input, output);
28-
}
29-
30-
// Process the remainder a byte-at-a-time.
31-
for (a, b) in self_remainder.iter().zip(rhs_remainder.iter()) {
32-
a.cmovne(b, input, output);
33-
}
20+
self.as_slice().cmovne(rhs, input, output);
3421
}
3522
}

cmov/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)]
2929

3030
#[macro_use]
31-
mod utils;
31+
mod macros;
3232

3333
#[cfg(not(miri))]
3434
#[cfg(target_arch = "aarch64")]

cmov/src/macros.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//! Macro definitions.
2+
3+
/// Generates a mask the width of the given unsigned integer type `$uint` if the input value is
4+
/// non-zero.
5+
///
6+
/// Uses `core::hint::black_box` to coerce our desired codegen based on real-world observations
7+
/// of the assembly generated by Rust/LLVM.
8+
///
9+
/// Implemented as a macro instead of a generic function because it uses functionality for which
10+
/// there aren't available `core` traits, e.g. `wrapping_neg`.
11+
///
12+
/// See also:
13+
/// - CVE-2026-23519
14+
/// - RustCrypto/utils#1332
15+
macro_rules! masknz {
16+
($value:tt : $uint:ident) => {{
17+
let mut value: $uint = $value;
18+
value |= value.wrapping_neg(); // has MSB `1` if non-zero, `0` if zero
19+
20+
// use `black_box` to obscure we're computing a 1-bit value
21+
core::hint::black_box(
22+
value >> ($uint::BITS - 1), // Extract MSB
23+
)
24+
.wrapping_neg() // Generate $uint::MAX mask if `black_box` outputs `1`
25+
}};
26+
}
27+
28+
#[cfg(test)]
29+
mod tests {
30+
// Spot check up to a given limit
31+
const TEST_LIMIT: u8 = 128;
32+
33+
macro_rules! masknz_test {
34+
( $($name:ident : $uint:ident),+ ) => {
35+
$(
36+
#[test]
37+
fn $name() {
38+
assert_eq!(masknz!(0: $uint), 0);
39+
40+
// Test lower values
41+
for i in 1..=$uint::from(TEST_LIMIT) {
42+
assert_eq!(masknz!(i: $uint), $uint::MAX);
43+
}
44+
45+
// Test upper values
46+
for i in ($uint::MAX - $uint::from(TEST_LIMIT))..=$uint::MAX {
47+
assert_eq!(masknz!(i: $uint), $uint::MAX);
48+
}
49+
}
50+
)+
51+
}
52+
}
53+
54+
// Ensure the macro works with any types we might use it with (we only use u8, u32, and u64)
55+
masknz_test!(
56+
masknz_u8: u8,
57+
masknz_u16: u16,
58+
masknz_u32: u32,
59+
masknz_u64: u64,
60+
masknz_u128: u128
61+
);
62+
}

cmov/src/slice.rs

Lines changed: 129 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
11
//! Trait impls for core slices.
22
3-
use crate::utils::{WORD_SIZE, Word, slice_as_chunks, slice_as_chunks_mut};
43
use crate::{Cmov, CmovEq, Condition};
4+
use core::slice;
5+
6+
// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
7+
#[cfg(not(target_pointer_width = "64"))]
8+
type Word = u32;
9+
#[cfg(target_pointer_width = "64")]
10+
type Word = u64;
11+
const WORD_SIZE: usize = size_of::<Word>();
12+
const _: () = assert!(size_of::<usize>() <= WORD_SIZE, "unexpected word size");
513

614
/// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
715
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
16+
///
17+
/// # Panics
18+
/// - if slices have unequal lengths
819
impl Cmov for [u8] {
920
#[inline]
1021
fn cmovnz(&mut self, value: &Self, condition: Condition) {
11-
let (self_chunks, self_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
12-
let (value_chunks, value_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
13-
14-
for (self_chunk, value_chunk) in self_chunks.iter_mut().zip(value_chunks.iter()) {
15-
let mut a = Word::from_ne_bytes(*self_chunk);
16-
let b = Word::from_ne_bytes(*value_chunk);
17-
a.cmovnz(&b, condition);
18-
self_chunk.copy_from_slice(&a.to_ne_bytes());
19-
}
22+
assert_eq!(
23+
self.len(),
24+
value.len(),
25+
"source slice length ({}) does not match destination slice length ({})",
26+
value.len(),
27+
self.len()
28+
);
2029

21-
// Process the remainder a byte-at-a-time.
22-
for (a, b) in self_remainder.iter_mut().zip(value_remainder.iter()) {
23-
a.cmovnz(b, condition);
24-
}
30+
cmovnz_slice_unchecked(self, value, condition);
2531
}
2632
}
2733

28-
impl<T: CmovEq> CmovEq for [T] {
34+
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
35+
/// then performs [`CmovEq`] at the word-level to cut down on the total number of instructions.
36+
///
37+
/// This is only constant-time for equal-length slices, and will short-circuit and set `output`
38+
/// in the event the slices are of unequal length.
39+
impl CmovEq for [u8] {
40+
#[inline]
2941
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
3042
// Short-circuit the comparison if the slices are of different lengths, and set the output
3143
// condition to the input condition.
@@ -34,9 +46,109 @@ impl<T: CmovEq> CmovEq for [T] {
3446
return;
3547
}
3648

37-
// Compare each byte.
38-
for (a, b) in self.iter().zip(rhs.iter()) {
49+
let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
50+
let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
51+
52+
for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
53+
let a = Word::from_ne_bytes(*self_chunk);
54+
let b = Word::from_ne_bytes(*rhs_chunk);
55+
a.cmovne(&b, input, output);
56+
}
57+
58+
// Process the remainder a byte-at-a-time.
59+
for (a, b) in self_remainder.iter().zip(rhs_remainder.iter()) {
3960
a.cmovne(b, input, output);
4061
}
4162
}
4263
}
64+
65+
/// Conditionally move `src` to `dst` in constant-time if `condition` is non-zero.
66+
///
67+
/// This function does not check the slices are equal-length and expects the caller to do so first.
68+
#[inline(always)]
69+
pub(crate) fn cmovnz_slice_unchecked(dst: &mut [u8], src: &[u8], condition: Condition) {
70+
let (dst_chunks, dst_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(dst);
71+
let (src_chunks, src_remainder) = slice_as_chunks::<u8, WORD_SIZE>(src);
72+
73+
for (dst_chunk, src_chunk) in dst_chunks.iter_mut().zip(src_chunks.iter()) {
74+
let mut a = Word::from_ne_bytes(*dst_chunk);
75+
let b = Word::from_ne_bytes(*src_chunk);
76+
a.cmovnz(&b, condition);
77+
dst_chunk.copy_from_slice(&a.to_ne_bytes());
78+
}
79+
80+
// Process the remainder a byte-at-a-time.
81+
for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) {
82+
a.cmovnz(b, condition);
83+
}
84+
}
85+
86+
/// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
87+
/// TODO(tarcieri): use upstream function when we bump MSRV
88+
#[inline]
89+
#[track_caller]
90+
#[must_use]
91+
#[allow(clippy::integer_division_remainder_used)]
92+
fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
93+
assert!(N != 0, "chunk size must be non-zero");
94+
let len_rounded_down = slice.len() / N * N;
95+
// SAFETY: The rounded-down value is always the same or smaller than the
96+
// original length, and thus must be in-bounds of the slice.
97+
let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
98+
// SAFETY: We already panicked for zero, and ensured by construction
99+
// that the length of the subslice is a multiple of N.
100+
let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
101+
(array_slice, remainder)
102+
}
103+
104+
/// Rust core `[T]::as_chunks_mut` vendored because of its 1.88 MSRV.
105+
/// TODO(tarcieri): use upstream function when we bump MSRV
106+
#[inline]
107+
#[track_caller]
108+
#[must_use]
109+
#[allow(clippy::integer_division_remainder_used)]
110+
fn slice_as_chunks_mut<T, const N: usize>(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) {
111+
assert!(N != 0, "chunk size must be non-zero");
112+
let len_rounded_down = slice.len() / N * N;
113+
// SAFETY: The rounded-down value is always the same or smaller than the
114+
// original length, and thus must be in-bounds of the slice.
115+
let (multiple_of_n, remainder) = unsafe { slice.split_at_mut_unchecked(len_rounded_down) };
116+
// SAFETY: We already panicked for zero, and ensured by construction
117+
// that the length of the subslice is a multiple of N.
118+
let array_slice = unsafe { slice_as_chunks_unchecked_mut(multiple_of_n) };
119+
(array_slice, remainder)
120+
}
121+
122+
/// Rust core `[T]::as_chunks_unchecked` vendored because of its 1.88 MSRV.
123+
/// TODO(tarcieri): use upstream function when we bump MSRV
124+
#[inline]
125+
#[must_use]
126+
#[track_caller]
127+
#[allow(clippy::integer_division_remainder_used)]
128+
unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
129+
// SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
130+
const { debug_assert!(N != 0) };
131+
debug_assert_eq!(slice.len() % N, 0);
132+
let new_len = slice.len() / N;
133+
134+
// SAFETY: We cast a slice of `new_len * N` elements into
135+
// a slice of `new_len` many `N` elements chunks.
136+
unsafe { slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
137+
}
138+
139+
/// Rust core `[T]::as_chunks_unchecked_mut` vendored because of its 1.88 MSRV.
140+
/// TODO(tarcieri): use upstream function when we bump MSRV
141+
#[inline]
142+
#[must_use]
143+
#[track_caller]
144+
#[allow(clippy::integer_division_remainder_used)]
145+
unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &mut [[T; N]] {
146+
// SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
147+
const { debug_assert!(N != 0) };
148+
debug_assert_eq!(slice.len() % N, 0);
149+
let new_len = slice.len() / N;
150+
151+
// SAFETY: We cast a slice of `new_len * N` elements into
152+
// a slice of `new_len` many `N` elements chunks.
153+
unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
154+
}

0 commit comments

Comments
 (0)