@@ -94,27 +94,31 @@ typedef struct ossl_ml_kem_scalar_st {
9494} scalar ;
9595
9696/* Key material allocation layout */
97- #define DECLARE_ML_KEM_KEYDATA (name , rank , private_sz ) \
97+ #define DECLARE_ML_KEM_PUBKEYDATA (name , rank ) \
9898 struct name##_alloc { \
9999 /* Public vector |t| */ \
100100 scalar tbuf [(rank )]; \
101101 /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102- scalar mbuf [(rank )* (rank )] \
103- /* optional private key data */ \
104- private_sz \
102+ scalar mbuf [(rank )* (rank )]; \
103+ }
104+
105+ #define DECLARE_ML_KEM_PRVKEYDATA (name , rank ) \
106+ struct name##_alloc { \
107+ scalar sbuf[rank]; \
108+ uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
105109 }
106110
107111/* Declare variant-specific public and private storage */
108112#define DECLARE_ML_KEM_VARIANT_KEYDATA (bits ) \
109- DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
110- DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
111- scalar sbuf[ML_KEM_##bits##_RANK]; \
112- uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
113+ DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
114+ DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
115+
113116DECLARE_ML_KEM_VARIANT_KEYDATA (512 );
114117DECLARE_ML_KEM_VARIANT_KEYDATA (768 );
115118DECLARE_ML_KEM_VARIANT_KEYDATA (1024 );
116119#undef DECLARE_ML_KEM_VARIANT_KEYDATA
117- #undef DECLARE_ML_KEM_KEYDATA
120+ #undef DECLARE_ML_KEM_PUBKEYDATA
121+ #undef DECLARE_ML_KEM_PRVKEYDATA
118122
119123typedef __owur
120124int (* CBD_FUNC )(scalar * out , uint8_t in [ML_KEM_RANDOM_BYTES + 1 ],
@@ -1534,26 +1538,35 @@ int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
15341538/*
15351539 * After allocating storage for public or private key data, update the key
15361540 * component pointers to reference that storage.
1541+ *
1542+ * The caller should only store private data in `priv` *after* a successful
1543+ * (non-zero) return from this function.
15371544 */
15381545static __owur
1539- int add_storage (scalar * p , int private , ML_KEM_KEY * key )
1546+ int add_storage (scalar * pub , scalar * priv , int private , ML_KEM_KEY * key )
15401547{
15411548 int rank = key -> vinfo -> rank ;
15421549
1543- if (p == NULL )
1550+ if (pub == NULL || (private && priv == NULL )) {
1551+ /*
1552+ * One of these could be allocated correctly. It is legal to call free with a NULL
1553+ * pointer, so always attempt to free both allocations here
1554+ */
1555+ OPENSSL_free (pub );
1556+ OPENSSL_secure_free (priv );
15441557 return 0 ;
1558+ }
15451559
15461560 /*
1547- * We're adding key material, the seed buffer will now hold |rho| and
1548- * |pkhash|.
1561+ * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
15491562 */
1550- memset (key -> seedbuf , 0 , sizeof (key -> seedbuf ));
1551- key -> rho = key -> seedbuf ;
1552- key -> pkhash = key -> seedbuf + ML_KEM_RANDOM_BYTES ;
1563+ memset (key -> rho_pkhash , 0 , sizeof (key -> rho_pkhash ));
1564+ key -> rho = key -> rho_pkhash ;
1565+ key -> pkhash = key -> rho_pkhash + ML_KEM_RANDOM_BYTES ;
15531566 key -> d = key -> z = NULL ;
15541567
15551568 /* A public key needs space for |t| and |m| */
1556- key -> m = (key -> t = p ) + rank ;
1569+ key -> m = (key -> t = pub ) + rank ;
15571570
15581571 /*
15591572 * A private key also needs space for |s| and |z|.
@@ -1563,7 +1576,7 @@ int add_storage(scalar *p, int private, ML_KEM_KEY *key)
15631576 * non-NULL |d| pointer.
15641577 */
15651578 if (private )
1566- key -> z = (uint8_t * )(rank + (key -> s = key -> m + rank * rank ));
1579+ key -> z = (uint8_t * )(rank + (key -> s = priv ));
15671580 return 1 ;
15681581}
15691582
@@ -1574,27 +1587,29 @@ int add_storage(scalar *p, int private, ML_KEM_KEY *key)
15741587void
15751588ossl_ml_kem_key_reset (ML_KEM_KEY * key )
15761589{
1577- if (key -> t == NULL )
1578- return ;
1590+ /*
1591+ * seedbuf can be allocated and contain |z| and |d| if the key is
1592+ * being created from a private key encoding. Similarly a pending
1593+ * serialised (encoded) private key may be queued up to load.
1594+ * Clear and free that data now.
1595+ */
1596+ if (key -> seedbuf != NULL )
1597+ OPENSSL_secure_clear_free (key -> seedbuf , ML_KEM_SEED_BYTES );
1598+ if (ossl_ml_kem_have_dkenc (key ))
1599+ OPENSSL_secure_clear_free (key -> encoded_dk , key -> vinfo -> prvkey_bytes );
1600+
15791601 /*-
15801602 * Cleanse any sensitive data:
15811603 * - The private vector |s| is immediately followed by the FO failure
15821604 * secret |z|, and seed |d|, we can cleanse all three in one call.
1583- *
1584- * - Otherwise, when key->d is set, cleanse the stashed seed.
1585- *
1586- * If the memory has been allocated with secure memory, it will be cleared
1587- * before being free'd under the OPENSSL_secure_free call.
15881605 */
1589- if (ossl_ml_kem_have_prvkey (key )) {
1590- if (!CRYPTO_secure_allocated (key -> t ))
1591- OPENSSL_cleanse (key -> s , key -> vinfo -> rank * sizeof (scalar ) + 2 * ML_KEM_RANDOM_BYTES );
1592- OPENSSL_secure_free (key -> t );
1593- } else {
1606+ if (key -> t != NULL ) {
1607+ if (ossl_ml_kem_have_prvkey (key ))
1608+ OPENSSL_secure_clear_free (key -> s , key -> vinfo -> prvalloc );
15941609 OPENSSL_free (key -> t );
15951610 }
1596-
1597- key -> d = key -> z = (uint8_t * )(key -> s = key -> m = key -> t = NULL );
1611+ key -> d = key -> z = key -> seedbuf = key -> encoded_dk =
1612+ (uint8_t * )(key -> s = key -> m = key -> t = NULL );
15981613}
15991614
16001615/*
@@ -1640,7 +1655,7 @@ ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
16401655 key -> shake256_md = EVP_MD_fetch (libctx , "SHAKE256" , properties );
16411656 key -> sha3_256_md = EVP_MD_fetch (libctx , "SHA3-256" , properties );
16421657 key -> sha3_512_md = EVP_MD_fetch (libctx , "SHA3-512" , properties );
1643- key -> d = key -> z = key -> rho = key -> pkhash = key -> encoded_dk = NULL ;
1658+ key -> d = key -> z = key -> rho = key -> pkhash = key -> encoded_dk = key -> seedbuf = NULL ;
16441659 key -> s = key -> m = key -> t = NULL ;
16451660
16461661 if (key -> shake128_md != NULL
@@ -1660,7 +1675,8 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
16601675{
16611676 int ok = 0 ;
16621677 ML_KEM_KEY * ret ;
1663- void * tmp ;
1678+ void * tmp_pub ;
1679+ void * tmp_priv ;
16641680
16651681 /*
16661682 * Partially decoded keys, not yet imported or loaded, should never be
@@ -1669,9 +1685,11 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
16691685 if (ossl_ml_kem_decoded_key (key ))
16701686 return NULL ;
16711687
1672- if (key == NULL
1673- || (ret = OPENSSL_memdup (key , sizeof (* key ))) == NULL )
1688+ if (key == NULL )
1689+ return NULL ;
1690+ else if ((ret = OPENSSL_memdup (key , sizeof (* key ))) == NULL )
16741691 return NULL ;
1692+
16751693 ret -> d = ret -> z = ret -> rho = ret -> pkhash = NULL ;
16761694 ret -> s = ret -> m = ret -> t = NULL ;
16771695
@@ -1686,16 +1704,19 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
16861704 ok = 1 ;
16871705 break ;
16881706 case OSSL_KEYMGMT_SELECT_PUBLIC_KEY :
1689- ok = add_storage (OPENSSL_memdup (key -> t , key -> vinfo -> puballoc ), 0 , ret );
1690- ret -> rho = ret -> seedbuf ;
1691- ret -> pkhash = ret -> rho + ML_KEM_RANDOM_BYTES ;
1707+ ok = add_storage (OPENSSL_memdup (key -> t , key -> vinfo -> puballoc ), NULL , 0 , ret );
16921708 break ;
16931709 case OSSL_KEYMGMT_SELECT_PRIVATE_KEY :
1694- tmp = OPENSSL_secure_malloc (key -> vinfo -> prvalloc );
1695- if (tmp == NULL )
1710+ tmp_pub = OPENSSL_memdup (key -> t , key -> vinfo -> puballoc );
1711+ if (tmp_pub == NULL )
1712+ break ;
1713+ tmp_priv = OPENSSL_secure_malloc (key -> vinfo -> prvalloc );
1714+ if (tmp_priv == NULL ) {
1715+ OPENSSL_free (tmp_pub );
16961716 break ;
1697- memcpy (tmp , key -> t , key -> vinfo -> prvalloc );
1698- ok = add_storage (tmp , 1 , ret );
1717+ }
1718+ if ((ok = add_storage (tmp_pub , tmp_priv , 1 , ret )) != 0 )
1719+ memcpy (tmp_priv , key -> s , key -> vinfo -> prvalloc );
16991720 /* Duplicated keys retain |d|, if available */
17001721 if (key -> d != NULL )
17011722 ret -> d = ret -> z + ML_KEM_RANDOM_BYTES ;
@@ -1725,11 +1746,6 @@ void ossl_ml_kem_key_free(ML_KEM_KEY *key)
17251746 EVP_MD_free (key -> sha3_256_md );
17261747 EVP_MD_free (key -> sha3_512_md );
17271748
1728- if (ossl_ml_kem_decoded_key (key )) {
1729- OPENSSL_cleanse (key -> seedbuf , sizeof (key -> seedbuf ));
1730- if (ossl_ml_kem_have_dkenc (key ))
1731- OPENSSL_secure_clear_free (key -> encoded_dk , key -> vinfo -> prvkey_bytes );
1732- }
17331749 ossl_ml_kem_key_reset (key );
17341750 OPENSSL_free (key );
17351751}
@@ -1783,10 +1799,13 @@ ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY
17831799 || ossl_ml_kem_have_seed (key )
17841800 || seedlen != ML_KEM_SEED_BYTES )
17851801 return NULL ;
1786- /*
1787- * With no public or private key material on hand, we can use the seed
1788- * buffer for |z| and |d|, in that order.
1789- */
1802+
1803+ if (key -> seedbuf == NULL ) {
1804+ key -> seedbuf = OPENSSL_secure_malloc (seedlen );
1805+ if (key -> seedbuf == NULL )
1806+ return NULL ;
1807+ }
1808+
17901809 key -> z = key -> seedbuf ;
17911810 key -> d = key -> z + ML_KEM_RANDOM_BYTES ;
17921811 memcpy (key -> d , seed , ML_KEM_RANDOM_BYTES );
@@ -1813,7 +1832,7 @@ int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
18131832 || (mdctx = EVP_MD_CTX_new ()) == NULL )
18141833 return 0 ;
18151834
1816- if (add_storage (OPENSSL_malloc (vinfo -> puballoc ), 0 , key ))
1835+ if (add_storage (OPENSSL_malloc (vinfo -> puballoc ), NULL , 0 , key ))
18171836 ret = parse_pubkey (in , mdctx , key );
18181837
18191838 if (!ret )
@@ -1841,7 +1860,8 @@ int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
18411860 || (mdctx = EVP_MD_CTX_new ()) == NULL )
18421861 return 0 ;
18431862
1844- if (add_storage (OPENSSL_secure_malloc (vinfo -> prvalloc ), 1 , key ))
1863+ if (add_storage (OPENSSL_malloc (vinfo -> puballoc ),
1864+ OPENSSL_secure_malloc (vinfo -> prvalloc ), 1 , key ))
18451865 ret = parse_prvkey (in , mdctx , key );
18461866
18471867 if (!ret )
@@ -1870,10 +1890,10 @@ int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
18701890 if (pubenc != NULL && publen != vinfo -> pubkey_bytes )
18711891 return 0 ;
18721892
1873- if (ossl_ml_kem_have_seed ( key ) ) {
1893+ if (key -> seedbuf != NULL ) {
18741894 if (!ossl_ml_kem_encode_seed (seed , sizeof (seed ), key ))
18751895 return 0 ;
1876- key -> d = key -> z = NULL ;
1896+ ossl_ml_kem_key_reset ( key ) ;
18771897 } else if (RAND_priv_bytes_ex (key -> libctx , seed , sizeof (seed ),
18781898 key -> vinfo -> secbits ) <= 0 ) {
18791899 return 0 ;
@@ -1888,7 +1908,8 @@ int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
18881908 */
18891909 CONSTTIME_SECRET (seed , ML_KEM_SEED_BYTES );
18901910
1891- if (add_storage (OPENSSL_secure_malloc (vinfo -> prvalloc ), 1 , key ))
1911+ if (add_storage (OPENSSL_malloc (vinfo -> puballoc ),
1912+ OPENSSL_secure_malloc (vinfo -> prvalloc ), 1 , key ))
18921913 ret = genkey (seed , mdctx , pubenc , key );
18931914 OPENSSL_cleanse (seed , sizeof (seed ));
18941915
0 commit comments