55
66#include <stddef.h>
77#include <stdio.h>
8- #include <string.h>
98#include <stdlib.h>
10- #include "../mldsa/src/sign.h"
11- #include "../mldsa/src/sys .h"
9+ #include <string.h>
10+ #include "../mldsa/src/fips202/fips202 .h"
1211#include "../mldsa/src/packing.h"
13- #include "../mldsa/src/polyvec.h"
1412#include "../mldsa/src/params.h"
15- #include "../mldsa/src/fips202/fips202.h"
13+ #include "../mldsa/src/polyvec.h"
14+ #include "../mldsa/src/sign.h"
15+ #include "../mldsa/src/sys.h"
1616
1717#include "notrandombytes/notrandombytes.h"
1818
3535 } while (0)
3636
3737/* Enum to specify which vector to corrupt in tests */
38- typedef enum {
38+ typedef enum
39+ {
3940 CORRUPT_S1 ,
4041 CORRUPT_S2
4142} corrupt_vector_t ;
4243
4344/* Struct to define a coefficient corruption test case */
44- typedef struct {
45- corrupt_vector_t vector ; /* Which vector to corrupt (s1 or s2) */
46- unsigned int poly_idx ; /* Polynomial index within the vector */
47- unsigned int coeff_idx ; /* Coefficient index within the polynomial */
48- int32_t corruption_value ; /* Value to set the coefficient to */
49- const char * description ; /* Description of the test case */
45+ typedef struct
46+ {
47+ corrupt_vector_t vector ; /* Which vector to corrupt (s1 or s2) */
48+ unsigned int poly_idx ; /* Polynomial index within the vector */
49+ unsigned int coeff_idx ; /* Coefficient index within the polynomial */
50+ int32_t corruption_value ; /* Value to set the coefficient to */
51+ const char * description ; /* Description of the test case */
5052} corruption_test_case_t ;
5153
5254
@@ -233,7 +235,8 @@ static int test_pk_from_sk(void)
233235 return 0 ;
234236}
235237
236- /* Helper function to check if s1 and s2 coefficients are within valid range [-eta, eta] */
238+ /* Helper function to check if s1 and s2 coefficients are within valid range
239+ * [-eta, eta] */
237240static int check_s1_s2_coeffs_in_range (
238241 const uint8_t sk [MLDSA_CRYPTO_SECRETKEYBYTES ])
239242{
@@ -243,10 +246,10 @@ static int check_s1_s2_coeffs_in_range(
243246 mld_polyveck t0 ;
244247 mld_polyvecl s1 ;
245248 mld_polyveck s2 ;
246-
249+
247250 /* Unpack the secret key to extract s1 and s2 */
248251 mld_unpack_sk (rho , tr , key , & t0 , & s1 , & s2 , sk );
249-
252+
250253 /* Check all coefficients in s1 are within [-MLDSA_ETA, MLDSA_ETA] */
251254 for (unsigned int i = 0 ; i < MLDSA_L ; i ++ )
252255 {
@@ -259,7 +262,7 @@ static int check_s1_s2_coeffs_in_range(
259262 }
260263 }
261264 }
262-
265+
263266 /* Check all coefficients in s2 are within [-MLDSA_ETA, MLDSA_ETA] */
264267 for (unsigned int i = 0 ; i < MLDSA_K ; i ++ )
265268 {
@@ -272,17 +275,18 @@ static int check_s1_s2_coeffs_in_range(
272275 }
273276 }
274277 }
275-
278+
276279 return 1 ; /* All coefficients in range */
277280}
278281
279282
280- /* Helper function to test crypto_sign_pk_from_sk with invalid s1 or s2 coefficients */
281- static int test_corrupted_sk (const corruption_test_case_t * test_case )
283+ /* Helper function to test crypto_sign_pk_from_sk with invalid s1 or s2
284+ * coefficients */
285+ static int test_corrupted_sk (const corruption_test_case_t * test_case )
282286{
283287 uint8_t sk_corrupted [MLDSA_CRYPTO_SECRETKEYBYTES ];
284288 int rc ;
285- const char * vector_name = (test_case -> vector == CORRUPT_S1 ) ? "s1" : "s2" ;
289+ const char * vector_name = (test_case -> vector == CORRUPT_S1 ) ? "s1" : "s2" ;
286290
287291 /* Start from a valid key pair */
288292 uint8_t pk_valid [MLDSA_CRYPTO_PUBLICKEYBYTES ];
@@ -304,22 +308,26 @@ static int test_corrupted_sk(const corruption_test_case_t* test_case)
304308 /* Validate indices are within bounds */
305309 if (test_case -> poly_idx >= MLDSA_L || test_case -> coeff_idx >= MLDSA_N )
306310 {
307- printf ("ERROR: s1 indices out of bounds: [%u][%u] (max [%u][%u])\n" ,
308- test_case -> poly_idx , test_case -> coeff_idx , MLDSA_L - 1 , MLDSA_N - 1 );
311+ printf ("ERROR: s1 indices out of bounds: [%u][%u] (max [%u][%u])\n" ,
312+ test_case -> poly_idx , test_case -> coeff_idx , MLDSA_L - 1 ,
313+ MLDSA_N - 1 );
309314 return 1 ;
310315 }
311- s1 .vec [test_case -> poly_idx ].coeffs [test_case -> coeff_idx ] = test_case -> corruption_value ;
316+ s1 .vec [test_case -> poly_idx ].coeffs [test_case -> coeff_idx ] =
317+ test_case -> corruption_value ;
312318 }
313319 else
314320 {
315321 /* Validate indices are within bounds */
316322 if (test_case -> poly_idx >= MLDSA_K || test_case -> coeff_idx >= MLDSA_N )
317323 {
318- printf ("ERROR: s2 indices out of bounds: [%u][%u] (max [%u][%u])\n" ,
319- test_case -> poly_idx , test_case -> coeff_idx , MLDSA_K - 1 , MLDSA_N - 1 );
324+ printf ("ERROR: s2 indices out of bounds: [%u][%u] (max [%u][%u])\n" ,
325+ test_case -> poly_idx , test_case -> coeff_idx , MLDSA_K - 1 ,
326+ MLDSA_N - 1 );
320327 return 1 ;
321328 }
322- s2 .vec [test_case -> poly_idx ].coeffs [test_case -> coeff_idx ] = test_case -> corruption_value ;
329+ s2 .vec [test_case -> poly_idx ].coeffs [test_case -> coeff_idx ] =
330+ test_case -> corruption_value ;
323331 }
324332
325333 /* Regenerate t0, t1, tr, and pk to be consistent with the corrupted vector
@@ -332,7 +340,7 @@ static int test_corrupted_sk(const corruption_test_case_t* test_case)
332340
333341 /* Expand matrix A from rho */
334342 mld_polyvec_matrix_expand (& mat , rho );
335-
343+
336344 /* Compute t = A * s1 + s2 */
337345 mld_polyvecl s1_ntt = s1 ;
338346 mld_polyvecl_ntt (& s1_ntt );
@@ -342,13 +350,14 @@ static int test_corrupted_sk(const corruption_test_case_t* test_case)
342350 mld_polyveck_add (& t1 , & s2 );
343351 mld_polyveck_reduce (& t1 );
344352 mld_polyveck_caddq (& t1 );
345-
353+
346354 /* Power2Round to get t1 and t0 */
347355 mld_polyveck_power2round (& t1 , & t0 , & t1 );
348-
356+
349357 /* Pack public key and compute tr */
350358 mld_pack_pk (pk_consistent , rho , & t1 );
351- mld_shake256 (tr_consistent , MLDSA_TRBYTES , pk_consistent , MLDSA_CRYPTO_PUBLICKEYBYTES );
359+ mld_shake256 (tr_consistent , MLDSA_TRBYTES , pk_consistent ,
360+ MLDSA_CRYPTO_PUBLICKEYBYTES );
352361
353362 /* Pack the corrupted secret key */
354363 mld_pack_sk (sk_corrupted , rho , tr_consistent , key , & t0 , & s1 , & s2 );
@@ -560,33 +569,38 @@ int main(void)
560569
561570 /* Run comprehensive corrupted key tests */
562571 corruption_test_case_t test_cases [] = {
563- /* Test s1 vector corruptions */
564- {CORRUPT_S1 , 0 , 0 , - (MLDSA_ETA + 1 ), "s1 underflow at [0][0]" },
565- {CORRUPT_S1 , 0 , 0 , MLDSA_ETA + 1 , "s1 overflow at [0][0]" },
566- {CORRUPT_S1 , 0 , 1 , - (MLDSA_ETA + 2 ), "s1 underflow at [0][1]" },
567- {CORRUPT_S1 , 0 , 1 , MLDSA_ETA + 2 , "s1 overflow at [0][1]" },
568- {CORRUPT_S1 , 0 , MLDSA_N - 1 , - (MLDSA_ETA + 1 ), "s1 underflow at [0][N-1]" },
569- {CORRUPT_S1 , 0 , MLDSA_N - 1 , MLDSA_ETA + 1 , "s1 overflow at [0][N-1]" },
570- {CORRUPT_S1 , MLDSA_L - 1 , 0 , - (MLDSA_ETA + 1 ), "s1 underflow at [L-1][0]" },
571- {CORRUPT_S1 , MLDSA_L - 1 , 0 , MLDSA_ETA + 1 , "s1 overflow at [L-1][0]" },
572-
573- /* Test s2 vector corruptions */
574- {CORRUPT_S2 , 0 , 0 , - (MLDSA_ETA + 1 ), "s2 underflow at [0][0]" },
575- {CORRUPT_S2 , 0 , 0 , MLDSA_ETA + 1 , "s2 overflow at [0][0]" },
576- {CORRUPT_S2 , 0 , 1 , - (MLDSA_ETA + 2 ), "s2 underflow at [0][1]" },
577- {CORRUPT_S2 , 0 , 1 , MLDSA_ETA + 2 , "s2 overflow at [0][1]" },
578- {CORRUPT_S2 , 0 , MLDSA_N - 1 , - (MLDSA_ETA + 1 ), "s2 underflow at [0][N-1]" },
579- {CORRUPT_S2 , 0 , MLDSA_N - 1 , MLDSA_ETA + 1 , "s2 overflow at [0][N-1]" },
580- {CORRUPT_S2 , MLDSA_K - 1 , 0 , - (MLDSA_ETA + 1 ), "s2 underflow at [K-1][0]" },
581- {CORRUPT_S2 , MLDSA_K - 1 , 0 , MLDSA_ETA + 1 , "s2 overflow at [K-1][0]" },
572+ /* Test s1 vector corruptions */
573+ {CORRUPT_S1 , 0 , 0 , - (MLDSA_ETA + 1 ), "s1 underflow at [0][0]" },
574+ {CORRUPT_S1 , 0 , 0 , MLDSA_ETA + 1 , "s1 overflow at [0][0]" },
575+ {CORRUPT_S1 , 0 , 1 , - (MLDSA_ETA + 2 ), "s1 underflow at [0][1]" },
576+ {CORRUPT_S1 , 0 , 1 , MLDSA_ETA + 2 , "s1 overflow at [0][1]" },
577+ {CORRUPT_S1 , 0 , MLDSA_N - 1 , - (MLDSA_ETA + 1 ),
578+ "s1 underflow at [0][N-1]" },
579+ {CORRUPT_S1 , 0 , MLDSA_N - 1 , MLDSA_ETA + 1 , "s1 overflow at [0][N-1]" },
580+ {CORRUPT_S1 , MLDSA_L - 1 , 0 , - (MLDSA_ETA + 1 ),
581+ "s1 underflow at [L-1][0]" },
582+ {CORRUPT_S1 , MLDSA_L - 1 , 0 , MLDSA_ETA + 1 , "s1 overflow at [L-1][0]" },
583+
584+ /* Test s2 vector corruptions */
585+ {CORRUPT_S2 , 0 , 0 , - (MLDSA_ETA + 1 ), "s2 underflow at [0][0]" },
586+ {CORRUPT_S2 , 0 , 0 , MLDSA_ETA + 1 , "s2 overflow at [0][0]" },
587+ {CORRUPT_S2 , 0 , 1 , - (MLDSA_ETA + 2 ), "s2 underflow at [0][1]" },
588+ {CORRUPT_S2 , 0 , 1 , MLDSA_ETA + 2 , "s2 overflow at [0][1]" },
589+ {CORRUPT_S2 , 0 , MLDSA_N - 1 , - (MLDSA_ETA + 1 ),
590+ "s2 underflow at [0][N-1]" },
591+ {CORRUPT_S2 , 0 , MLDSA_N - 1 , MLDSA_ETA + 1 , "s2 overflow at [0][N-1]" },
592+ {CORRUPT_S2 , MLDSA_K - 1 , 0 , - (MLDSA_ETA + 1 ),
593+ "s2 underflow at [K-1][0]" },
594+ {CORRUPT_S2 , MLDSA_K - 1 , 0 , MLDSA_ETA + 1 , "s2 overflow at [K-1][0]" },
582595 };
583-
596+
584597 size_t num_test_cases = sizeof (test_cases ) / sizeof (test_cases [0 ]);
585598 for (size_t num = 0 ; num < num_test_cases ; num ++ )
586599 {
587600 if (test_corrupted_sk (& test_cases [num ]))
588601 {
589- printf ("ERROR: Test case %zu failed: %s\n" , num , test_cases [num ].description );
602+ printf ("ERROR: Test case %zu failed: %s\n" , num ,
603+ test_cases [num ].description );
590604 return 1 ;
591605 }
592606 }
0 commit comments