|
1 | 1 | //! Trait impls for core slices. |
2 | 2 |
|
3 | 3 | use crate::{Cmov, CmovEq, Condition}; |
4 | | -use core::slice; |
| 4 | +use core::{ |
| 5 | + ops::{BitOrAssign, Shl}, |
| 6 | + slice, |
| 7 | +}; |
5 | 8 |
|
6 | 9 | // Uses 64-bit words on 64-bit targets, 32-bit everywhere else |
7 | 10 | #[cfg(not(target_pointer_width = "64"))] |
@@ -43,10 +46,7 @@ impl Cmov for [u8] { |
43 | 46 | dst_chunk.copy_from_slice(&a.to_ne_bytes()); |
44 | 47 | } |
45 | 48 |
|
46 | | - // Process the remainder a byte-at-a-time. |
47 | | - for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) { |
48 | | - a.cmovnz(b, condition); |
49 | | - } |
| 49 | + cmovnz_remainder(dst_remainder, src_remainder, condition); |
50 | 50 | } |
51 | 51 | } |
52 | 52 |
|
@@ -74,10 +74,7 @@ impl Cmov for [u16] { |
74 | 74 | dst_chunk[1] = (a >> 16) as u16; |
75 | 75 | } |
76 | 76 |
|
77 | | - // If slice is odd-length |
78 | | - if !dst_remainder.is_empty() { |
79 | | - dst_remainder[0].cmovnz(&src_remainder[0], condition); |
80 | | - } |
| 77 | + cmovnz_remainder(dst_remainder, src_remainder, condition); |
81 | 78 | } |
82 | 79 | } |
83 | 80 |
|
@@ -115,9 +112,7 @@ impl Cmov for [u16] { |
115 | 112 | dst_chunk[3] = ((a >> 48) & 0xFFFF) as u16; |
116 | 113 | } |
117 | 114 |
|
118 | | - for (a, b) in dst_remainder.iter_mut().zip(src_remainder.iter()) { |
119 | | - a.cmovnz(b, condition); |
120 | | - } |
| 115 | + cmovnz_remainder(dst_remainder, src_remainder, condition); |
121 | 116 | } |
122 | 117 | } |
123 | 118 |
|
@@ -163,10 +158,7 @@ impl Cmov for [u32] { |
163 | 158 | dst_chunk[1] = (a >> 32) as u32; |
164 | 159 | } |
165 | 160 |
|
166 | | - // If slice is odd-length |
167 | | - if !dst_remainder.is_empty() { |
168 | | - dst_remainder[0].cmovnz(&src_remainder[0], condition); |
169 | | - } |
| 161 | + cmovnz_remainder(dst_remainder, src_remainder, condition); |
170 | 162 | } |
171 | 163 | } |
172 | 164 |
|
@@ -323,10 +315,7 @@ impl CmovEq for [u8] { |
323 | 315 | a.cmovne(&b, input, output); |
324 | 316 | } |
325 | 317 |
|
326 | | - // Process the remainder a byte-at-a-time. |
327 | | - for (a, b) in self_remainder.iter().zip(rhs_remainder.iter()) { |
328 | | - a.cmovne(b, input, output); |
329 | | - } |
| 318 | + cmovne_remainder(self_remainder, rhs_remainder, input, output); |
330 | 319 | } |
331 | 320 | } |
332 | 321 |
|
@@ -370,6 +359,68 @@ impl_cmoveq_with_loop!( |
370 | 359 | "Implementation for `u128` slices where we can just loop." |
371 | 360 | ); |
372 | 361 |
|
| 362 | +/// Compare the two remainder slices by loading a `Word` then performing `cmovne`. |
| 363 | +#[inline] |
| 364 | +fn cmovne_remainder<T>( |
| 365 | + a_remainder: &[T], |
| 366 | + b_remainder: &[T], |
| 367 | + input: Condition, |
| 368 | + output: &mut Condition, |
| 369 | +) where |
| 370 | + T: Copy, |
| 371 | + Word: From<T>, |
| 372 | +{ |
| 373 | + let a = slice_to_word(a_remainder); |
| 374 | + let b = slice_to_word(b_remainder); |
| 375 | + a.cmovne(&b, input, output); |
| 376 | +} |
| 377 | + |
| 378 | +/// Load the remainder from chunking the slice into a single `Word`, perform `cmovnz`, then write |
| 379 | +/// the result back out to `dst_remainder`. |
| 380 | +#[inline] |
| 381 | +fn cmovnz_remainder<T>(dst_remainder: &mut [T], src_remainder: &[T], condition: Condition) |
| 382 | +where |
| 383 | + T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>, |
| 384 | + Word: From<T>, |
| 385 | +{ |
| 386 | + let mut remainder = slice_to_word(dst_remainder); |
| 387 | + remainder.cmovnz(&slice_to_word(src_remainder), condition); |
| 388 | + word_to_slice(remainder, dst_remainder); |
| 389 | +} |
| 390 | + |
| 391 | +/// Create a [`Word`] from the given input slice. |
| 392 | +#[inline] |
| 393 | +fn slice_to_word<T>(slice: &[T]) -> Word |
| 394 | +where |
| 395 | + T: Copy, |
| 396 | + Word: From<T>, |
| 397 | +{ |
| 398 | + debug_assert!(size_of_val(slice) <= WORD_SIZE, "slice too large"); |
| 399 | + slice |
| 400 | + .iter() |
| 401 | + .rev() |
| 402 | + .copied() |
| 403 | + .fold(0, |acc, n| (acc << (size_of::<T>() * 8)) | Word::from(n)) |
| 404 | +} |
| 405 | + |
| 406 | +/// Serialize [`Word`] as bytes using the same byte ordering as `slice_to_word`. |
| 407 | +#[inline] |
| 408 | +fn word_to_slice<T>(word: Word, out: &mut [T]) |
| 409 | +where |
| 410 | + T: BitOrAssign + Copy + From<u8> + Shl<usize, Output = T>, |
| 411 | +{ |
| 412 | + debug_assert!(size_of::<T>() > 0, "can't be used with ZSTs"); |
| 413 | + debug_assert!(out.len() <= WORD_SIZE, "slice too large"); |
| 414 | + |
| 415 | + let bytes = word.to_le_bytes(); |
| 416 | + for (o, chunk) in out.iter_mut().zip(bytes.chunks(size_of::<T>())) { |
| 417 | + *o = T::from(0u8); |
| 418 | + for (i, &byte) in chunk.iter().enumerate() { |
| 419 | + *o |= T::from(byte) << (i * 8); |
| 420 | + } |
| 421 | + } |
| 422 | +} |
| 423 | + |
373 | 424 | /// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV. |
374 | 425 | /// TODO(tarcieri): use upstream function when we bump MSRV |
375 | 426 | #[inline] |
@@ -439,3 +490,59 @@ unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> & |
439 | 490 | // a slice of `new_len` many `N` elements chunks. |
440 | 491 | unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) } |
441 | 492 | } |
| 493 | + |
| 494 | +#[cfg(test)] |
| 495 | +mod tests { |
| 496 | + #[test] |
| 497 | + fn cmovnz_remainder() { |
| 498 | + // - Test endianness handling on non-64-bit platforms |
| 499 | + // - Test handling of odd length slices on 64-bit platforms |
| 500 | + #[cfg(not(target_pointer_width = "64"))] |
| 501 | + const A_U16: [u16; 2] = [0xAAAA, 0xBBBB]; |
| 502 | + #[cfg(target_pointer_width = "64")] |
| 503 | + const A_U16: [u16; 3] = [0xAAAA, 0xBBBB, 0xCCCC]; |
| 504 | + |
| 505 | + #[cfg(not(target_pointer_width = "64"))] |
| 506 | + const B_U16: [u16; 2] = [0x10, 0xFFFF]; |
| 507 | + #[cfg(target_pointer_width = "64")] |
| 508 | + const B_U16: [u16; 3] = [0x10, 0x10, 0xFFFF]; |
| 509 | + |
| 510 | + let mut out = A_U16; |
| 511 | + |
| 512 | + super::cmovnz_remainder(&mut out, &B_U16, 0); |
| 513 | + assert_eq!(A_U16, out); |
| 514 | + |
| 515 | + super::cmovnz_remainder(&mut out, &B_U16, 1); |
| 516 | + assert_eq!(B_U16, out); |
| 517 | + } |
| 518 | + |
| 519 | + #[test] |
| 520 | + fn slice_to_word() { |
| 521 | + assert_eq!(0xAABBCC, super::slice_to_word(&[0xCCu8, 0xBB, 0xAA])); |
| 522 | + assert_eq!(0xAAAABBBB, super::slice_to_word(&[0xBBBBu16, 0xAAAA])); |
| 523 | + |
| 524 | + #[cfg(target_pointer_width = "64")] |
| 525 | + assert_eq!( |
| 526 | + 0xAAAABBBBCCCC, |
| 527 | + super::slice_to_word(&[0xCCCCu16, 0xBBBB, 0xAAAA]) |
| 528 | + ); |
| 529 | + } |
| 530 | + |
| 531 | + #[test] |
| 532 | + fn word_to_slice() { |
| 533 | + let mut out = [0u8; 3]; |
| 534 | + super::word_to_slice(0xAABBCC, &mut out); |
| 535 | + assert_eq!(&[0xCC, 0xBB, 0xAA], &out); |
| 536 | + |
| 537 | + let mut out = [0u16; 2]; |
| 538 | + super::word_to_slice(0xAAAABBBB, &mut out); |
| 539 | + assert_eq!(&[0xBBBB, 0xAAAA], &out); |
| 540 | + |
| 541 | + #[cfg(target_pointer_width = "64")] |
| 542 | + { |
| 543 | + let mut out = [0u16; 3]; |
| 544 | + super::word_to_slice(0xAAAABBBBCCCC, &mut out); |
| 545 | + assert_eq!(&[0xCCCC, 0xBBBB, 0xAAAA], &out); |
| 546 | + } |
| 547 | + } |
| 548 | +} |
0 commit comments