Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions test/src/test_mldsa.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "../../mldsa/src/fips202/fips202.h"
#include "../../mldsa/src/packing.h"
#include "../../mldsa/src/params.h"
#include "../../mldsa/src/polyvec.h"
#include "../../mldsa/src/sign.h"
#include "../../mldsa/src/sys.h"
#include "../notrandombytes/notrandombytes.h"
Expand All @@ -28,6 +34,23 @@
} \
} while (0)

/* Enum to specify which vector to corrupt in tests */
typedef enum
{
CORRUPT_S1,
CORRUPT_S2
} corrupt_vector_t;

/* Struct to define a coefficient corruption test case */
typedef struct
{
corrupt_vector_t vector; /* Which vector to corrupt (s1 or s2) */
unsigned int poly_idx; /* Polynomial index within the vector */
unsigned int coeff_idx; /* Coefficient index within the polynomial */
int32_t corruption_value; /* Value to set the coefficient to */
const char *description; /* Description of the test case */
} corruption_test_case_t;


static int test_sign_core(uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES],
uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES],
Expand Down Expand Up @@ -212,6 +235,159 @@ static int test_pk_from_sk(void)
return 0;
}

/* Helper function to check if s1 and s2 coefficients are within valid range
* [-eta, eta] */
static int check_s1_s2_coeffs_in_range(
const uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES])
{
uint8_t rho[MLDSA_SEEDBYTES];
uint8_t tr[MLDSA_TRBYTES];
uint8_t key[MLDSA_SEEDBYTES];
mld_polyveck t0;
mld_polyvecl s1;
mld_polyveck s2;

/* Unpack the secret key to extract s1 and s2 */
mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk);

/* Check all coefficients in s1 are within [-MLDSA_ETA, MLDSA_ETA] */
for (unsigned int i = 0; i < MLDSA_L; i++)
{
for (unsigned int j = 0; j < MLDSA_N; j++)
{
int32_t coeff = s1.vec[i].coeffs[j];
if (coeff < -MLDSA_ETA || coeff > MLDSA_ETA)
{
return 0; /* Out of range */
}
}
}

/* Check all coefficients in s2 are within [-MLDSA_ETA, MLDSA_ETA] */
for (unsigned int i = 0; i < MLDSA_K; i++)
{
for (unsigned int j = 0; j < MLDSA_N; j++)
{
int32_t coeff = s2.vec[i].coeffs[j];
if (coeff < -MLDSA_ETA || coeff > MLDSA_ETA)
{
return 0; /* Out of range */
}
}
}

return 1; /* All coefficients in range */
}


/* Helper function to test crypto_sign_pk_from_sk with invalid s1 or s2
* coefficients */
static int test_corrupted_sk(const corruption_test_case_t *test_case)
{
uint8_t sk_corrupted[MLDSA_CRYPTO_SECRETKEYBYTES];
int rc;
const char *vector_name = (test_case->vector == CORRUPT_S1) ? "s1" : "s2";

/* Start from a valid key pair */
uint8_t pk_valid[MLDSA_CRYPTO_PUBLICKEYBYTES];
uint8_t sk_valid[MLDSA_CRYPTO_SECRETKEYBYTES];
CHECK(crypto_sign_keypair(pk_valid, sk_valid) == 0);

uint8_t rho[MLDSA_SEEDBYTES];
uint8_t tr[MLDSA_TRBYTES];
uint8_t key[MLDSA_SEEDBYTES];
mld_polyveck t0;
mld_polyvecl s1;
mld_polyveck s2;

mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk_valid);

/* Corrupt the specified coefficient */
if (test_case->vector == CORRUPT_S1)
{
/* Validate indices are within bounds */
if (test_case->poly_idx >= MLDSA_L || test_case->coeff_idx >= MLDSA_N)
{
printf("ERROR: s1 indices out of bounds: [%u][%u] (max [%u][%u])\n",
test_case->poly_idx, test_case->coeff_idx, MLDSA_L - 1,
MLDSA_N - 1);
return 1;
}
s1.vec[test_case->poly_idx].coeffs[test_case->coeff_idx] =
test_case->corruption_value;
}
else
{
/* Validate indices are within bounds */
if (test_case->poly_idx >= MLDSA_K || test_case->coeff_idx >= MLDSA_N)
{
printf("ERROR: s2 indices out of bounds: [%u][%u] (max [%u][%u])\n",
test_case->poly_idx, test_case->coeff_idx, MLDSA_K - 1,
MLDSA_N - 1);
return 1;
}
s2.vec[test_case->poly_idx].coeffs[test_case->coeff_idx] =
test_case->corruption_value;
}

/* Regenerate t0, t1, tr, and pk to be consistent with the corrupted vector
* https://github.com/aws/aws-lc/blob/0336dd78a0f2623c1f9b209a98cd497026d9c779/crypto/fipsmodule/ml_dsa/ml_dsa_ref/packing.c#L7-L61
*/
mld_polymat mat;
mld_polyveck t1;
uint8_t pk_consistent[MLDSA_CRYPTO_PUBLICKEYBYTES];
uint8_t tr_consistent[MLDSA_TRBYTES];

