77//! optimizer potentially inserting branches.
88
99use crate :: { Cmov , CmovEq , Condition } ;
10+ use core:: ops:: { BitAnd , BitOr , Not } ;
1011
12+ // Uses `Cmov` impl for `u32`
1113impl Cmov for u16 {
1214 #[ inline]
1315 fn cmovnz ( & mut self , value : & u16 , condition : Condition ) {
@@ -24,6 +26,7 @@ impl Cmov for u16 {
2426 }
2527}
2628
29+ // Uses `CmovEq` impl for `u32`
2730impl CmovEq for u16 {
2831 #[ inline]
2932 fn cmovne ( & self , rhs : & Self , input : Condition , output : & mut Condition ) {
@@ -39,59 +42,103 @@ impl CmovEq for u16 {
3942impl Cmov for u32 {
4043 #[ inline]
4144 fn cmovnz ( & mut self , value : & Self , condition : Condition ) {
42- let mask = masknz32 ( condition) ;
43- * self = ( * self & !mask) | ( * value & mask) ;
45+ * self = masksel ( * self , * value, masknz32 ( condition. into ( ) ) ) ;
4446 }
4547
4648 #[ inline]
4749 fn cmovz ( & mut self , value : & Self , condition : Condition ) {
48- let mask = masknz32 ( condition) ;
49- * self = ( * self & mask) | ( * value & !mask) ;
50+ * self = masksel ( * self , * value, !masknz32 ( condition. into ( ) ) ) ;
5051 }
5152}
5253
5354impl CmovEq for u32 {
5455 #[ inline]
5556 fn cmovne ( & self , rhs : & Self , input : Condition , output : & mut Condition ) {
56- let ne = testne32 ( * self , * rhs) ;
57- output. cmovnz ( & input, ne) ;
57+ output. cmovnz ( & input, testne32 ( * self , * rhs) ) ;
5858 }
5959
6060 #[ inline]
6161 fn cmoveq ( & self , rhs : & Self , input : Condition , output : & mut Condition ) {
62- let eq = testeq32 ( * self , * rhs) ;
63- output. cmovnz ( & input, eq) ;
62+ output. cmovnz ( & input, testeq32 ( * self , * rhs) ) ;
6463 }
6564}
6665
6766impl Cmov for u64 {
6867 #[ inline]
6968 fn cmovnz ( & mut self , value : & Self , condition : Condition ) {
70- let mask = masknz64 ( condition) ;
71- * self = ( * self & !mask) | ( * value & mask) ;
69+ * self = masksel ( * self , * value, masknz64 ( condition. into ( ) ) ) ;
7270 }
7371
7472 #[ inline]
7573 fn cmovz ( & mut self , value : & Self , condition : Condition ) {
76- let mask = masknz64 ( condition) ;
77- * self = ( * self & mask) | ( * value & !mask) ;
74+ * self = masksel ( * self , * value, !masknz64 ( condition. into ( ) ) ) ;
7875 }
7976}
8077
8178impl CmovEq for u64 {
8279 #[ inline]
8380 fn cmovne ( & self , rhs : & Self , input : Condition , output : & mut Condition ) {
84- let ne = testne64 ( * self , * rhs) ;
85- output. cmovnz ( & input, ne) ;
81+ output. cmovnz ( & input, testne64 ( * self , * rhs) ) ;
8682 }
8783
8884 #[ inline]
8985 fn cmoveq ( & self , rhs : & Self , input : Condition , output : & mut Condition ) {
90- let eq = testeq64 ( * self , * rhs) ;
91- output. cmovnz ( & input, eq) ;
86+ output. cmovnz ( & input, testeq64 ( * self , * rhs) ) ;
9287 }
9388}
9489
90+ /// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
91+ #[ cfg( not( target_arch = "arm" ) ) ]
92+ fn masknz32 ( condition : u32 ) -> u32 {
93+ testnz32 ( condition) . wrapping_neg ( )
94+ }
95+
96+ /// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
97+ #[ cfg( not( target_arch = "arm" ) ) ]
98+ fn masknz64 ( condition : u64 ) -> u64 {
99+ testnz64 ( condition) . wrapping_neg ( )
100+ }
101+
102+ /// Optimized mask generation for ARM32 targets.
103+ ///
104+ /// This is written in assembly both for performance and because we've had problematic code
105+ /// generation in this routine in the past which lead to the insertion of a branch, which using
106+ /// assembly should guarantee won't happen again in the future (CVE-2026-23519).
107+ #[ cfg( target_arch = "arm" ) ]
108+ fn masknz32 ( condition : u32 ) -> u32 {
109+ let mut mask = condition;
110+ unsafe {
111+ core:: arch:: asm!(
112+ "rsbs {0}, {0}, #0" , // Reverse subtract
113+ "sbcs {0}, {0}, {0}" , // Subtract with carry, setting flags
114+ inout( reg) mask,
115+ options( nostack, nomem) ,
116+ ) ;
117+ }
118+ mask
119+ }
120+
121+ /// 64-bit wrapper for targets that implement 32-bit mask generation in assembly.
122+ #[ cfg( target_arch = "arm" ) ]
123+ fn masknz64 ( condition : u64 ) -> u64 {
124+ let lo = masknz32 ( ( condition & 0xFFFF_FFFF ) as u32 ) ;
125+ let hi = masknz32 ( ( condition >> 32 ) as u32 ) ;
126+ let mask = ( lo | hi) as u64 ;
127+ mask | mask << 32
128+ }
129+
130+ /// Given a supplied mask of `0` or all 1-bits (i.e. `u*::MAX`), select `a` if the mask is all-zeros
131+ /// and `b` if the mask is all-ones.
132+ ///
133+ /// This function shouldn't be used with a mask that isn't `0` or `u*::MAX`.
134+ #[ inline]
135+ fn masksel < T > ( a : T , b : T , mask : T ) -> T
136+ where
137+ T : BitAnd < Output = T > + BitOr < Output = T > + Copy + Not < Output = T > ,
138+ {
139+ ( a & !mask) | ( b & mask)
140+ }
141+
95142/// Returns `1` if `x` is equal to `y`, otherwise returns `0` (32-bit version)
96143fn testeq32 ( x : u32 , y : u32 ) -> Condition {
97144 testne32 ( x, y) ^ 1
@@ -120,46 +167,37 @@ fn testnz32(mut x: u32) -> u32 {
120167
121168/// Returns `0` if `x` is `0`, otherwise returns `1` (64-bit version)
122169fn testnz64 ( mut x : u64 ) -> u64 {
123- x |= x. wrapping_neg ( ) ; // MSB now set if non-zero
170+ x |= x. wrapping_neg ( ) ; // MSB now set if non-zero (or unset if zero)
124171 core:: hint:: black_box ( x >> ( u64:: BITS - 1 ) ) // Extract MSB
125172}
126173
127- /// Return a [`u32::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
128- #[ cfg( not( target_arch = "arm" ) ) ]
129- fn masknz32 ( condition : Condition ) -> u32 {
130- testnz32 ( condition. into ( ) ) . wrapping_neg ( )
131- }
132-
133- /// Return a [`u64::MAX`] mask if `condition` is non-zero, otherwise return zero for a zero input.
134- #[ cfg( not( target_arch = "arm" ) ) ]
135- fn masknz64 ( condition : Condition ) -> u64 {
136- testnz64 ( condition. into ( ) ) . wrapping_neg ( )
137- }
174+ #[ cfg( test) ]
175+ mod tests {
176+ #[ test]
177+ fn masknz32 ( ) {
178+ assert_eq ! ( super :: masknz32( 0 ) , 0 ) ;
179+ for i in 1 ..=u8:: MAX {
180+ assert_eq ! ( super :: masknz32( i. into( ) ) , u32 :: MAX ) ;
181+ }
182+ }
138183
139- /// Optimized mask generation for ARM32 targets.
140- #[ cfg( target_arch = "arm" ) ]
141- fn masknz32 ( condition : u8 ) -> u32 {
142- let mut out = condition as u32 ;
143- unsafe {
144- core:: arch:: asm!(
145- "rsbs {0}, {0}, #0" , // Reverse subtract
146- "sbcs {0}, {0}, {0}" , // Subtract with carry, setting flags
147- inout( reg) out,
148- options( nostack, nomem) ,
149- ) ;
184+ #[ test]
185+ fn masknz64 ( ) {
186+ assert_eq ! ( super :: masknz64( 0 ) , 0 ) ;
187+ for i in 1 ..=u8:: MAX {
188+ assert_eq ! ( super :: masknz64( i. into( ) ) , u64 :: MAX ) ;
189+ }
150190 }
151- out
152- }
153191
154- /// 64-bit wrapper for targets that implement 32-bit mask generation in assembly.
155- #[ cfg( target_arch = "arm" ) ]
156- fn masknz64 ( condition : u8 ) -> u64 {
157- let mask = masknz32 ( condition) as u64 ;
158- mask | mask << 32
159- }
192+ #[ test]
193+ fn masksel ( ) {
194+ assert_eq ! ( super :: masksel( 23u8 , 42u8 , 0u8 ) , 23u8 ) ;
195+ assert_eq ! ( super :: masksel( 23u8 , 42u8 , u8 :: MAX ) , 42u8 ) ;
196+
197+ assert_eq ! ( super :: masksel( 17u32 , 101077u32 , 0u32 ) , 17u32 ) ;
198+ assert_eq ! ( super :: masksel( 17u32 , 101077u32 , u32 :: MAX ) , 101077u32 ) ;
199+ }
160200
161- #[ cfg( test) ]
162- mod tests {
163201 #[ test]
164202 fn testeq32 ( ) {
165203 assert_eq ! ( super :: testeq32( 0 , 0 ) , 1 ) ;
@@ -219,20 +257,4 @@ mod tests {
219257 assert_eq ! ( super :: testnz64( i as u64 ) , 1 ) ;
220258 }
221259 }
222-
223- #[ test]
224- fn masknz32 ( ) {
225- assert_eq ! ( super :: masknz32( 0 ) , 0 ) ;
226- for i in 1 ..=u8:: MAX {
227- assert_eq ! ( super :: masknz32( i) , u32 :: MAX ) ;
228- }
229- }
230-
231- #[ test]
232- fn masknz64 ( ) {
233- assert_eq ! ( super :: masknz64( 0 ) , 0 ) ;
234- for i in 1 ..=u8:: MAX {
235- assert_eq ! ( super :: masknz64( i) , u64 :: MAX ) ;
236- }
237- }
238260}
0 commit comments