|
5 | 5 |
|
6 | 6 | #include <stddef.h> |
7 | 7 | #include <stdio.h> |
| 8 | +#include <stdlib.h> |
8 | 9 | #include <string.h> |
| 10 | + |
| 11 | +#include "../../mldsa/src/fips202/fips202.h" |
| 12 | +#include "../../mldsa/src/packing.h" |
| 13 | +#include "../../mldsa/src/params.h" |
| 14 | +#include "../../mldsa/src/polyvec.h" |
9 | 15 | #include "../../mldsa/src/sign.h" |
10 | 16 | #include "../../mldsa/src/sys.h" |
11 | 17 | #include "../notrandombytes/notrandombytes.h" |
|
28 | 34 | } \ |
29 | 35 | } while (0) |
30 | 36 |
|
| 37 | +/* Enum to specify which vector to corrupt in tests */ |
| 38 | +typedef enum |
| 39 | +{ |
| 40 | + CORRUPT_S1, |
| 41 | + CORRUPT_S2 |
| 42 | +} corrupt_vector_t; |
| 43 | + |
| 44 | +/* Struct to define a coefficient corruption test case */ |
| 45 | +typedef struct |
| 46 | +{ |
| 47 | + corrupt_vector_t vector; /* Which vector to corrupt (s1 or s2) */ |
| 48 | + unsigned int poly_idx; /* Polynomial index within the vector */ |
| 49 | + unsigned int coeff_idx; /* Coefficient index within the polynomial */ |
| 50 | + int32_t corruption_value; /* Value to set the coefficient to */ |
| 51 | + const char *description; /* Description of the test case */ |
| 52 | +} corruption_test_case_t; |
| 53 | + |
31 | 54 |
|
32 | 55 | static int test_sign_core(uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES], |
33 | 56 | uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES], |
@@ -212,6 +235,159 @@ static int test_pk_from_sk(void) |
212 | 235 | return 0; |
213 | 236 | } |
214 | 237 |
|
| 238 | +/* Helper function to check if s1 and s2 coefficients are within valid range |
| 239 | + * [-eta, eta] */ |
| 240 | +static int check_s1_s2_coeffs_in_range( |
| 241 | + const uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES]) |
| 242 | +{ |
| 243 | + uint8_t rho[MLDSA_SEEDBYTES]; |
| 244 | + uint8_t tr[MLDSA_TRBYTES]; |
| 245 | + uint8_t key[MLDSA_SEEDBYTES]; |
| 246 | + mld_polyveck t0; |
| 247 | + mld_polyvecl s1; |
| 248 | + mld_polyveck s2; |
| 249 | + |
| 250 | + /* Unpack the secret key to extract s1 and s2 */ |
| 251 | + mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk); |
| 252 | + |
| 253 | + /* Check all coefficients in s1 are within [-MLDSA_ETA, MLDSA_ETA] */ |
| 254 | + for (unsigned int i = 0; i < MLDSA_L; i++) |
| 255 | + { |
| 256 | + for (unsigned int j = 0; j < MLDSA_N; j++) |
| 257 | + { |
| 258 | + int32_t coeff = s1.vec[i].coeffs[j]; |
| 259 | + if (coeff < -MLDSA_ETA || coeff > MLDSA_ETA) |
| 260 | + { |
| 261 | + return 0; /* Out of range */ |
| 262 | + } |
| 263 | + } |
| 264 | + } |
| 265 | + |
| 266 | + /* Check all coefficients in s2 are within [-MLDSA_ETA, MLDSA_ETA] */ |
| 267 | + for (unsigned int i = 0; i < MLDSA_K; i++) |
| 268 | + { |
| 269 | + for (unsigned int j = 0; j < MLDSA_N; j++) |
| 270 | + { |
| 271 | + int32_t coeff = s2.vec[i].coeffs[j]; |
| 272 | + if (coeff < -MLDSA_ETA || coeff > MLDSA_ETA) |
| 273 | + { |
| 274 | + return 0; /* Out of range */ |
| 275 | + } |
| 276 | + } |
| 277 | + } |
| 278 | + |
| 279 | + return 1; /* All coefficients in range */ |
| 280 | +} |
| 281 | + |
| 282 | + |
| 283 | +/* Helper function to test crypto_sign_pk_from_sk with invalid s1 or s2 |
| 284 | + * coefficients */ |
| 285 | +static int test_corrupted_sk(const corruption_test_case_t *test_case) |
| 286 | +{ |
| 287 | + uint8_t sk_corrupted[MLDSA_CRYPTO_SECRETKEYBYTES]; |
| 288 | + int rc; |
| 289 | + const char *vector_name = (test_case->vector == CORRUPT_S1) ? "s1" : "s2"; |
| 290 | + |
| 291 | + /* Start from a valid key pair */ |
| 292 | + uint8_t pk_valid[MLDSA_CRYPTO_PUBLICKEYBYTES]; |
| 293 | + uint8_t sk_valid[MLDSA_CRYPTO_SECRETKEYBYTES]; |
| 294 | + CHECK(crypto_sign_keypair(pk_valid, sk_valid) == 0); |
| 295 | + |
| 296 | + uint8_t rho[MLDSA_SEEDBYTES]; |
| 297 | + uint8_t tr[MLDSA_TRBYTES]; |
| 298 | + uint8_t key[MLDSA_SEEDBYTES]; |
| 299 | + mld_polyveck t0; |
| 300 | + mld_polyvecl s1; |
| 301 | + mld_polyveck s2; |
| 302 | + |
| 303 | + mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk_valid); |
| 304 | + |
| 305 | + /* Corrupt the specified coefficient */ |
| 306 | + if (test_case->vector == CORRUPT_S1) |
| 307 | + { |
| 308 | + /* Validate indices are within bounds */ |
| 309 | + if (test_case->poly_idx >= MLDSA_L || test_case->coeff_idx >= MLDSA_N) |
| 310 | + { |
| 311 | + printf("ERROR: s1 indices out of bounds: [%u][%u] (max [%u][%u])\n", |
| 312 | + test_case->poly_idx, test_case->coeff_idx, MLDSA_L - 1, |
| 313 | + MLDSA_N - 1); |
| 314 | + return 1; |
| 315 | + } |
| 316 | + s1.vec[test_case->poly_idx].coeffs[test_case->coeff_idx] = |
| 317 | + test_case->corruption_value; |
| 318 | + } |
| 319 | + else |
| 320 | + { |
| 321 | + /* Validate indices are within bounds */ |
| 322 | + if (test_case->poly_idx >= MLDSA_K || test_case->coeff_idx >= MLDSA_N) |
| 323 | + { |
| 324 | + printf("ERROR: s2 indices out of bounds: [%u][%u] (max [%u][%u])\n", |
| 325 | + test_case->poly_idx, test_case->coeff_idx, MLDSA_K - 1, |
| 326 | + MLDSA_N - 1); |
| 327 | + return 1; |
| 328 | + } |
| 329 | + s2.vec[test_case->poly_idx].coeffs[test_case->coeff_idx] = |
| 330 | + test_case->corruption_value; |
| 331 | + } |
| 332 | + |
| 333 | + /* Regenerate t0, t1, tr, and pk to be consistent with the corrupted vector |
| 334 | + * https://github.com/aws/aws-lc/blob/0336dd78a0f2623c1f9b209a98cd497026d9c779/crypto/fipsmodule/ml_dsa/ml_dsa_ref/packing.c#L7-L61 |
| 335 | + */ |
| 336 | + mld_polymat mat; |
| 337 | + mld_polyveck t1; |
| 338 | + uint8_t pk_consistent[MLDSA_CRYPTO_PUBLICKEYBYTES]; |
| 339 | + uint8_t tr_consistent[MLDSA_TRBYTES]; |
| 340 | + |
| 341 | + /* Expand matrix A from rho */ |
| 342 | + mld_polyvec_matrix_expand(&mat, rho); |
| 343 | + |
| 344 | + /* Compute t = A * s1 + s2 */ |
| 345 | + mld_polyvecl s1_ntt = s1; |
| 346 | + mld_polyvecl_ntt(&s1_ntt); |
| 347 | + mld_polyvec_matrix_pointwise_montgomery(&t1, &mat, &s1_ntt); |
| 348 | + mld_polyveck_reduce(&t1); |
| 349 | + mld_polyveck_invntt_tomont(&t1); |
| 350 | + mld_polyveck_add(&t1, &s2); |
| 351 | + mld_polyveck_reduce(&t1); |
| 352 | + mld_polyveck_caddq(&t1); |
| 353 | + |
| 354 | + /* Power2Round to get t1 and t0 */ |
| 355 | + mld_polyveck_power2round(&t1, &t0, &t1); |
| 356 | + |
| 357 | + /* Pack public key and compute tr */ |
| 358 | + mld_pack_pk(pk_consistent, rho, &t1); |
| 359 | + mld_shake256(tr_consistent, MLDSA_TRBYTES, pk_consistent, |
| 360 | + MLDSA_CRYPTO_PUBLICKEYBYTES); |
| 361 | + |
| 362 | + /* Pack the corrupted secret key */ |
| 363 | + mld_pack_sk(sk_corrupted, rho, tr_consistent, key, &t0, &s1, &s2); |
| 364 | + |
| 365 | + /* Verify that it is corrupted */ |
| 366 | + if (check_s1_s2_coeffs_in_range(sk_corrupted)) |
| 367 | + { |
| 368 | + printf("ERROR: failed to corrupt secret key %s\n", vector_name); |
| 369 | + return 1; |
| 370 | + } |
| 371 | + |
| 372 | + /* Test crypto_sign_pk_from_sk with corrupted key */ |
| 373 | + uint8_t pk_derived[MLDSA_CRYPTO_PUBLICKEYBYTES]; |
| 374 | + rc = crypto_sign_pk_from_sk(pk_derived, sk_corrupted); |
| 375 | + MLD_CT_TESTING_DECLASSIFY(&rc, sizeof(int)); |
| 376 | + |
| 377 | + if (rc != -1) |
| 378 | + { |
| 379 | + printf( |
| 380 | + "ERROR: pk_from_sk - should fail with corrupted secret key - %s - " |
| 381 | + "%s[%u][%u] = %d\n", |
| 382 | + test_case->description, vector_name, test_case->poly_idx, |
| 383 | + test_case->coeff_idx, test_case->corruption_value); |
| 384 | + return 1; |
| 385 | + } |
| 386 | + |
| 387 | + return 0; |
| 388 | +} |
| 389 | + |
| 390 | + |
215 | 391 | static int test_wrong_pk(void) |
216 | 392 | { |
217 | 393 | uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES]; |
@@ -391,6 +567,44 @@ int main(void) |
391 | 567 | } |
392 | 568 | } |
393 | 569 |
|
| 570 | + /* Run comprehensive corrupted key tests */ |
| 571 | + corruption_test_case_t test_cases[] = { |
| 572 | + /* Test s1 vector corruptions */ |
| 573 | + {CORRUPT_S1, 0, 0, -(MLDSA_ETA + 1), "s1 underflow at [0][0]"}, |
| 574 | + {CORRUPT_S1, 0, 0, MLDSA_ETA + 1, "s1 overflow at [0][0]"}, |
| 575 | + {CORRUPT_S1, 0, 1, -(MLDSA_ETA + 2), "s1 underflow at [0][1]"}, |
| 576 | + {CORRUPT_S1, 0, 1, MLDSA_ETA + 2, "s1 overflow at [0][1]"}, |
| 577 | + {CORRUPT_S1, 0, MLDSA_N - 1, -(MLDSA_ETA + 1), |
| 578 | + "s1 underflow at [0][N-1]"}, |
| 579 | + {CORRUPT_S1, 0, MLDSA_N - 1, MLDSA_ETA + 1, "s1 overflow at [0][N-1]"}, |
| 580 | + {CORRUPT_S1, MLDSA_L - 1, 0, -(MLDSA_ETA + 1), |
| 581 | + "s1 underflow at [L-1][0]"}, |
| 582 | + {CORRUPT_S1, MLDSA_L - 1, 0, MLDSA_ETA + 1, "s1 overflow at [L-1][0]"}, |
| 583 | + |
| 584 | + /* Test s2 vector corruptions */ |
| 585 | + {CORRUPT_S2, 0, 0, -(MLDSA_ETA + 1), "s2 underflow at [0][0]"}, |
| 586 | + {CORRUPT_S2, 0, 0, MLDSA_ETA + 1, "s2 overflow at [0][0]"}, |
| 587 | + {CORRUPT_S2, 0, 1, -(MLDSA_ETA + 2), "s2 underflow at [0][1]"}, |
| 588 | + {CORRUPT_S2, 0, 1, MLDSA_ETA + 2, "s2 overflow at [0][1]"}, |
| 589 | + {CORRUPT_S2, 0, MLDSA_N - 1, -(MLDSA_ETA + 1), |
| 590 | + "s2 underflow at [0][N-1]"}, |
| 591 | + {CORRUPT_S2, 0, MLDSA_N - 1, MLDSA_ETA + 1, "s2 overflow at [0][N-1]"}, |
| 592 | + {CORRUPT_S2, MLDSA_K - 1, 0, -(MLDSA_ETA + 1), |
| 593 | + "s2 underflow at [K-1][0]"}, |
| 594 | + {CORRUPT_S2, MLDSA_K - 1, 0, MLDSA_ETA + 1, "s2 overflow at [K-1][0]"}, |
| 595 | + }; |
| 596 | + |
| 597 | + size_t num_test_cases = sizeof(test_cases) / sizeof(test_cases[0]); |
| 598 | + for (size_t num = 0; num < num_test_cases; num++) |
| 599 | + { |
| 600 | + if (test_corrupted_sk(&test_cases[num])) |
| 601 | + { |
| 602 | + printf("ERROR: Test case %zu failed: %s\n", num, |
| 603 | + test_cases[num].description); |
| 604 | + return 1; |
| 605 | + } |
| 606 | + } |
| 607 | + |
394 | 608 | printf("MLDSA_CRYPTO_SECRETKEYBYTES: %d\n", MLDSA_CRYPTO_SECRETKEYBYTES); |
395 | 609 | printf("MLDSA_CRYPTO_PUBLICKEYBYTES: %d\n", MLDSA_CRYPTO_PUBLICKEYBYTES); |
396 | 610 | printf("MLDSA_CRYPTO_BYTES: %d\n", MLDSA_CRYPTO_BYTES); |
|
0 commit comments