Skip to content

Commit 56281cc

Browse files
authored
Merge pull request #22 from GiacomoPope/prehash_mldsa
Include Hash ML-DSA
2 parents 3d0a299 + 0105b90 commit 56281cc

File tree

7 files changed

+324
-4
lines changed

7 files changed

+324
-4
lines changed

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,33 @@ so you can simply import the NIST level you want to play with:
144144
The above example would also work with the other NIST levels
145145
`ML_DSA_65` and `ML_DSA_87`.
146146

147+
#### Hash ML-DSA
148+
149+
Following algorithms 4 and 5 of FIPS 204 we also include a version of pre-hash ML-DSA which hashes the message before signing it using SHA512 by default for
150+
all three security levels. This is used in much the same way as ML-DSA:
151+
152+
```python
153+
>>> from dilithium_py.ml_dsa import HASH_ML_DSA_44_WITH_SHA512
154+
>>>
155+
>>> # Example of signing
156+
>>> pk, sk = HASH_ML_DSA_44_WITH_SHA512.keygen()
157+
>>> msg = b"Your message signed by ML_DSA"
158+
>>> sig = HASH_ML_DSA_44_WITH_SHA512.sign(sk, msg)
159+
>>> assert HASH_ML_DSA_44_WITH_SHA512.verify(pk, msg, sig)
160+
>>>
161+
>>> # Verification will fail with the wrong msg or pk
162+
>>> assert not HASH_ML_DSA_44_WITH_SHA512.verify(pk, b"", sig)
163+
>>> pk_new, sk_new = HASH_ML_DSA_44_WITH_SHA512.keygen()
164+
>>> assert not HASH_ML_DSA_44_WITH_SHA512.verify(pk_new, msg, sig)
165+
```
166+
167+
There is also support for other hash functions (at the time, only SHA256 and SHAKE128), but there seem to only be OIDs for the pre-hash version using SHA512
168+
so this is what is included. To access signing with other hash functions the methods are `HASH_ML_DSA_44_WITH_SHA512._sign_with_pre_hash` and
169+
`HASH_ML_DSA_44_WITH_SHA512._verify_with_pre_hash`. For more information see the
170+
implementation and comments in `hash_ml_dsa.py`.
171+
172+
The pre-hash version of ML-DSA has purposefully been added to a child class of ML-DSA as the signatures which are produced between these variants are incompatible.
173+
147174
### Benchmarks
148175

149176
Some very rough benchmarks to give an idea about performance:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "dilithium-py"
7-
version = "1.1.0"
7+
version = "1.3.0"
88
requires-python = ">= 3.9"
99
description = "A pure python implementation of ML-DSA (FIPS 204)"
1010
readme = "README.md"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="dilithium-py",
5-
version="1.1.0",
5+
version="1.3.0",
66
python_requires=">=3.9",
77
description="A pure python implementation of ML-DSA (FIPS 204)",
88
long_description=open("README.md").read(),
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1-
from .default_parameters import ML_DSA_44, ML_DSA_65, ML_DSA_87
1+
from .default_parameters import (
2+
ML_DSA_44,
3+
ML_DSA_65,
4+
ML_DSA_87,
5+
HASH_ML_DSA_44_WITH_SHA512,
6+
HASH_ML_DSA_65_WITH_SHA512,
7+
HASH_ML_DSA_87_WITH_SHA512,
8+
)
29

3-
__all__ = ["ML_DSA_44", "ML_DSA_65", "ML_DSA_87"]
10+
__all__ = [
11+
"ML_DSA_44",
12+
"ML_DSA_65",
13+
"ML_DSA_87",
14+
"HASH_ML_DSA_44_WITH_SHA512",
15+
"HASH_ML_DSA_65_WITH_SHA512",
16+
"HASH_ML_DSA_87_WITH_SHA512",
17+
]

src/dilithium_py/ml_dsa/default_parameters.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .ml_dsa import ML_DSA
2+
from .hash_ml_dsa import HashML_DSA
3+
24

