diff --git a/mldsa/mldsa_native.h b/mldsa/mldsa_native.h index 67aa25bf5..ecd83a7ce 100644 --- a/mldsa/mldsa_native.h +++ b/mldsa/mldsa_native.h @@ -735,8 +735,16 @@ size_t MLD_API_NAMESPACE(prepare_domain_separation_prefix)( /************************************************* * Name: crypto_sign_pk_from_sk * - * Description: Derives public key from secret key with validation. - * Checks that t0 and tr stored in sk match recomputed values. + * Description: Performs basic validity checks on secret key, and derives + * public key. + * + * Referring to the decoding of the secret key + * `sk=(rho, K, tr, s1, s2, t0)` + * (cf. [@FIPS204, Algorithm 25 skDecode]), + * the following checks are performed: + * - Check that s1 and s2 have coefficients in + * [-MLDSA_ETA, MLDSA_ETA] + * - Check that t0 and tr stored in sk match recomputed values. * * Arguments: - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key * - const uint8_t sk[CRYPTO_SECRETKEYBYTES]: input secret key diff --git a/mldsa/src/sign.c b/mldsa/src/sign.c index ef5ea6bea..2896cf2b6 100644 --- a/mldsa/src/sign.c +++ b/mldsa/src/sign.c @@ -1297,7 +1297,7 @@ MLD_EXTERNAL_API int crypto_sign_pk_from_sk(uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES], const uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES]) { - uint8_t cmp, cmp0, cmp1; + uint8_t check, cmp0, cmp1, chk1, chk2; int ret; MLD_ALLOC(rho, uint8_t, MLDSA_SEEDBYTES); MLD_ALLOC(tr, uint8_t, MLDSA_TRBYTES); @@ -1320,6 +1320,10 @@ int crypto_sign_pk_from_sk(uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES], /* Unpack secret key */ mld_unpack_sk(rho, tr, key, t0, s1, s2, sk); + /* Validate s1 and s2 coefficients are within [-MLDSA_ETA, MLDSA_ETA] */ + chk1 = mld_polyvecl_chknorm(s1, MLDSA_ETA + 1) & 0xFF; + chk2 = mld_polyveck_chknorm(s2, MLDSA_ETA + 1) & 0xFF; + /* Recompute t0, t1, tr, and pk from rho, s1, s2 */ ret = mld_compute_t0_t1_tr_from_sk_components(t0_computed, t1, tr_computed, pk, rho, s1, s2); @@ -1333,11 +1337,11 @@ int crypto_sign_pk_from_sk(uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES], sizeof(mld_polyveck)); cmp1 = mld_ct_memcmp((const uint8_t *)tr, (const uint8_t *)tr_computed, MLDSA_TRBYTES); - cmp = mld_value_barrier_u8(cmp0 | cmp1); + check = mld_value_barrier_u8(cmp0 | cmp1 | chk1 | chk2); /* Declassify the final result of the validity check. */ - MLD_CT_TESTING_DECLASSIFY(&cmp, sizeof(cmp)); - ret = (cmp != 0) ? MLD_ERR_FAIL : 0; + MLD_CT_TESTING_DECLASSIFY(&check, sizeof(check)); + ret = (check != 0) ? MLD_ERR_FAIL : 0; cleanup: diff --git a/mldsa/src/sign.h b/mldsa/src/sign.h index 7c8ecc97c..03cb077f2 100644 --- a/mldsa/src/sign.h +++ b/mldsa/src/sign.h @@ -745,8 +745,16 @@ __contract__( /************************************************* * Name: crypto_sign_pk_from_sk * - * Description: Derives public key from secret key with validation. - * Checks that t0 and tr stored in sk match recomputed values. + * Description: Performs basic validity checks on secret key, and derives + * public key. + * + * Referring to the decoding of the secret key + * `sk=(rho, K, tr, s1, s2, t0)` + * (cf. [@FIPS204, Algorithm 25 skDecode]), + * the following checks are performed: + * - Check that s1 and s2 have coefficients in + * [-MLDSA_ETA, MLDSA_ETA] + * - Check that t0 and tr stored in sk match recomputed values. * * Arguments: - uint8_t pk[MLDSA_CRYPTO_PUBLICKEYBYTES]: output public key * - const uint8_t sk[MLDSA_CRYPTO_SECRETKEYBYTES]: input secret diff --git a/proofs/cbmc/crypto_sign_pk_from_sk/Makefile b/proofs/cbmc/crypto_sign_pk_from_sk/Makefile index 2168e005f..e7213ac91 100644 --- a/proofs/cbmc/crypto_sign_pk_from_sk/Makefile +++ b/proofs/cbmc/crypto_sign_pk_from_sk/Makefile @@ -22,6 +22,8 @@ PROJECT_SOURCES += $(SRCDIR)/mldsa/src/sign.c CHECK_FUNCTION_CONTRACTS=$(MLD_NAMESPACE)pk_from_sk USE_FUNCTION_CONTRACTS=$(MLD_NAMESPACE)unpack_sk USE_FUNCTION_CONTRACTS+=$(MLD_NAMESPACE)pack_pk +USE_FUNCTION_CONTRACTS+=$(MLD_NAMESPACE)polyvecl_chknorm +USE_FUNCTION_CONTRACTS+=$(MLD_NAMESPACE)polyveck_chknorm USE_FUNCTION_CONTRACTS+=mld_compute_t0_t1_tr_from_sk_components USE_FUNCTION_CONTRACTS+=mld_value_barrier_u8 USE_FUNCTION_CONTRACTS+=mld_ct_memcmp