Skip to content

Commit b53e6bf

Browse files
authored
cmov: add generic masksel function to portable backend (#1342)
Generically implements selecting between two values based upon a provided mask value, rather than duplicating this logic
1 parent 65c6520 commit b53e6bf

File tree

1 file changed

+87
-65
lines changed

1 file changed

+87
-65
lines changed

cmov/src/portable.rs

Lines changed: 87 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//! optimizer potentially inserting branches.
88
99
use crate::{Cmov, CmovEq, Condition};
10+
use core::ops::{BitAnd, BitOr, Not};
1011

12+
// Uses `Cmov` impl for `u32`
1113
impl Cmov for u16 {
1214
#[inline]
1315
fn cmovnz(&mut self, value: &u16, condition: Condition) {
@@ -24,6 +26,7 @@ impl Cmov for u16 {
2426
}
2527
}
2628

29+
// Uses `CmovEq` impl for `u32`
2730
impl CmovEq for u16 {
2831
#[inline]
2932
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
@@ -39,59 +42,103 @@ impl CmovEq for u16 {
3942
impl Cmov for u32 {
4043
#[inline]
4144
fn cmovnz(&mut self, value: &Self, condition: Condition) {
42-
let mask = masknz32(condition);
43-
*self = (*self & !mask) | (*value & mask);
45+
*self = masksel(*self, *value, masknz32(condition.into()));
4446
}
4547

4648
#[inline]
4749
fn cmovz(&mut self, value: &Self, condition: Condition) {
48-
let mask = masknz32(condition);
49-
*self = (*self & mask) | (*value & !mask);
50+
*self = masksel(*self, *value, !masknz32(condition.into()));
5051
}
5152
}
5253

5354
impl CmovEq for u32 {
5455
#[inline]
5556
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
56-
let ne = testne32(*self, *rhs);
57-
output.cmovnz(&input, ne);
57+
output.cmovnz(&input, testne32(*self, *rhs));
5858
}
5959

6060
#[inline]
6161
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
62-
let eq = testeq32(*self, *rhs);
63-
output.cmovnz(&input, eq);
62+
output.cmovnz(&input, testeq32(*self, *rhs));
6463
}
6564
}
6665

6766
impl Cmov for u64 {
6867
#[inline]
6968
fn cmovnz(&mut self, value: &Self, condition: Condition) {
70-
let mask = masknz64(condition);
71-
*self = (*self & !mask) | (*value & mask);
69+
*self = masksel(*self, *value, masknz64(condition.into()));
7270
}
7371

7472
#[inline]
7573
fn cmovz(&mut self, value: &Self, condition: Condition) {
76-
let mask = masknz64(condition);
77-
*self = (*self & mask) | (*value & !mask);
74+
*self = masksel(*self, *value, !masknz64(condition.into()));
7875
}
7976
}
8077

8178
impl CmovEq for u64 {
8279
#[inline]
8380
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
84-
let ne = testne64(*self, *rhs);
85-
output.cmovnz(&input, ne);
81+
output.cmovnz(&input, testne64(*self, *rhs));
8682
}
8783

8884
#[inline]
8985
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
90-
let eq = testeq64(*self, *rhs);
91-
output.cmovnz(&input, eq);
86+
output.cmovnz(&input, testeq64(*self, *rhs));
9287
}
9388
}
9489