35
DEFAULT_PARAMETERS = {
46
"ML_DSA_44": {
@@ -37,8 +39,48 @@
3739
"c_tilde_bytes": 64,
3840
"oid": (2, 16, 840, 1, 101, 3, 4, 3, 19),
3941
},
42+
"HASH_ML_DSA_44_WITH_SHA512": {
43+
"d": 13, # number of bits dropped from t
44+
"tau": 39, # number of ±1 in c
45+
"gamma_1": 131072, # coefficient range of y: 2^17
46+
"gamma_2": 95232, # low order rounding range: (q-1)/88
47+
"k": 4, # Dimensions of A = (k, l)
48+
"l": 4, # Dimensions of A = (k, l)
49+
"eta": 2, # Private key range
50+
"omega": 80, # Max number of ones in hint
51+
"c_tilde_bytes": 32,
52+
"oid": (2, 16, 840, 1, 101, 3, 4, 3, 32),
53+
},
54+
"HASH_ML_DSA_65_WITH_SHA512": {
55+
"d": 13, # number of bits dropped from t
56+
"tau": 49, # number of ±1 in c
57+
"gamma_1": 524288, # coefficient range of y: 2^19
58+
"gamma_2": 261888, # low order rounding range: (q-1)/32
59+
"k": 6, # Dimensions of A = (k, l)
60+
"l": 5, # Dimensions of A = (k, l)
61+
"eta": 4, # Private key range
62+
"omega": 55, # Max number of ones in hint
63+
"c_tilde_bytes": 48,
64+
"oid": (2, 16, 840, 1, 101, 3, 4, 3, 33),
65+
},
66+
"HASH_ML_DSA_87_WITH_SHA512": {
67+
"d": 13, # number of bits dropped from t
68+
"tau": 60, # number of ±1 in c
69+
"gamma_1": 524288, # coefficient range of y: 2^19
70+
"gamma_2": 261888, # low order rounding range: (q-1)/32
71+
"k": 8, # Dimensions of A = (k, l)
72+
"l": 7, # Dimensions of A = (k, l)
73+
"eta": 2, # Private key range
74+
"omega": 75, # Max number of ones in hint
75+
"c_tilde_bytes": 64,
76+
"oid": (2, 16, 840, 1, 101, 3, 4, 3, 34),
77+
},
4078
}
4179