/* Expand matrix A from rho */
mld_polyvec_matrix_expand(&mat, rho);

/* Compute t = A * s1 + s2 */
mld_polyvecl s1_ntt = s1;
mld_polyvecl_ntt(&s1_ntt);
mld_polyvec_matrix_pointwise_montgomery(&t1, &mat, &s1_ntt);
mld_polyveck_reduce(&t1);
mld_polyveck_invntt_tomont(&t1);
mld_polyveck_add(&t1, &s2);
mld_polyveck_reduce(&t1);
mld_polyveck_caddq(&t1);

/* Power2Round to get t1 and t0 */
mld_polyveck_power2round(&t1, &t0, &t1);

/* Pack public key and compute tr */
mld_pack_pk(pk_consistent, rho, &t1);
mld_shake256(tr_consistent, MLDSA_TRBYTES, pk_consistent,
MLDSA_CRYPTO_PUBLICKEYBYTES);

/* Pack the corrupted secret key */
mld_pack_sk(sk_corrupted, rho, tr_consistent, key, &t0, &s1, &s2);

/* Verify that it is corrupted */
if (check_s1_s2_coeffs_in_range(sk_corrupted))
{
printf("ERROR: failed to corrupt secret key %s\n", vector_name);
return 1;
}

/* Test crypto_sign_pk_from_sk with corrupted key */
uint8_t pk_derived[MLDSA_CRYPTO_PUBLICKEYBYTES];
rc = crypto_sign_pk_from_sk(pk_derived, sk_corrupted);
MLD_CT_TESTING_DECLASSIFY(&rc, sizeof(int));

if (rc != -1)
{
printf(
"ERROR: pk_from_sk - should fail with corrupted secret key - %s - "
"%s[%u][%u] = %d\n",
test_case->description, vector_name, test_case->poly_idx,
test_case->coeff_idx, test_case->corruption_value);
return 1;
}

return 0;
}


static int test_wrong_pk(void)
{
uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES];
Expand Down Expand Up @@ -391,6 +567,44 @@ int main(void)
}
}

/* Run comprehensive corrupted key tests */
corruption_test_case_t test_cases[] = {
/* Test s1 vector corruptions */
{CORRUPT_S1, 0, 0, -(MLDSA_ETA + 1), "s1 underflow at [0][0]"},
{CORRUPT_S1, 0, 0, MLDSA_ETA + 1, "s1 overflow at [0][0]"},
{CORRUPT_S1, 0, 1, -(MLDSA_ETA + 2), "s1 underflow at [0][1]"},
{CORRUPT_S1, 0, 1, MLDSA_ETA + 2, "s1 overflow at [0][1]"},
{CORRUPT_S1, 0, MLDSA_N - 1, -(MLDSA_ETA + 1),
"s1 underflow at [0][N-1]"},
{CORRUPT_S1, 0, MLDSA_N - 1, MLDSA_ETA + 1, "s1 overflow at [0][N-1]"},
{CORRUPT_S1, MLDSA_L - 1, 0, -(MLDSA_ETA + 1),
"s1 underflow at [L-1][0]"},
{CORRUPT_S1, MLDSA_L - 1, 0, MLDSA_ETA + 1, "s1 overflow at [L-1][0]"},

/* Test s2 vector corruptions */
{CORRUPT_S2, 0, 0, -(MLDSA_ETA + 1), "s2 underflow at [0][0]"},
{CORRUPT_S2, 0, 0, MLDSA_ETA + 1, "s2 overflow at [0][0]"},
{CORRUPT_S2, 0, 1, -(MLDSA_ETA + 2), "s2 underflow at [0][1]"},
{CORRUPT_S2, 0, 1, MLDSA_ETA + 2, "s2 overflow at [0][1]"},
{CORRUPT_S2, 0, MLDSA_N - 1, -(MLDSA_ETA + 1),
"s2 underflow at [0][N-1]"},
{CORRUPT_S2, 0, MLDSA_N - 1, MLDSA_ETA + 1, "s2 overflow at [0][N-1]"},
{CORRUPT_S2, MLDSA_K - 1, 0, -(MLDSA_ETA + 1),
"s2 underflow at [K-1][0]"},
{CORRUPT_S2, MLDSA_K - 1, 0, MLDSA_ETA + 1, "s2 overflow at [K-1][0]"},
};

size_t num_test_cases = sizeof(test_cases) / sizeof(test_cases[0]);
for (size_t num = 0; num < num_test_cases; num++)
{
if (test_corrupted_sk(&test_cases[num]))
{
printf("ERROR: Test case %zu failed: %s\n", num,
test_cases[num].description);
return 1;
}
}

printf("MLDSA_CRYPTO_SECRETKEYBYTES: %d\n", MLDSA_CRYPTO_SECRETKEYBYTES);
printf("MLDSA_CRYPTO_PUBLICKEYBYTES: %d\n", MLDSA_CRYPTO_PUBLICKEYBYTES);
printf("MLDSA_CRYPTO_BYTES: %d\n", MLDSA_CRYPTO_BYTES);
Expand Down
Loading