Skip to content

Commit 6689812

Browse files
committed
allow creation of pk from sk following aws/aws-lc#2142
1 parent 45e5f4e commit 6689812

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,40 @@ def verify(self, pk_bytes, m, sig_bytes, ctx=b""):
420420

421421
return self._verify_internal(pk_bytes, m_prime, sig_bytes)
422422

423+
"""
424+
The following additional function follows an outline from:
425+
https://github.com/aws/aws-lc/pull/2142
426+
which computes pk_bytes when only the sk_bytes are known.
427+
"""
428+
429+
def pk_from_sk(self, sk_bytes: bytes) -> bytes:
430+
"""
431+
Given the packed representation of a ML-DSA secret key,
432+
compute the corresponding packed public key bytes.
433+
"""
434+
# First unpack the secret key
435+
rho, K, tr, s1, s2, t0 = self._unpack_sk(sk_bytes)
436+
437+
# Compute the matrix A from rho in NTT form
438+
A_hat = self._expand_matrix_from_seed(rho)
439+
440+
# Convert s1 to NTT form
441+
s1_hat = s1.to_ntt()
442+
443+
# Compute the polynomial t, we have the lower bits t0,
444+
# but we need the higher bits t1 for the public key
445+
t = (A_hat @ s1_hat).from_ntt() + s2
446+
t1, _ = t.power_2_round(self.d)
447+
448+
# The packed public key is made from rho || t1
449+
pk_bytes = self._pack_pk(rho, t1)
450+
451+
# Ensure the public key matches the hash within the secret key
452+
if tr != self._h(pk_bytes, 64):
453+
raise ValueError("maleformed secret key")
454+
455+
return pk_bytes
456+
423457
"""
424458
The following external mu functions are not in FIPS 204, but are in
425459
Appendix D of the following IETF draft and are included for experimentation

src/dilithium_py/ml_dsa/pkcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,10 @@ def sk_from_der(enc_key):
266266
if not expanded:
267267
pk, expanded = ml_dsa.key_derive(seed)
268268

269-
if not pk and seed:
270-
# deriving verifying key from expanded key in ML-DSA is non-trivial
271-
# do that only for encodings that include the seed
272-
pk, _ = ml_dsa.key_derive(seed)
269+
if not pk:
270+
# If we reach here, we need to compute the public key
271+
# directly from the secret key bytes
272+
pk = ml_dsa.pk_from_sk(expanded)
273273

274274
return ml_dsa, expanded, seed, pk
275275

tests/test_ml_dsa.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def generic_test_ml_dsa(self, ML_DSA, count=5):
3131
check_wrong_msg = ML_DSA.verify(pk, b"", sig, ctx=ctx)
3232
check_no_ctx = ML_DSA.verify(pk, msg, sig)
3333

34+
# Generate the public key directly from the secret key
35+
recovered_pk = ML_DSA.pk_from_sk(sk)
36+
37+
# Check that recovering the pk works
38+
self.assertEqual(pk, recovered_pk)
39+
3440
# Check that signature works
3541
self.assertTrue(check_verify)
3642

0 commit comments

Comments
 (0)