4280
ML_DSA_44 = ML_DSA(DEFAULT_PARAMETERS["ML_DSA_44"])
4381
ML_DSA_65 = ML_DSA(DEFAULT_PARAMETERS["ML_DSA_65"])
4482
ML_DSA_87 = ML_DSA(DEFAULT_PARAMETERS["ML_DSA_87"])
83+
84+
HASH_ML_DSA_44_WITH_SHA512 = HashML_DSA(DEFAULT_PARAMETERS["ML_DSA_44"])
85+
HASH_ML_DSA_65_WITH_SHA512 = HashML_DSA(DEFAULT_PARAMETERS["ML_DSA_65"])
86+
HASH_ML_DSA_87_WITH_SHA512 = HashML_DSA(DEFAULT_PARAMETERS["ML_DSA_87"])
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from .ml_dsa import ML_DSA
2+
from hashlib import sha256, sha512, shake_128
3+
4+
5+
class HashML_DSA(ML_DSA):
6+
def _hash_with_oid(self, m: bytes, hash_name: str) -> tuple[bytes, bytes]:
7+
hash_name = hash_name.upper()
8+
9+
if hash_name == "SHA256":
10+
oid = bytes(
11+
[0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01]
12+
)
13+
ph_m = sha256(m).digest()
14+
elif hash_name == "SHA512":
15+
oid = bytes(
16+
[0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03]
17+
)
18+
ph_m = sha512(m).digest()
19+
elif hash_name == "SHAKE128":
20+
oid = bytes(
21+
[0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0B]
22+
)
23+
ph_m = shake_128(m).digest(32)
24+
else:
25+
raise ValueError(f"unsupported hash algorithm: {hash_name}")
26+
27+
return oid, ph_m
28+
29+
def _sign_with_pre_hash(
30+
self,
31+
sk: bytes,
32+
m: bytes,
33+
hash_name: str,
34+
ctx: bytes = b"",
35+
deterministic: bool = False,
36+
) -> bytes:
37+
"""
38+
Generates a HashML-DSA signature following Algorithm 4 (FIPS 204)
39+
40+
The hash name is a string which selects the pre-hash function and
41+
can currently be SHA256, SHA512 or SHAKE128.
42+
"""
43+
if len(ctx) > 255:
44+
raise ValueError(
45+
f"ctx bytes must have length at most 255, ctx has length {len(ctx) = }"
46+
)
47+
48+
if deterministic:
49+
rnd = bytes([0] * 32)
50+
else:
51+
rnd = self.random_bytes(32)
52+
53+
# Prehash the message and return the OID of the used hash
54+
oid, ph_m = self._hash_with_oid(m, hash_name)
55+
56+
# Format the message using the context
57+
m_prime = bytes([1]) + bytes([len(ctx)]) + ctx + oid + ph_m
58+
59+
# Compute the signature of m_prime
60+
sig_bytes = self._sign_internal(sk, m_prime, rnd)
61+
return sig_bytes
62+
63+
def _verify_with_pre_hash(
64+
self, pk: bytes, m: bytes, sig: bytes, hash_name: str, ctx: bytes = b""
65+
) -> bool:
66+
"""
67+
Verifies a signature sigma for a message M following algorithm 5 (FIPS 204)
68+
"""
69+
if len(ctx) > 255:
70+
raise ValueError(
71+
f"ctx bytes must have length at most 255, ctx has length {len(ctx) = }"
72+
)
73+
74+
# Prehash the message and return the OID of the used hash
75+
oid, ph_m = self._hash_with_oid(m, hash_name)
76+
77+
# Format the message using the context
78+
m_prime = bytes([1]) + bytes([len(ctx)]) + ctx + oid + ph_m
79+
80+
return self._verify_internal(pk, m_prime, sig)
81+
82+
def sign(
83+
self,
84+
sk: bytes,
85+
m: bytes,
86+
ctx: bytes = b"",
87+
deterministic: bool = False,
88+
) -> bytes:
89+
"""
90+
Generates a HashML-DSA signature following Algorithm 4 (FIPS 204)
91+
with SHA512 as the chosen hash function.
92+
"""
93+
return self._sign_with_pre_hash(sk, m, "SHA512", ctx, deterministic)
94+
95+
def verify(self, pk: bytes, m: bytes, sig: bytes, ctx: bytes = b"") -> bool:
96+
"""
97+
Verifies a signature sigma for a message M following algorithm 5 (FIPS 204)
98+
with SHA512 as the chosen hash function.
99+
"""
100+
return self._verify_with_pre_hash(pk, m, sig, "SHA512", ctx)

