Skip to content

Commit 426ba6e

Browse files
sgmendamkannwischer
authored andcommitted
add test for validation of s1 and s2
Signed-off-by: sanketh <sgmenda@amazon.com>
1 parent 33469ef commit 426ba6e

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed

test/src/test_mldsa.c

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
#include <stddef.h>
77
#include <stdio.h>
8+
#include <stdlib.h>
89
#include <string.h>
10+
11+
#include "../../mldsa/src/fips202/fips202.h"
12+
#include "../../mldsa/src/packing.h"
13+
#include "../../mldsa/src/params.h"
14+
#include "../../mldsa/src/polyvec.h"
915
#include "../../mldsa/src/sign.h"
1016
#include "../../mldsa/src/sys.h"
1117
#include "../notrandombytes/notrandombytes.h"
@@ -28,6 +34,23 @@
2834
} \
2935
} while (0)
3036

37+
/* Enum to specify which vector to corrupt in tests */
38+
typedef enum
39+
{
40+
CORRUPT_S1,
41+
CORRUPT_S2
42+
} corrupt_vector_t;
43+
44+
/* Struct to define a coefficient corruption test case */
45+
typedef struct
46+
{
47+
corrupt_vector_t vector; /* Which vector to corrupt (s1 or s2) */
48+
unsigned int poly_idx; /* Polynomial index within the vector */
49+
unsigned int coeff_idx; /* Coefficient index within the polynomial */
50+
int32_t corruption_value; /* Value to set the coefficient to */
51+
const char *description; /* Description of the test case */
52+
} corruption_test_case_t;
53+
3154

3255
static int test_sign_core(uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES],
3356
uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES],
@@ -212,6 +235,159 @@ static int test_pk_from_sk(void)
212235
return 0;
213236
}
214237

