Skip to content

Commit 4bb07d1

Browse files
authored
cmov: impl optimized Cmov for [u8; N] (#1350)
Adds a specialized impl of `Cmov` for byte arrays of generic size which first coalesces elements of the array into word-sized chunks, then calls `Cmov` on those. This should result in significantly more efficient codegen, which can also take advantage of compile-time knowledge of `N` to potentially unroll loops. Unfortunately without specialization this means we can't impl `Cmov` for other types of arrays, but downstream consumers can just iterate over them and call `Cmov::cmov*` on each element easily enough, whereas this optimized implementation for byte arrays actually provides something a lot less trivial than looping over an array.
1 parent ac50346 commit 4bb07d1

File tree

4 files changed

+162
-17
lines changed

4 files changed

+162
-17
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
target
2-
**/*proptest-regressions

cmov/src/lib.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,51 @@ macro_rules! impl_cmov_traits_for_signed_ints {
196196

197197
impl_cmov_traits_for_signed_ints!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128);
198198

199+
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
200+
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
201+
///
202+
/// With compile-time knowledge of `N`, the compiler should also be able to unroll the loops in
203+
/// cases where efficiency would benefit, reducing the implementation to a sequence of word-sized
204+
/// [`Cmov`] ops (and if `N` isn't word-aligned, followed by a series of 1-byte ops).
205+
impl<const N: usize> Cmov for [u8; N] {
206+
#[inline]
207+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
208+
// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
209+
#[cfg(not(target_pointer_width = "64"))]
210+
type Chunk = u32;
211+
#[cfg(target_pointer_width = "64")]
212+
type Chunk = u64;
213+
const CHUNK_SIZE: usize = size_of::<Chunk>();
214+
215+
// Load a chunk from a byte slice
216+
// TODO(tarcieri): use `array_chunks` when stable (rust-lang/rust##100450)
217+
#[inline]
218+
fn load_chunk(slice: &[u8]) -> Chunk {
219+
Chunk::from_ne_bytes(slice.try_into().expect("should be the right size"))
220+
}
221+
222+
let mut self_chunks = self.chunks_exact_mut(CHUNK_SIZE);
223+
let mut value_chunks = value.chunks_exact(CHUNK_SIZE);
224+
225+
// Process as much input as we can a `Chunk`-at-a-time.
226+
for (self_chunk, value_chunk) in self_chunks.by_ref().zip(value_chunks.by_ref()) {
227+
let mut a = load_chunk(self_chunk);
228+
let b = load_chunk(value_chunk);
229+
a.cmovnz(&b, condition);
230+
self_chunk.copy_from_slice(&a.to_ne_bytes());
231+
}
232+
233+
// Process the remainder a byte-at-a-time.
234+
for (a, b) in self_chunks
235+
.into_remainder()
236+
.iter_mut()
237+
.zip(value_chunks.remainder().iter())
238+
{
239+
a.cmovnz(b, condition);
240+
}
241+
}
242+
}
243+
199244
impl<T: CmovEq> CmovEq for [T] {
200245
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
201246
let mut tmp = 1u8;

cmov/tests/core_impls.rs

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,40 +122,88 @@ int_tests!(
122122
0x2222_2222_2222_2222_3333_3333_3333_3333u128
123123
);
124124

125+
mod arrays {
126+
use cmov::Cmov;
127+
128+
// 127-elements: large enough to test the chunk loop, odd-sized to test remainder handling,
129+
// and with each element different to ensure the operations actually work
130+
const EXAMPLE_A: [u8; 127] = [
131+
0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11,
132+
0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
133+
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
134+
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e,
135+
0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d,
136+
0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c,
137+
0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b,
138+
0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a,
139+
0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
140+
];
141+
142+
const EXAMPLE_B: [u8; 127] = [
143+
0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8, 0xf7, 0xf6, 0xf5, 0xf4, 0xf3, 0xf2, 0xf1,
144+
0xf0, 0xef, 0xee, 0xed, 0xec, 0xeb, 0xea, 0xe9, 0xe8, 0xe7, 0xe6, 0xe5, 0xe4, 0xe3, 0xe2,
145+
0xe1, 0xe0, 0xdf, 0xde, 0xdd, 0xdc, 0xdb, 0xda, 0xd9, 0xd8, 0xd7, 0xd6, 0xd5, 0xd4, 0xd3,
146+
0xd2, 0xd1, 0xd0, 0xcf, 0xce, 0xcd, 0xcc, 0xcb, 0xca, 0xc9, 0xc8, 0xc7, 0xc6, 0xc5, 0xc4,
147+
0xc3, 0xc2, 0xc1, 0xc0, 0xbf, 0xbe, 0xbd, 0xbc, 0xbb, 0xba, 0xb9, 0xb8, 0xb7, 0xb6, 0xb5,
148+
0xb4, 0xb3, 0xb2, 0xb1, 0xb0, 0xaf, 0xae, 0xad, 0xac, 0xab, 0xaa, 0xa9, 0xa8, 0xa7, 0xa6,
149+
0xa5, 0xa4, 0xa3, 0xa2, 0xa1, 0xa0, 0x9f, 0x9e, 0x9d, 0x9c, 0x9b, 0x9a, 0x99, 0x98, 0x97,
150+
0x96, 0x95, 0x94, 0x93, 0x92, 0x91, 0x90, 0x8f, 0x8e, 0x8d, 0x8c, 0x8b, 0x8a, 0x89, 0x88,
151+
0x87, 0x86, 0x85, 0x84, 0x83, 0x82, 0x81,
152+
];
153+
154+
/// Note: we only provide this impl for `[u8; N]` so we have some optimized way of operating
155+
/// over byte arrays. Unfortunately without specialization we can't also provide a generalized
156+
/// impl, but having good codegen for byte arrays is important.
157+
#[test]
158+
fn u8_cmovnz_works() {
159+
let mut x = EXAMPLE_A;
160+
x.cmovnz(&EXAMPLE_B, 0);
161+
assert_eq!(x, EXAMPLE_A);
162+
163+
for cond in 1..u8::MAX {
164+
let mut x = EXAMPLE_A;
165+
x.cmovnz(&EXAMPLE_B, cond);
166+
assert_eq!(x, EXAMPLE_B);
167+
}
168+
}
169+
}
170+
125171
mod slices {
126172
use cmov::CmovEq;
127173

128174
#[test]
129175
fn cmoveq_works() {
176+
let example = [1u8, 2, 3].as_slice();
130177
let mut o = 0u8;
131178

132179
// Same slices.
133-
[1u8, 2, 3].cmoveq(&[1, 2, 3], 43, &mut o);
180+
example.cmoveq(example, 43, &mut o);
134181
assert_eq!(o, 43);
135182

136183
// Different lengths.
137-
[1u8, 2, 3].cmoveq(&[1, 2], 44, &mut o);
184+
example.cmoveq(&[1, 2], 44, &mut o);
138185
assert_ne!(o, 44);
139186

140187
// Different contents.
141-
[1u8, 2, 3].cmoveq(&[1, 2, 4], 45, &mut o);
188+
example.cmoveq(&[1, 2, 4], 45, &mut o);
142189
assert_ne!(o, 45);
143190
}
144191

145192
#[test]
146193
fn cmovne_works() {
194+
let example = [1u8, 2, 3].as_slice();
147195
let mut o = 0u8;
148196

149197
// Same slices.
150-
[1u8, 2, 3].cmovne(&[1, 2, 3], 43, &mut o);
198+
example.cmovne(example, 43, &mut o);
151199
assert_ne!(o, 43);
152200

153201
// Different lengths.
154-
[1u8, 2, 3].cmovne(&[1, 2], 44, &mut o);
202+
example.cmovne(&[1, 2], 44, &mut o);
155203
assert_eq!(o, 44);
156204

157205
// Different contents.
158-
[1u8, 2, 3].cmovne(&[1, 2, 4], 45, &mut o);
206+
example.cmovne(&[1, 2, 4], 45, &mut o);
159207
assert_eq!(o, 45);
160208
}
161209
}

cmov/tests/proptests.rs

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,55 +9,51 @@ macro_rules! int_proptests {
99
proptest! {
1010
#[test]
1111
fn cmovz_works(mut a in any::<$int>(), b in any::<$int>(), cond in any::<u8>()) {
12-
a.cmovz(&b, cond);
13-
1412
let expected = if cond == 0 {
1513
b
1614
} else {
1715
a
1816
};
1917

18+
a.cmovz(&b, cond);
2019
prop_assert_eq!(expected, a);
2120
}
2221

2322
#[test]
2423
fn cmovnz_works(mut a in any::<$int>(), b in any::<$int>(), cond in any::<u8>()) {
25-
a.cmovnz(&b, cond);
26-
2724
let expected = if cond != 0 {
2825
b
2926
} else {
3027
a
3128
};
3229

30+
a.cmovnz(&b, cond);
3331
prop_assert_eq!(expected, a);
3432
}
3533

3634
#[test]
3735
fn cmoveq_works(a in any::<$int>(), b in any::<$int>(), cond in any::<u8>()) {
38-
let mut actual = 0;
39-
a.cmoveq(&b, cond, &mut actual);
40-
4136
let expected = if a == b {
4237
cond
4338
} else {
4439
0
4540
};
4641

42+
let mut actual = 0;
43+
a.cmoveq(&b, cond, &mut actual);
4744
prop_assert_eq!(expected, actual);
4845
}
4946

5047
#[test]
5148
fn cmovne_works(a in any::<$int>(), b in any::<$int>(), cond in any::<u8>()) {
52-
let mut actual = 0;
53-
a.cmovne(&b, cond, &mut actual);
54-
5549
let expected = if a != b {
5650
cond
5751
} else {
5852
0
5953
};
6054

55+
let mut actual = 0;
56+
a.cmovne(&b, cond, &mut actual);
6157
prop_assert_eq!(expected, actual);
6258
}
6359
}
@@ -66,4 +62,61 @@ macro_rules! int_proptests {
6662
};
6763
}
6864

65+
/// Write the proptests for a byte array of the given size.
66+
macro_rules! byte_array_proptests {
67+
( $($name:ident: $size:expr),+ ) => {
68+
$(
69+
mod $name {
70+
use cmov::Cmov;
71+
use proptest::prelude::*;
72+
73+
proptest! {
74+
#[test]
75+
fn cmovnz_works(
76+
mut a in any::<[u8; $size]>(),
77+
b in any::<[u8; $size]>(),
78+
cond in any::<u8>()
79+
) {
80+
let expected = if cond == 0 {
81+
a
82+
} else {
83+
b
84+
};
85+
86+
a.cmovnz(&b, cond);
87+
prop_assert_eq!(expected, a);
88+
}
89+
}
90+
}
91+
)+
92+
};
93+
}
94+
6995
int_proptests!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
96+
byte_array_proptests!(
97+
array0: 0,
98+
array1: 1,
99+
array2: 2,
100+
array3: 3,
101+
array4: 4,
102+
array5: 5,
103+
array6: 6,
104+
array7: 7,
105+
array8: 8,
106+
array9: 9,
107+
array10: 10,
108+
array11: 11,
109+
array12: 12,
110+
array13: 13,
111+
array14: 14,
112+
array15: 15,
113+
array16: 16,
114+
array17: 17,
115+
array18: 18,
116+
array19: 19,
117+
array20: 20,
118+
array21: 21,
119+
array22: 22,
120+
array23: 23,
121+
array24: 24
122+
);

0 commit comments

Comments
 (0)