tests/test_hash_ml_dsa.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import unittest
2+
import os
3+
from dilithium_py.ml_dsa import (
4+
HASH_ML_DSA_44_WITH_SHA512,
5+
HASH_ML_DSA_65_WITH_SHA512,
6+
HASH_ML_DSA_87_WITH_SHA512,
7+
)
8+
9+
10+
class TestHashMLDSA(unittest.TestCase):
11+
"""
12+
Test ML DSA for internal consistency by generating signatures
13+
and verifying them!
14+
"""
15+
16+
def generic_test_hash_ml_dsa(self, HASH_ML_DSA, hash_name="SHA512", count=5):
17+
for _ in range(count):
18+
msg = b"Signed by HASH_ML_DSA" + os.urandom(16)
19+
ctx = os.urandom(128)
20+
21+
# Perform signature process
22+
pk, sk = HASH_ML_DSA.keygen()
23+
sig = HASH_ML_DSA.sign(sk, msg, ctx=ctx)
24+
check_verify = HASH_ML_DSA.verify(pk, msg, sig, ctx=ctx)
25+
26+
# Generate some fail cases
27+
pk_bad, _ = HASH_ML_DSA.keygen()
28+
check_wrong_pk = HASH_ML_DSA.verify(pk_bad, msg, sig, ctx=ctx)
29+
check_wrong_msg = HASH_ML_DSA.verify(pk, b"", sig, ctx=ctx)
30+
check_no_ctx = HASH_ML_DSA.verify(pk, msg, sig)
31+
32+
# Check with user-supplied hashes
33+
hash_sig = HASH_ML_DSA._sign_with_pre_hash(sk, msg, hash_name, ctx=ctx)
34+
check_hash_verify = HASH_ML_DSA._verify_with_pre_hash(
35+
pk, msg, hash_sig, hash_name, ctx=ctx
36+
)
37+
38+
# Check with the wrong hashes
39+
if hash_name == "SHA512":
40+
bad_hash = "SHA256"
41+
else:
42+
bad_hash = "SHA512"
43+
check_wrong_hash = HASH_ML_DSA._verify_with_pre_hash(
44+
pk, msg, hash_sig, bad_hash, ctx=ctx
45+
)
46+
47+
# Check that signature works
48+
self.assertTrue(check_verify)
49+
50+
# Check that signature works with custom hash
51+
self.assertTrue(check_hash_verify)
52+
53+
# Ensure the hashes need to match
54+
self.assertFalse(check_wrong_hash)
55+
56+
# Check changing the key breaks verify
57+
self.assertFalse(check_wrong_pk)
58+
59+
# Check changing the message breaks verify
60+
self.assertFalse(check_wrong_msg)
61+
62+
# Check removing the context breaks verify
63+
self.assertFalse(check_no_ctx)
64+
65+
# Default hash is SHA512
66+
def test_hash_ml_dsa_44(self):
67+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_44_WITH_SHA512)
68+
69+
def test_hash_ml_dsa_65(self):
70+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_65_WITH_SHA512)
71+
72+
def test_hash_ml_dsa_87(self):
73+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_87_WITH_SHA512)
74+
75+
# Test with SHA256
76+
def test_hash_ml_dsa_44_sha256(self):
77+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_44_WITH_SHA512, "SHA256")
78+
79+
def test_hash_ml_dsa_65_sha256(self):
80+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_65_WITH_SHA512, "SHA256")
81+
82+
def test_hash_ml_dsa_87_sha256(self):
83+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_87_WITH_SHA512, "SHA256")
84+
85+
# Test with SHAKE128
86+
def test_hash_ml_dsa_44_shake128(self):
87+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_44_WITH_SHA512, "SHAKE128")
88+
89+
def test_hash_ml_dsa_65_shake128(self):
90+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_65_WITH_SHA512, "SHAKE128")
91+
92+
def test_hash_ml_dsa_87_shake128(self):
93+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_87_WITH_SHA512, "SHAKE128")
94+
95+
96+
class TestHashMLDSADeterministic(unittest.TestCase):
97+
"""
98+
Test ML DSA for internal consistency by generating signatures
99+
and verifying them!
100+
"""
101+
102+
def generic_test_hash_ml_dsa(self, HASH_ML_DSA, count=5):
103+
for _ in range(count):
104+
msg = b"Signed by HASH_ML_DSA" + os.urandom(16)
105+
ctx = os.urandom(128)
106+
107+
# Perform signature process
108+
pk, sk = HASH_ML_DSA.keygen()
109+
sig = HASH_ML_DSA.sign(sk, msg, ctx=ctx, deterministic=True)
110+
check_verify = HASH_ML_DSA.verify(pk, msg, sig, ctx=ctx)
111+
112+
# Generate some fail cases
113+
pk_bad, _ = HASH_ML_DSA.keygen()
114+
check_wrong_pk = HASH_ML_DSA.verify(pk_bad, msg, sig, ctx=ctx)
115+
check_wrong_msg = HASH_ML_DSA.verify(pk, b"", sig, ctx=ctx)
116+
check_no_ctx = HASH_ML_DSA.verify(pk, msg, sig)
117+
118+
# Check that signature works
119+
self.assertTrue(check_verify)
120+
121+
# Check changing the key breaks verify
122+
self.assertFalse(check_wrong_pk)
123+
124+
# Check changing the message breaks verify
125+
self.assertFalse(check_wrong_msg)
126+
127+
# Check removing the context breaks verify
128+
self.assertFalse(check_no_ctx)
129+
130+
def test_hash_ml_dsa_44(self):
131+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_44_WITH_SHA512)
132+
133+
def test_hash_ml_dsa_65(self):
134+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_65_WITH_SHA512)
135+
136+
def test_hash_ml_dsa_87(self):
137+
self.generic_test_hash_ml_dsa(HASH_ML_DSA_87_WITH_SHA512)

0 commit comments

Comments
 (0)