238+
/* Helper function to check if s1 and s2 coefficients are within valid range
239+
* [-eta, eta] */
240+
static int check_s1_s2_coeffs_in_range(
241+
const uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES])
242+
{
243+
uint8_t rho[MLDSA_SEEDBYTES];
244+
uint8_t tr[MLDSA_TRBYTES];
245+
uint8_t key[MLDSA_SEEDBYTES];
246+
mld_polyveck t0;
247+
mld_polyvecl s1;
248+
mld_polyveck s2;
249+
250+
/* Unpack the secret key to extract s1 and s2 */
251+
mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk);
252+
253+
/* Check all coefficients in s1 are within [-MLDSA_ETA, MLDSA_ETA] */
254+
for (unsigned int i = 0; i < MLDSA_L; i++)
255+
{
256+
for (unsigned int j = 0; j < MLDSA_N; j++)
257+
{
258+
int32_t coeff = s1.vec[i].coeffs[j];
259+
if (coeff < -MLDSA_ETA || coeff > MLDSA_ETA)
260+
{
261+
return 0; /* Out of range */
262+
}
263+
}
264+
}
265+
266+
/* Check all coefficients in s2 are within [-MLDSA_ETA, MLDSA_ETA] */
267+
for (unsigned int i = 0; i < MLDSA_K; i++)
268+
{
269+
for (unsigned int j = 0; j < MLDSA_N; j++)
270+
{
271+
int32_t coeff = s2.vec[i].coeffs[j];
272+
if (coeff < -MLDSA_ETA || coeff > MLDSA_ETA)
273+
{
274+
return 0; /* Out of range */
275+
}
276+
}
277+
}
278+
279+
return 1; /* All coefficients in range */
280+
}
281+
282+
283+
/* Helper function to test crypto_sign_pk_from_sk with invalid s1 or s2
284+
* coefficients */
285+
static int test_corrupted_sk(const corruption_test_case_t *test_case)
286+
{
287+
uint8_t sk_corrupted[MLDSA_CRYPTO_SECRETKEYBYTES];
288+
int rc;
289+
const char *vector_name = (test_case->vector == CORRUPT_S1) ? "s1" : "s2";
290+
291+
/* Start from a valid key pair */
292+
uint8_t pk_valid[MLDSA_CRYPTO_PUBLICKEYBYTES];
293+
uint8_t sk_valid[MLDSA_CRYPTO_SECRETKEYBYTES];
294+
CHECK(crypto_sign_keypair(pk_valid, sk_valid) == 0);
295+
296+
uint8_t rho[MLDSA_SEEDBYTES];
297+
uint8_t tr[MLDSA_TRBYTES];
298+
uint8_t key[MLDSA_SEEDBYTES];
299+
mld_polyveck t0;
300+
mld_polyvecl s1;
301+
mld_polyveck s2;
302+
303+
mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk_valid);
304+
305+
/* Corrupt the specified coefficient */
306+
if (test_case->vector == CORRUPT_S1)
307+
{
308+
/* Validate indices are within bounds */
309+
if (test_case->poly_idx >= MLDSA_L || test_case->coeff_idx >= MLDSA_N)
310+
{
311+
printf("ERROR: s1 indices out of bounds: [%u][%u] (max [%u][%u])\n",
312+
test_case->poly_idx, test_case->coeff_idx, MLDSA_L - 1,
313+
MLDSA_N - 1);
314+
return 1;
315+
}
316+
s1.vec[test_case->poly_idx].coeffs[test_case->coeff_idx] =
317+
test_case->corruption_value;
318+
}
319+
else
320+
{
321+
/* Validate indices are within bounds */
322+
if (test_case->poly_idx >= MLDSA_K || test_case->coeff_idx >= MLDSA_N)
323+
{
324+
printf("ERROR: s2 indices out of bounds: [%u][%u] (max [%u][%u])\n",
325+
test_case->poly_idx, test_case->coeff_idx, MLDSA_K - 1,
326+
MLDSA_N - 1);
327+
return 1;
328+
}
329+
s2.vec[test_case->poly_idx].coeffs[test_case->coeff_idx] =
330+
test_case->corruption_value;
331+
}
332+
333+
/* Regenerate t0, t1, tr, and pk to be consistent with the corrupted vector
334+
* https://github.com/aws/aws-lc/blob/0336dd78a0f2623c1f9b209a98cd497026d9c779/crypto/fipsmodule/ml_dsa/ml_dsa_ref/packing.c#L7-L61
335+
*/
336+
mld_polymat mat;
337+
mld_polyveck t1;
338+
uint8_t pk_consistent[MLDSA_CRYPTO_PUBLICKEYBYTES];
339+
uint8_t tr_consistent[MLDSA_TRBYTES];
340+
341+
/* Expand matrix A from rho */
342+
mld_polyvec_matrix_expand(&mat, rho);
343+
344+
/* Compute t = A * s1 + s2 */
345+
mld_polyvecl s1_ntt = s1;
346+
mld_polyvecl_ntt(&s1_ntt);
347+
mld_polyvec_matrix_pointwise_montgomery(&t1, &mat, &s1_ntt);
348+
mld_polyveck_reduce(&t1);
349+
mld_polyveck_invntt_tomont(&t1);
350+
mld_polyveck_add(&t1, &s2);
351+
mld_polyveck_reduce(&t1);
352+
mld_polyveck_caddq(&t1);
353+
354+
/* Power2Round to get t1 and t0 */
355+
mld_polyveck_power2round(&t1, &t0, &t1);
356+
357+
/* Pack public key and compute tr */
358+
mld_pack_pk(pk_consistent, rho, &t1);
359+
mld_shake256(tr_consistent, MLDSA_TRBYTES, pk_consistent,
360+
MLDSA_CRYPTO_PUBLICKEYBYTES);
361+
362+
/* Pack the corrupted secret key */
363+
mld_pack_sk(sk_corrupted, rho, tr_consistent, key, &t0, &s1, &s2);
364+
365+
/* Verify that it is corrupted */
366+
if (check_s1_s2_coeffs_in_range(sk_corrupted))
367+
{
368+
printf("ERROR: failed to corrupt secret key %s\n", vector_name);
369+
return 1;
370+
}
371+
372+
/* Test crypto_sign_pk_from_sk with corrupted key */
373+
uint8_t pk_derived[MLDSA_CRYPTO_PUBLICKEYBYTES];
374+
rc = crypto_sign_pk_from_sk(pk_derived, sk_corrupted);
375+
MLD_CT_TESTING_DECLASSIFY(&rc, sizeof(int));
376+
377+
if (rc != -1)
378+
{
379+
printf(
380+
"ERROR: pk_from_sk - should fail with corrupted secret key - %s - "
381+
"%s[%u][%u] = %d\n",
382+
test_case->description, vector_name, test_case->poly_idx,
383+
test_case->coeff_idx, test_case->corruption_value);
384+
return 1;
385+
}
386+
387+
return 0;
388+
}
389+
390+
215391
static int test_wrong_pk(void)
216392
{
217393
uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES];
@@ -391,6 +567,44 @@ int main(void)
391567
}
392568
}
393569

