Skip to content

Commit 9c4ae29

Browse files
committed
Check hint format in unpack_h method
This changes add four checks for strong unforgeability 1. Accumulated offsets are monotonic increasing 2. Last offset is less than omega 3. Zefo fields should be all zero 4. Non-zero fields should be monotonic increasing
1 parent 22f0d83 commit 9c4ae29

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,28 @@ def _unpack_sk(self, sk_bytes):
163163

164164
def _unpack_h(self, h_bytes):
165165
offsets = [0] + list(h_bytes[-self.k :])
166+
# check offsets are monotonic increasing
167+
if any(offsets[i] > offsets[i + 1] for i in range(len(offsets) - 1)):
168+
raise ValueError("Offsets in h_bytes are not monotonic increasing")
169+
# check offset[-1] is smaller than the length of h_bytes
170+
if offsets[-1] > self.omega:
171+
raise ValueError("Accumulate offset of hints exceeds omega")
172+
# check zero fields are all zeros
173+
if any(b != 0 for b in h_bytes[offsets[-1] : self.omega]):
174+
raise ValueError("Non-zero fields in h_bytes are not all zeros")
175+
166176
non_zero_positions = [
167177
list(h_bytes[offsets[i] : offsets[i + 1]]) for i in range(self.k)
168178
]
169179

170180
matrix = []
171181
for poly_non_zero in non_zero_positions:
172182
coeffs = [0 for _ in range(256)]
173-
for non_zero in poly_non_zero:
183+
for i, non_zero in enumerate(poly_non_zero):
184+
if i > 0 and non_zero < poly_non_zero[i - 1]:
185+
raise ValueError(
186+
"Non-zero positions in h_bytes are not monotonic increasing"
187+
)
174188
coeffs[non_zero] = 1
175189
matrix.append([self.R(coeffs)])
176190
return self.M(matrix)
@@ -294,7 +308,11 @@ def _verify_internal(self, pk_bytes, m, sig_bytes):
294308
following Algorithm 8 (FIPS 204)
295309
"""
296310
rho, t1 = self._unpack_pk(pk_bytes)
297-
c_tilde, z, h = self._unpack_sig(sig_bytes)
311+
try:
312+
c_tilde, z, h = self._unpack_sig(sig_bytes)
313+
except ValueError:
314+
# verify failed if malformed input signature
315+
return False
298316

299317
if h.sum_hint() > self.omega:
300318
return False

0 commit comments

Comments
 (0)