Skip to content

Commit 65f1aec

Browse files
authored
cmov: use masksel to impl CmovEq (#1343)
Uses the `masksel` function added in #1342 to impl `CmovEq`. This should be less circuitious than the previous implementation. This approach means we can get rid of all of the `test*` functions like `testeq*`, `testne*`, and `testnz*` and implement everything in terms of mask generation. The pure Rust versions of `masknz32` and `masknz64` still internally use `black_box` to prevent optimizations around the non-zero bit extracted before computing the mask.
1 parent b53e6bf commit 65f1aec

File tree

1 file changed

+91
-98
lines changed

1 file changed

+91
-98
lines changed

cmov/src/portable.rs

Lines changed: 91 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ impl Cmov for u32 {
5454
impl CmovEq for u32 {
5555
#[inline]
5656
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
57-
output.cmovnz(&input, testne32(*self, *rhs));
57+
*output = masksel(*output, input, (maskne32(*self, *rhs) & 0xFF) as u8);
5858
}
5959

6060
#[inline]
6161
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
62-
output.cmovnz(&input, testeq32(*self, *rhs));
62+
*output = masksel(*output, input, (maskeq32(*self, *rhs) & 0xFF) as u8);
6363
}
6464
}
6565

@@ -78,25 +78,49 @@ impl Cmov for u64 {
7878
impl CmovEq for u64 {
7979
#[inline]
8080
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
81-
output.cmovnz(&input, testne64(*self, *rhs));
81+
*output = masksel(*output, input, (maskne64(*self, *rhs) & 0xFF) as u8);
8282
}
8383

8484
#[inline]
8585
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
86-
output.cmovnz(&input, testeq64(*self, *rhs));
86+
*output = masksel(*output, input, (maskeq64(*self, *rhs) & 0xFF) as u8);
8787
}
8888
}
8989

90+
/// Returns `u32::MAX` if `x` is equal to `y`, otherwise returns `0` (32-bit version)
91+
fn maskeq32(x: u32, y: u32) -> u32 {
92+
!maskne32(x, y)
93+
}
94+
95+
/// Returns `u32::MAX` if `x` is equal to `y`, otherwise returns `0` (64-bit version)
96+
fn maskeq64(x: u64, y: u64) -> u64 {
97+
!maskne64(x, y)
98+
}
99+
100+
/// Returns `0` if `x` is equal to `y`, otherwise returns `1` (32-bit version)
101+
fn maskne32(x: u32, y: u32) -> u32 {
102+
masknz32(x ^ y)
103+
}
104+
105+
/// Returns `0` if `x` is equal to `y`, otherwise returns `1` (64-bit version)
106+
fn maskne64(x: u64, y: u64) -> u64 {
107+
masknz64(x ^ y)
108+
}
109+
90110
/// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
91111
#[cfg(not(target_arch = "arm"))]
92112
fn masknz32(condition: u32) -> u32 {
93-
testnz32(condition).wrapping_neg()
113+
let x = condition | condition.wrapping_neg(); // MSB of `x` now `1` if non-zero
114+
let nz = core::hint::black_box(x >> (u32::BITS - 1)); // Extract MSB
115+
nz.wrapping_neg()
94116
}
95117

96118
/// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
97119
#[cfg(not(target_arch = "arm"))]
98120
fn masknz64(condition: u64) -> u64 {
99-
testnz64(condition).wrapping_neg()
121+
let x = condition | condition.wrapping_neg(); // MSB of `x` now `1` if non-zero
122+
let nz = core::hint::black_box(x >> (u64::BITS - 1)); // Extract MSB
123+
nz.wrapping_neg()
100124
}
101125

102126
/// Optimized mask generation for ARM32 targets.
@@ -139,53 +163,76 @@ where
139163
(a & !mask) | (b & mask)
140164
}
141165

