88#include <stdint.h>
99#include "cbmc.h"
1010#include "common.h"
11+ #include "ct.h"
1112
1213#define MLD_2_POW_D (1 << MLDSA_D)
1314
14- #define mld_power2round MLD_NAMESPACE(power2round)
1515/*************************************************
1616 * Name: mld_power2round
1717 *
2626 * Reference: In the reference implementation, a1 is passed as a
2727 * return value instead.
2828 **************************************************/
29- void mld_power2round (int32_t * a0 , int32_t * a1 , int32_t a )
29+ static MLD_INLINE void mld_power2round (int32_t * a0 , int32_t * a1 , int32_t a )
3030__contract__ (
3131 requires (memory_no_alias (a0 , sizeof (int32_t )))
3232 requires (memory_no_alias (a1 , sizeof (int32_t )))
@@ -36,10 +36,12 @@ __contract__(
3636 ensures (* a0 > - (MLD_2_POW_D /2 ) && * a0 <= (MLD_2_POW_D /2 ))
3737 ensures (* a1 >= 0 && * a1 <= (MLDSA_Q - 1 ) / MLD_2_POW_D )
3838 ensures ((* a1 * MLD_2_POW_D + * a0 - a ) % MLDSA_Q == 0 )
39- );
39+ )
40+ {
41+ * a1 = (a + (1 << (MLDSA_D - 1 )) - 1 ) >> MLDSA_D ;
42+ * a0 = a - (* a1 << MLDSA_D );
43+ }
4044
41-
42- #define mld_decompose MLD_NAMESPACE(decompose)
4345/*************************************************
4446 * Name: mld_decompose
4547 *
@@ -56,7 +58,7 @@ __contract__(
5658 *
5759 * Reference: a1 is passed as a return value instead
5860 **************************************************/
59- void mld_decompose (int32_t * a0 , int32_t * a1 , int32_t a )
61+ static MLD_INLINE void mld_decompose (int32_t * a0 , int32_t * a1 , int32_t a )
6062__contract__ (
6163 requires (memory_no_alias (a0 , sizeof (int32_t )))
6264 requires (memory_no_alias (a1 , sizeof (int32_t )))
@@ -68,9 +70,32 @@ __contract__(
6870 ensures (* a0 >= - MLDSA_GAMMA2 && * a0 <= MLDSA_GAMMA2 )
6971 ensures (* a1 >= 0 && * a1 < (MLDSA_Q - 1 )/(2 * MLDSA_GAMMA2 ))
7072 ensures ((* a1 * 2 * MLDSA_GAMMA2 + * a0 - a ) % MLDSA_Q == 0 )
71- );
73+ )
74+ {
75+ * a1 = (a + 127 ) >> 7 ;
76+ /* We know a >= 0 and a < MLDSA_Q, so... */
77+ cassert (* a1 >= 0 && * a1 <= 65472 );
78+
79+ #if MLDSA_MODE == 2
80+ * a1 = (* a1 * 11275 + (1 << 23 )) >> 24 ;
81+ cassert (* a1 >= 0 && * a1 <= 44 );
82+
83+ * a1 = mld_ct_sel_int32 (0 , * a1 , mld_ct_cmask_neg_i32 (43 - * a1 ));
84+ cassert (* a1 >= 0 && * a1 <= 43 );
85+ #else /* MLDSA_MODE == 2 */
86+ * a1 = (* a1 * 1025 + (1 << 21 )) >> 22 ;
87+ cassert (* a1 >= 0 && * a1 <= 16 );
88+
89+ * a1 &= 15 ;
90+ cassert (* a1 >= 0 && * a1 <= 15 );
91+
92+ #endif /* MLDSA_MODE != 2 */
93+
94+ * a0 = a - * a1 * 2 * MLDSA_GAMMA2 ;
95+ * a0 = mld_ct_sel_int32 (* a0 - MLDSA_Q , * a0 ,
96+ mld_ct_cmask_neg_i32 ((MLDSA_Q - 1 ) / 2 - * a0 ));
97+ }
7298
73- #define mld_make_hint MLD_NAMESPACE(make_hint)
7499/*************************************************
75100 * Name: mld_make_hint
76101 *
@@ -82,12 +107,20 @@ __contract__(
82107 *
83108 * Returns 1 if overflow, 0 otherwise
84109 **************************************************/
85- unsigned int mld_make_hint (int32_t a0 , int32_t a1 )
110+ static MLD_INLINE unsigned int mld_make_hint (int32_t a0 , int32_t a1 )
86111__contract__ (
87112 ensures (return_value >= 0 && return_value <= 1 )
88- );
113+ )
114+ {
115+ if (a0 > MLDSA_GAMMA2 || a0 < - MLDSA_GAMMA2 ||
116+ (a0 == - MLDSA_GAMMA2 && a1 != 0 ))
117+ {
118+ return 1 ;
119+ }
120+
121+ return 0 ;
122+ }
89123
90- #define mld_use_hint MLD_NAMESPACE(use_hint)
91124/*************************************************
92125 * Name: mld_use_hint
93126 *
@@ -98,11 +131,41 @@ __contract__(
98131 *
99132 * Returns corrected high bits.
100133 **************************************************/
101- int32_t mld_use_hint (int32_t a , unsigned int hint )
134+ static MLD_INLINE int32_t mld_use_hint (int32_t a , unsigned int hint )
102135__contract__ (
103136 requires (hint >= 0 && hint <= 1 )
104137 requires (a >= 0 && a < MLDSA_Q )
105138 ensures (return_value >= 0 && return_value < (MLDSA_Q - 1 )/(2 * MLDSA_GAMMA2 ))
106- );
139+ )
140+ {
141+ int32_t a0 , a1 ;
142+
143+ mld_decompose (& a0 , & a1 , a );
144+ if (hint == 0 )
145+ {
146+ return a1 ;
147+ }
148+
149+ #if MLDSA_MODE == 2
150+ if (a0 > 0 )
151+ {
152+ return (a1 == 43 ) ? 0 : a1 + 1 ;
153+ }
154+ else
155+ {
156+ return (a1 == 0 ) ? 43 : a1 - 1 ;
157+ }
158+ #else /* MLDSA_MODE == 2 */
159+ if (a0 > 0 )
160+ {
161+ return (a1 + 1 ) & 15 ;
162+ }
163+ else
164+ {
165+ return (a1 - 1 ) & 15 ;
166+ }
167+ #endif /* MLDSA_MODE != 2 */
168+ }
169+
107170
108171#endif /* !MLD_ROUNDING_H */
0 commit comments