Skip to content

Commit 08b01de

Browse files
authored
Merge pull request #18 from GiacomoPope/pk_from_sk
allow creation of pk from sk
2 parents 45e5f4e + 3972bb6 commit 08b01de

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,40 @@ def verify(self, pk_bytes, m, sig_bytes, ctx=b""):
420420

421421
return self._verify_internal(pk_bytes, m_prime, sig_bytes)
422422

423+
"""
424+
The following additional function follows an outline from:
425+
https://github.com/aws/aws-lc/pull/2142
426+
which computes pk_bytes when only the sk_bytes are known.
427+
"""
428+
429+
def pk_from_sk(self, sk_bytes: bytes) -> bytes:
430+
"""
431+
Given the packed representation of a ML-DSA secret key,
432+
compute the corresponding packed public key bytes.
433+
"""
434+
# First unpack the secret key
435+
rho, K, tr, s1, s2, t0 = self._unpack_sk(sk_bytes)
436+
437+
# Compute the matrix A from rho in NTT form
438+
A_hat = self._expand_matrix_from_seed(rho)
439+
440+
# Convert s1 to NTT form
441+
s1_hat = s1.to_ntt()
442+
443+
# Compute the polynomial t, we have the lower bits t0,
444+
# but we need the higher bits t1 for the public key
445+
t = (A_hat @ s1_hat).from_ntt() + s2
446+
t1, _ = t.power_2_round(self.d)
447+
448+
# The packed public key is made from rho || t1
449+
pk_bytes = self._pack_pk(rho, t1)
450+
451+
# Ensure the public key matches the hash within the secret key
452+
if tr != self._h(pk_bytes, 64):
453+
raise ValueError("maleformed secret key")
454+
455+
return pk_bytes
456+
423457
"""
424458
The following external mu functions are not in FIPS 204, but are in
425459
Appendix D of the following IETF draft and are included for experimentation

src/dilithium_py/ml_dsa/pkcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,10 @@ def sk_from_der(enc_key):
266266
if not expanded:
267267
pk, expanded = ml_dsa.key_derive(seed)
268268

269-
if not pk and seed:
270-
# deriving verifying key from expanded key in ML-DSA is non-trivial
271-
# do that only for encodings that include the seed
272-
pk, _ = ml_dsa.key_derive(seed)
269+
if not pk:
270+
# If we reach here, we need to compute the public key
271+
# directly from the secret key bytes
272+
pk = ml_dsa.pk_from_sk(expanded)
273273

274274
return ml_dsa, expanded, seed, pk
275275

tests/test_ml_dsa.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def generic_test_ml_dsa(self, ML_DSA, count=5):
3131
check_wrong_msg = ML_DSA.verify(pk, b"", sig, ctx=ctx)
3232
check_no_ctx = ML_DSA.verify(pk, msg, sig)
3333

34+
# Generate the public key directly from the secret key
35+
recovered_pk = ML_DSA.pk_from_sk(sk)
36+
37+
# Check that recovering the pk works
38+
self.assertEqual(pk, recovered_pk)
39+
3440
# Check that signature works
3541
self.assertTrue(check_verify)
3642

tests/test_pkcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_import_from_expanded(self):
161161
self.assertIs(ml_dsa, self.ml_dsa)
162162
self.assertEqual(sk, self.sk)
163163
self.assertEqual(seed, None)
164-
self.assertEqual(pk, None)
164+
self.assertEqual(pk, self.pk)
165165

166166

167167
@unittest.skipUnless(ECDSA_PRESENT, "requires ecdsa package")
@@ -1006,7 +1006,7 @@ def test_sk_sanity_expanded_only(self):
10061006
self.assertEqual(self.ml_dsa, ml_dsa)
10071007
self.assertEqual(self.sk, expanded)
10081008
self.assertEqual(None, seed)
1009-
self.assertEqual(None, pk)
1009+
self.assertEqual(self.pk, pk)
10101010

10111011
def test_sk_trailing_junk(self):
10121012
enc = (

0 commit comments

Comments
 (0)