570+
/* Run comprehensive corrupted key tests */
571+
corruption_test_case_t test_cases[] = {
572+
/* Test s1 vector corruptions */
573+
{CORRUPT_S1, 0, 0, -(MLDSA_ETA + 1), "s1 underflow at [0][0]"},
574+
{CORRUPT_S1, 0, 0, MLDSA_ETA + 1, "s1 overflow at [0][0]"},
575+
{CORRUPT_S1, 0, 1, -(MLDSA_ETA + 2), "s1 underflow at [0][1]"},
576+
{CORRUPT_S1, 0, 1, MLDSA_ETA + 2, "s1 overflow at [0][1]"},
577+
{CORRUPT_S1, 0, MLDSA_N - 1, -(MLDSA_ETA + 1),
578+
"s1 underflow at [0][N-1]"},
579+
{CORRUPT_S1, 0, MLDSA_N - 1, MLDSA_ETA + 1, "s1 overflow at [0][N-1]"},
580+
{CORRUPT_S1, MLDSA_L - 1, 0, -(MLDSA_ETA + 1),
581+
"s1 underflow at [L-1][0]"},
582+
{CORRUPT_S1, MLDSA_L - 1, 0, MLDSA_ETA + 1, "s1 overflow at [L-1][0]"},
583+
584+
/* Test s2 vector corruptions */
585+
{CORRUPT_S2, 0, 0, -(MLDSA_ETA + 1), "s2 underflow at [0][0]"},
586+
{CORRUPT_S2, 0, 0, MLDSA_ETA + 1, "s2 overflow at [0][0]"},
587+
{CORRUPT_S2, 0, 1, -(MLDSA_ETA + 2), "s2 underflow at [0][1]"},
588+
{CORRUPT_S2, 0, 1, MLDSA_ETA + 2, "s2 overflow at [0][1]"},
589+
{CORRUPT_S2, 0, MLDSA_N - 1, -(MLDSA_ETA + 1),
590+
"s2 underflow at [0][N-1]"},
591+
{CORRUPT_S2, 0, MLDSA_N - 1, MLDSA_ETA + 1, "s2 overflow at [0][N-1]"},
592+
{CORRUPT_S2, MLDSA_K - 1, 0, -(MLDSA_ETA + 1),
593+
"s2 underflow at [K-1][0]"},
594+
{CORRUPT_S2, MLDSA_K - 1, 0, MLDSA_ETA + 1, "s2 overflow at [K-1][0]"},
595+
};
596+
597+
size_t num_test_cases = sizeof(test_cases) / sizeof(test_cases[0]);
598+
for (size_t num = 0; num < num_test_cases; num++)
599+
{
600+
if (test_corrupted_sk(&test_cases[num]))
601+
{
602+
printf("ERROR: Test case %zu failed: %s\n", num,
603+
test_cases[num].description);
604+
return 1;
605+
}
606+
}
607+
394608
printf("MLDSA_CRYPTO_SECRETKEYBYTES: %d\n", MLDSA_CRYPTO_SECRETKEYBYTES);
395609
printf("MLDSA_CRYPTO_PUBLICKEYBYTES: %d\n", MLDSA_CRYPTO_PUBLICKEYBYTES);
396610
printf("MLDSA_CRYPTO_BYTES: %d\n", MLDSA_CRYPTO_BYTES);

0 commit comments

Comments
 (0)