142-
/// Returns `1` if `x` is equal to `y`, otherwise returns `0` (32-bit version)
143-
fn testeq32(x: u32, y: u32) -> Condition {
144-
testne32(x, y) ^ 1
145-
}
146-
147-
/// Returns `1` if `x` is equal to `y`, otherwise returns `0` (64-bit version)
148-
fn testeq64(x: u64, y: u64) -> Condition {
149-
testne64(x, y) ^ 1
150-
}
166+
#[cfg(test)]
167+
mod tests {
168+
// Spot check up to a given limit
169+
const TEST_LIMIT: u32 = 65536;
151170

152-
/// Returns `0` if `x` is equal to `y`, otherwise returns `1` (32-bit version)
153-
fn testne32(x: u32, y: u32) -> Condition {
154-
(testnz32(x ^ y) & 0xFF) as Condition
155-
}
171+
#[test]
172+
fn maskeq32() {
173+
assert_eq!(super::maskeq32(0, 0), u32::MAX);
174+
assert_eq!(super::maskeq32(1, 0), 0);
175+
assert_eq!(super::maskeq32(0, 1), 0);
176+
assert_eq!(super::maskeq32(1, 1), u32::MAX);
177+
assert_eq!(super::maskeq32(u32::MAX, 1), 0);
178+
assert_eq!(super::maskeq32(1, u32::MAX), 0);
179+
assert_eq!(super::maskeq32(u32::MAX, u32::MAX), u32::MAX);
180+
}
156181

157-
/// Returns `0` if `x` is equal to `y`, otherwise returns `1` (64-bit version)
158-
fn testne64(x: u64, y: u64) -> Condition {
159-
(testnz64(x ^ y) & 0xFF) as Condition
160-
}
182+
#[test]
183+
fn maskeq64() {
184+
assert_eq!(super::maskeq64(0, 0), u64::MAX);
185+
assert_eq!(super::maskeq64(1, 0), 0);
186+
assert_eq!(super::maskeq64(0, 1), 0);
187+
assert_eq!(super::maskeq64(1, 1), u64::MAX);
188+
assert_eq!(super::maskeq64(u64::MAX, 1), 0);
189+
assert_eq!(super::maskeq64(1, u64::MAX), 0);
190+
assert_eq!(super::maskeq64(u64::MAX, u64::MAX), u64::MAX);
191+
}
161192

162-
/// Returns `0` if `x` is `0`, otherwise returns `1` (32-bit version)
163-
fn testnz32(mut x: u32) -> u32 {
164-
x |= x.wrapping_neg(); // MSB now set if non-zero
165-
core::hint::black_box(x >> (u32::BITS - 1)) // Extract MSB
166-
}
193+
#[test]
194+
fn maskne32() {
195+
assert_eq!(super::maskne32(0, 0), 0);
196+
assert_eq!(super::maskne32(1, 0), u32::MAX);
197+
assert_eq!(super::maskne32(0, 1), u32::MAX);
198+
assert_eq!(super::maskne32(1, 1), 0);
199+
assert_eq!(super::maskne32(u32::MAX, 1), u32::MAX);
200+
assert_eq!(super::maskne32(1, u32::MAX), u32::MAX);
201+
assert_eq!(super::maskne32(u32::MAX, u32::MAX), 0);
202+
}
167203

168-
/// Returns `0` if `x` is `0`, otherwise returns `1` (64-bit version)
169-
fn testnz64(mut x: u64) -> u64 {
170-
x |= x.wrapping_neg(); // MSB now set if non-zero (or unset if zero)
171-
core::hint::black_box(x >> (u64::BITS - 1)) // Extract MSB
172-
}
204+
#[test]
205+
fn maskne64() {
206+
assert_eq!(super::maskne64(0, 0), 0);
207+
assert_eq!(super::maskne64(1, 0), u64::MAX);
208+
assert_eq!(super::maskne64(0, 1), u64::MAX);
209+
assert_eq!(super::maskne64(1, 1), 0);
210+
assert_eq!(super::maskne64(u64::MAX, 1), u64::MAX);
211+
assert_eq!(super::maskne64(1, u64::MAX), u64::MAX);
212+
assert_eq!(super::maskne64(u64::MAX, u64::MAX), 0);
213+
}
173214

174-
#[cfg(test)]
175-
mod tests {
176215
#[test]
177216
fn masknz32() {
178217
assert_eq!(super::masknz32(0), 0);
179-
for i in 1..=u8::MAX {
180-
assert_eq!(super::masknz32(i.into()), u32::MAX);
218+
for i in 1..=TEST_LIMIT {
219+
assert_eq!(super::masknz32(i), u32::MAX);
220+
}
221+
222+
for i in (u32::MAX - TEST_LIMIT)..=u32::MAX {
223+
assert_eq!(super::masknz32(i), u32::MAX);
181224
}
182225
}
183226

184227
#[test]
185228
fn masknz64() {
186229
assert_eq!(super::masknz64(0), 0);
187-
for i in 1..=u8::MAX {
188-
assert_eq!(super::masknz64(i.into()), u64::MAX);
230+
for i in 1..=(TEST_LIMIT as u64) {
231+
assert_eq!(super::masknz64(i), u64::MAX);
232+
}
233+
234+
for i in (u64::MAX - TEST_LIMIT as u64)..=u64::MAX {
235+
assert_eq!(super::masknz64(i), u64::MAX);
189236
}
190237
}
191238

@@ -196,65 +243,11 @@ mod tests {
196243

197244
assert_eq!(super::masksel(17u32, 101077u32, 0u32), 17u32);
198245
assert_eq!(super::masksel(17u32, 101077u32, u32::MAX), 101077u32);
199-
}
200-
201-
#[test]
202-
fn testeq32() {
203-
assert_eq!(super::testeq32(0, 0), 1);
204-
assert_eq!(super::testeq32(1, 0), 0);
205-
assert_eq!(super::testeq32(0, 1), 0);
206-
assert_eq!(super::testeq32(1, 1), 1);
207-
assert_eq!(super::testeq32(u32::MAX, 1), 0);
208-
assert_eq!(super::testeq32(1, u32::MAX), 0);
209-
assert_eq!(super::testeq32(u32::MAX, u32::MAX), 1);
210-
}
211-
212-
#[test]
213-
fn testeq64() {
214-
assert_eq!(super::testeq64(0, 0), 1);
215-
assert_eq!(super::testeq64(1, 0), 0);
216-
assert_eq!(super::testeq64(0, 1), 0);
217-
assert_eq!(super::testeq64(1, 1), 1);
218-
assert_eq!(super::testeq64(u64::MAX, 1), 0);
219-
assert_eq!(super::testeq64(1, u64::MAX), 0);
220-
assert_eq!(super::testeq64(u64::MAX, u64::MAX), 1);
221-
}
222246

223-
#[test]
224-
fn testne32() {
225-
assert_eq!(super::testne32(0, 0), 0);
226-
assert_eq!(super::testne32(1, 0), 1);
227-
assert_eq!(super::testne32(0, 1), 1);
228-
assert_eq!(super::testne32(1, 1), 0);
229-
assert_eq!(super::testne32(u32::MAX, 1), 1);
230-
assert_eq!(super::testne32(1, u32::MAX), 1);
231-
assert_eq!(super::testne32(u32::MAX, u32::MAX), 0);
232-
}
233-
234-
#[test]
235-
fn testne64() {
236-
assert_eq!(super::testne64(0, 0), 0);
237-
assert_eq!(super::testne64(1, 0), 1);
238-
assert_eq!(super::testne64(0, 1), 1);
239-
assert_eq!(super::testne64(1, 1), 0);
240-
assert_eq!(super::testne64(u64::MAX, 1), 1);
241-
assert_eq!(super::testne64(1, u64::MAX), 1);
242-
assert_eq!(super::testne64(u64::MAX, u64::MAX), 0);
243-
}
244-
245-
#[test]
246-
fn testnz32() {
247-
assert_eq!(super::testnz32(0), 0);
248-
for i in 1..=u8::MAX {
249-
assert_eq!(super::testnz32(i as u32), 1);
250-
}
251-
}
252-
253-
#[test]
254-
fn testnz64() {
255-
assert_eq!(super::testnz64(0), 0);
256-
for i in 1..=u8::MAX {
257-
assert_eq!(super::testnz64(i as u64), 1);
258-
}
247+
assert_eq!(super::masksel(129u64, 0xFFEEDDCCBBAA9988u64, 0u64), 129u64);
248+
assert_eq!(
249+
super::masksel(129u64, 0xFFEEDDCCBBAA9988u64, u64::MAX),
250+
0xFFEEDDCCBBAA9988u64
251+
);
259252
}
260253
}

0 commit comments

Comments
 (0)