90+
/// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
91+
#[cfg(not(target_arch = "arm"))]
92+
fn masknz32(condition: u32) -> u32 {
93+
testnz32(condition).wrapping_neg()
94+
}
95+
96+
/// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
97+
#[cfg(not(target_arch = "arm"))]
98+
fn masknz64(condition: u64) -> u64 {
99+
testnz64(condition).wrapping_neg()
100+
}
101+
102+
/// Optimized mask generation for ARM32 targets.
103+
///
104+
/// This is written in assembly both for performance and because we've had problematic code
105+
/// generation in this routine in the past which lead to the insertion of a branch, which using
106+
/// assembly should guarantee won't happen again in the future (CVE-2026-23519).
107+
#[cfg(target_arch = "arm")]
108+
fn masknz32(condition: u32) -> u32 {
109+
let mut mask = condition;
110+
unsafe {
111+
core::arch::asm!(
112+
"rsbs {0}, {0}, #0", // Reverse subtract
113+
"sbcs {0}, {0}, {0}", // Subtract with carry, setting flags
114+
inout(reg) mask,
115+
options(nostack, nomem),
116+
);
117+
}
118+
mask
119+
}
120+
121+
/// 64-bit wrapper for targets that implement 32-bit mask generation in assembly.
122+
#[cfg(target_arch = "arm")]
123+
fn masknz64(condition: u64) -> u64 {
124+
let lo = masknz32((condition & 0xFFFF_FFFF) as u32);
125+
let hi = masknz32((condition >> 32) as u32);
126+
let mask = (lo | hi) as u64;
127+
mask | mask << 32
128+
}
129+
130+
/// Given a supplied mask of `0` or all 1-bits (i.e. `u*::MAX`), select `a` if the mask is all-zeros
131+
/// and `b` if the mask is all-ones.
132+
///
133+
/// This function shouldn't be used with a mask that isn't `0` or `u*::MAX`.
134+
#[inline]
135+
fn masksel<T>(a: T, b: T, mask: T) -> T
136+
where
137+
T: BitAnd<Output = T> + BitOr<Output = T> + Copy + Not<Output = T>,
138+
{
139+
(a & !mask) | (b & mask)
140+
}
141+
95142
/// Returns `1` if `x` is equal to `y`, otherwise returns `0` (32-bit version)
96143
fn testeq32(x: u32, y: u32) -> Condition {
97144
testne32(x, y) ^ 1
@@ -120,46 +167,37 @@ fn testnz32(mut x: u32) -> u32 {
120167

121168
/// Returns `0` if `x` is `0`, otherwise returns `1` (64-bit version)
122169
fn testnz64(mut x: u64) -> u64 {
123-
x |= x.wrapping_neg(); // MSB now set if non-zero
170+
x |= x.wrapping_neg(); // MSB now set if non-zero (or unset if zero)
124171
core::hint::black_box(x >> (u64::BITS - 1)) // Extract MSB
125172
}
126173

127-
/// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
128-
#[cfg(not(target_arch = "arm"))]
129-
fn masknz32(condition: Condition) -> u32 {
130-
testnz32(condition.into()).wrapping_neg()
131-
}
132-
133-
/// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
134-
#[cfg(not(target_arch = "arm"))]
135-
fn masknz64(condition: Condition) -> u64 {
136-
testnz64(condition.into()).wrapping_neg()
137-
}
174+
#[cfg(test)]
175+
mod tests {
176+
#[test]
177+
fn masknz32() {
178+
assert_eq!(super::masknz32(0), 0);
179+
for i in 1..=u8::MAX {
180+
assert_eq!(super::masknz32(i.into()), u32::MAX);
181+
}
182+
}
138183

139-
/// Optimized mask generation for ARM32 targets.
140-
#[cfg(target_arch = "arm")]
141-
fn masknz32(condition: u8) -> u32 {
142-
let mut out = condition as u32;
143-
unsafe {
144-
core::arch::asm!(
145-
"rsbs {0}, {0}, #0", // Reverse subtract
146-
"sbcs {0}, {0}, {0}", // Subtract with carry, setting flags
147-
inout(reg) out,
148-
options(nostack, nomem),
149-
);
184+
#[test]
185+
fn masknz64() {
186+
assert_eq!(super::masknz64(0), 0);
187+
for i in 1..=u8::MAX {
188+
assert_eq!(super::masknz64(i.into()), u64::MAX);
189+
}
150190
}
151-
out
152-
}
153191

154-
/// 64-bit wrapper for targets that implement 32-bit mask generation in assembly.
155-
#[cfg(target_arch = "arm")]
156-
fn masknz64(condition: u8) -> u64 {
157-
let mask = masknz32(condition) as u64;
158-
mask | mask << 32
159-
}
192+
#[test]
193+
fn masksel() {
194+
assert_eq!(super::masksel(23u8, 42u8, 0u8), 23u8);
195+
assert_eq!(super::masksel(23u8, 42u8, u8::MAX), 42u8);
196+
197+
assert_eq!(super::masksel(17u32, 101077u32, 0u32), 17u32);
198+
assert_eq!(super::masksel(17u32, 101077u32, u32::MAX), 101077u32);
199+
}
160200

161-
#[cfg(test)]
162-
mod tests {
163201
#[test]
164202
fn testeq32() {
165203
assert_eq!(super::testeq32(0, 0), 1);
@@ -219,20 +257,4 @@ mod tests {
219257
assert_eq!(super::testnz64(i as u64), 1);
220258
}
221259
}
222-
223-
#[test]
224-
fn masknz32() {
225-
assert_eq!(super::masknz32(0), 0);
226-
for i in 1..=u8::MAX {
227-
assert_eq!(super::masknz32(i), u32::MAX);
228-
}
229-
}
230-
231-
#[test]
232-
fn masknz64() {
233-
assert_eq!(super::masknz64(0), 0);
234-
for i in 1..=u8::MAX {
235-
assert_eq!(super::masknz64(i), u64::MAX);
236-
}
237-
}
238260
}

0 commit comments

Comments
 (0)