11//! Trait impls for core slices.
22
3- use crate :: utils:: { WORD_SIZE , Word , slice_as_chunks, slice_as_chunks_mut} ;
43use crate :: { Cmov , CmovEq , Condition } ;
4+ use core:: slice;
5+
6+ // Uses 64-bit words on 64-bit targets, 32-bit everywhere else
7+ #[ cfg( not( target_pointer_width = "64" ) ) ]
8+ type Word = u32 ;
9+ #[ cfg( target_pointer_width = "64" ) ]
10+ type Word = u64 ;
11+ const WORD_SIZE : usize = size_of :: < Word > ( ) ;
12+ const _: ( ) = assert ! ( size_of:: <usize >( ) <= WORD_SIZE , "unexpected word size" ) ;
513
614/// Optimized implementation for byte slices which coalesces them into word-sized chunks first,
715/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
16+ ///
17+ /// # Panics
18+ /// - if slices have unequal lengths
819impl Cmov for [ u8 ] {
920 #[ inline]
1021 fn cmovnz ( & mut self , value : & Self , condition : Condition ) {
11- let ( self_chunks, self_remainder) = slice_as_chunks_mut :: < u8 , WORD_SIZE > ( self ) ;
12- let ( value_chunks, value_remainder) = slice_as_chunks :: < u8 , WORD_SIZE > ( value) ;
13-
14- for ( self_chunk, value_chunk) in self_chunks. iter_mut ( ) . zip ( value_chunks. iter ( ) ) {
15- let mut a = Word :: from_ne_bytes ( * self_chunk) ;
16- let b = Word :: from_ne_bytes ( * value_chunk) ;
17- a. cmovnz ( & b, condition) ;
18- self_chunk. copy_from_slice ( & a. to_ne_bytes ( ) ) ;
19- }
22+ assert_eq ! (
23+ self . len( ) ,
24+ value. len( ) ,
25+ "source slice length ({}) does not match destination slice length ({})" ,
26+ value. len( ) ,
27+ self . len( )
28+ ) ;
2029
21- // Process the remainder a byte-at-a-time.
22- for ( a, b) in self_remainder. iter_mut ( ) . zip ( value_remainder. iter ( ) ) {
23- a. cmovnz ( b, condition) ;
24- }
30+ cmovnz_slice_unchecked ( self , value, condition) ;
2531 }
2632}
2733
28- impl < T : CmovEq > CmovEq for [ T ] {
34+ /// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
35+ /// then performs [`CmovEq`] at the word-level to cut down on the total number of instructions.
36+ ///
37+ /// This is only constant-time for equal-length slices, and will short-circuit and set `output`
38+ /// in the event the slices are of unequal length.
39+ impl CmovEq for [ u8 ] {
40+ #[ inline]
2941 fn cmovne ( & self , rhs : & Self , input : Condition , output : & mut Condition ) {
3042 // Short-circuit the comparison if the slices are of different lengths, and set the output
3143 // condition to the input condition.
@@ -34,9 +46,109 @@ impl<T: CmovEq> CmovEq for [T] {
3446 return ;
3547 }
3648
37- // Compare each byte.
38- for ( a, b) in self . iter ( ) . zip ( rhs. iter ( ) ) {
49+ let ( self_chunks, self_remainder) = slice_as_chunks :: < u8 , WORD_SIZE > ( self ) ;
50+ let ( rhs_chunks, rhs_remainder) = slice_as_chunks :: < u8 , WORD_SIZE > ( rhs) ;
51+
52+ for ( self_chunk, rhs_chunk) in self_chunks. iter ( ) . zip ( rhs_chunks. iter ( ) ) {
53+ let a = Word :: from_ne_bytes ( * self_chunk) ;
54+ let b = Word :: from_ne_bytes ( * rhs_chunk) ;
55+ a. cmovne ( & b, input, output) ;
56+ }
57+
58+ // Process the remainder a byte-at-a-time.
59+ for ( a, b) in self_remainder. iter ( ) . zip ( rhs_remainder. iter ( ) ) {
3960 a. cmovne ( b, input, output) ;
4061 }
4162 }
4263}
64+
65+ /// Conditionally move `src` to `dst` in constant-time if `condition` is non-zero.
66+ ///
67+ /// This function does not check the slices are equal-length and expects the caller to do so first.
68+ #[ inline( always) ]
69+ pub ( crate ) fn cmovnz_slice_unchecked ( dst : & mut [ u8 ] , src : & [ u8 ] , condition : Condition ) {
70+ let ( dst_chunks, dst_remainder) = slice_as_chunks_mut :: < u8 , WORD_SIZE > ( dst) ;
71+ let ( src_chunks, src_remainder) = slice_as_chunks :: < u8 , WORD_SIZE > ( src) ;
72+
73+ for ( dst_chunk, src_chunk) in dst_chunks. iter_mut ( ) . zip ( src_chunks. iter ( ) ) {
74+ let mut a = Word :: from_ne_bytes ( * dst_chunk) ;
75+ let b = Word :: from_ne_bytes ( * src_chunk) ;
76+ a. cmovnz ( & b, condition) ;
77+ dst_chunk. copy_from_slice ( & a. to_ne_bytes ( ) ) ;
78+ }
79+
80+ // Process the remainder a byte-at-a-time.
81+ for ( a, b) in dst_remainder. iter_mut ( ) . zip ( src_remainder. iter ( ) ) {
82+ a. cmovnz ( b, condition) ;
83+ }
84+ }
85+
86+ /// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
87+ /// TODO(tarcieri): use upstream function when we bump MSRV
88+ #[ inline]
89+ #[ track_caller]
90+ #[ must_use]
91+ #[ allow( clippy:: integer_division_remainder_used) ]
92+ fn slice_as_chunks < T , const N : usize > ( slice : & [ T ] ) -> ( & [ [ T ; N ] ] , & [ T ] ) {
93+ assert ! ( N != 0 , "chunk size must be non-zero" ) ;
94+ let len_rounded_down = slice. len ( ) / N * N ;
95+ // SAFETY: The rounded-down value is always the same or smaller than the
96+ // original length, and thus must be in-bounds of the slice.
97+ let ( multiple_of_n, remainder) = unsafe { slice. split_at_unchecked ( len_rounded_down) } ;
98+ // SAFETY: We already panicked for zero, and ensured by construction
99+ // that the length of the subslice is a multiple of N.
100+ let array_slice = unsafe { slice_as_chunks_unchecked ( multiple_of_n) } ;
101+ ( array_slice, remainder)
102+ }
103+
104+ /// Rust core `[T]::as_chunks_mut` vendored because of its 1.88 MSRV.
105+ /// TODO(tarcieri): use upstream function when we bump MSRV
106+ #[ inline]
107+ #[ track_caller]
108+ #[ must_use]
109+ #[ allow( clippy:: integer_division_remainder_used) ]
110+ fn slice_as_chunks_mut < T , const N : usize > ( slice : & mut [ T ] ) -> ( & mut [ [ T ; N ] ] , & mut [ T ] ) {
111+ assert ! ( N != 0 , "chunk size must be non-zero" ) ;
112+ let len_rounded_down = slice. len ( ) / N * N ;
113+ // SAFETY: The rounded-down value is always the same or smaller than the
114+ // original length, and thus must be in-bounds of the slice.
115+ let ( multiple_of_n, remainder) = unsafe { slice. split_at_mut_unchecked ( len_rounded_down) } ;
116+ // SAFETY: We already panicked for zero, and ensured by construction
117+ // that the length of the subslice is a multiple of N.
118+ let array_slice = unsafe { slice_as_chunks_unchecked_mut ( multiple_of_n) } ;
119+ ( array_slice, remainder)
120+ }
121+
122+ /// Rust core `[T]::as_chunks_unchecked` vendored because of its 1.88 MSRV.
123+ /// TODO(tarcieri): use upstream function when we bump MSRV
124+ #[ inline]
125+ #[ must_use]
126+ #[ track_caller]
127+ #[ allow( clippy:: integer_division_remainder_used) ]
128+ unsafe fn slice_as_chunks_unchecked < T , const N : usize > ( slice : & [ T ] ) -> & [ [ T ; N ] ] {
129+ // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
130+ const { debug_assert ! ( N != 0 ) } ;
131+ debug_assert_eq ! ( slice. len( ) % N , 0 ) ;
132+ let new_len = slice. len ( ) / N ;
133+
134+ // SAFETY: We cast a slice of `new_len * N` elements into
135+ // a slice of `new_len` many `N` elements chunks.
136+ unsafe { slice:: from_raw_parts ( slice. as_ptr ( ) . cast ( ) , new_len) }
137+ }
138+
139+ /// Rust core `[T]::as_chunks_unchecked_mut` vendored because of its 1.88 MSRV.
140+ /// TODO(tarcieri): use upstream function when we bump MSRV
141+ #[ inline]
142+ #[ must_use]
143+ #[ track_caller]
144+ #[ allow( clippy:: integer_division_remainder_used) ]
145+ unsafe fn slice_as_chunks_unchecked_mut < T , const N : usize > ( slice : & mut [ T ] ) -> & mut [ [ T ; N ] ] {
146+ // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
147+ const { debug_assert ! ( N != 0 ) } ;
148+ debug_assert_eq ! ( slice. len( ) % N , 0 ) ;
149+ let new_len = slice. len ( ) / N ;
150+
151+ // SAFETY: We cast a slice of `new_len * N` elements into
152+ // a slice of `new_len` many `N` elements chunks.
153+ unsafe { slice:: from_raw_parts_mut ( slice. as_mut_ptr ( ) . cast ( ) , new_len) }
154+ }
0 commit comments