8
8
9
9
10
10
class ML_DSA :
11
- def __init__ (self , parameter_set ):
11
+ def __init__ (self , parameter_set : dict ):
12
12
self .d = parameter_set ["d" ]
13
13
self .k = parameter_set ["k" ]
14
14
self .l = parameter_set ["l" ]
@@ -28,7 +28,7 @@ def __init__(self, parameter_set):
28
28
# use the method `set_drbg_seed()`
29
29
self .random_bytes = os .urandom
30
30
31
- def set_drbg_seed (self , seed ):
31
+ def set_drbg_seed (self , seed : bytes ):
32
32
"""
33
33
Change entropy source to a DRBG and seed it with provided value.
34
34
@@ -69,7 +69,7 @@ def _expand_matrix_from_seed(self, rho):
69
69
Helper function which generates a element of size
70
70
k x l from a seed `rho`.
71
71
"""
72
- A_data = [[0 for _ in range (self .l )] for _ in range (self .k )]
72
+ A_data = [[self . R . zero () for _ in range (self .l )] for _ in range (self .k )]
73
73
for i in range (self .k ):
74
74
for j in range (self .l ):
75
75
A_data [i ][j ] = self .R .rejection_sample_ntt_poly (rho , i , j )
@@ -357,16 +357,16 @@ def _verify_internal(self, pk_bytes, m, sig_bytes):
357
357
358
358
return c_tilde == self ._h (mu + w_prime_bytes , self .c_tilde_bytes )
359
359
360
- def keygen (self ):
360
+ def keygen (self ) -> tuple [ bytes , bytes ] :
361
361
"""
362
362
Generates a public-private key pair following
363
363
Algorithm 1 (FIPS 204)
364
364
"""
365
365
zeta = self .random_bytes (32 )
366
366
pk , sk = self ._keygen_internal (zeta )
367
- return pk , sk
367
+ return ( pk , sk )
368
368
369
- def key_derive (self , seed ) :
369
+ def key_derive (self , seed : bytes ) -> tuple [ bytes , bytes ] :
370
370
"""
371
371
Derive a verification key and corresponding signing key
372
372
following the approach from Section 6.1 (FIPS 204)
@@ -383,7 +383,9 @@ def key_derive(self, seed):
383
383
pk , sk = self ._keygen_internal (seed )
384
384
return (pk , sk )
385
385
386
- def sign (self , sk_bytes , m , ctx = b"" , deterministic = False ):
386
+ def sign (
387
+ self , sk : bytes , m : bytes , ctx : bytes = b"" , deterministic : bool = False
388
+ ) -> bytes :
387
389
"""
388
390
Generates an ML-DSA signature following
389
391
Algorithm 2 (FIPS 204)
@@ -402,10 +404,10 @@ def sign(self, sk_bytes, m, ctx=b"", deterministic=False):
402
404
m_prime = bytes ([0 ]) + bytes ([len (ctx )]) + ctx + m
403
405
404
406
# Compute the signature of m_prime
405
- sig_bytes = self ._sign_internal (sk_bytes , m_prime , rnd )
407
+ sig_bytes = self ._sign_internal (sk , m_prime , rnd )
406
408
return sig_bytes
407
409
408
- def verify (self , pk_bytes , m , sig_bytes , ctx = b"" ):
410
+ def verify (self , pk : bytes , m : bytes , sig : bytes , ctx : bytes = b"" ) -> bool :
409
411
"""
410
412
Verifies a signature sigma for a message M following
411
413
Algorithm 3 (FIPS 204)
@@ -418,21 +420,21 @@ def verify(self, pk_bytes, m, sig_bytes, ctx=b""):
418
420
# Format the message using the context
419
421
m_prime = bytes ([0 ]) + bytes ([len (ctx )]) + ctx + m
420
422
421
- return self ._verify_internal (pk_bytes , m_prime , sig_bytes )
423
+ return self ._verify_internal (pk , m_prime , sig )
422
424
423
425
"""
424
426
The following additional function follows an outline from:
425
427
https://github.com/aws/aws-lc/pull/2142
426
428
which computes pk_bytes when only the sk_bytes are known.
427
429
"""
428
430
429
- def pk_from_sk (self , sk_bytes : bytes ) -> bytes :
431
+ def pk_from_sk (self , sk : bytes ) -> bytes :
430
432
"""
431
433
Given the packed representation of a ML-DSA secret key,
432
434
compute the corresponding packed public key bytes.
433
435
"""
434
436
# First unpack the secret key
435
- rho , K , tr , s1 , s2 , t0 = self ._unpack_sk (sk_bytes )
437
+ rho , _ , tr , s1 , s2 , _ = self ._unpack_sk (sk )
436
438
437
439
# Compute the matrix A from rho in NTT form
438
440
A_hat = self ._expand_matrix_from_seed (rho )
@@ -446,13 +448,13 @@ def pk_from_sk(self, sk_bytes: bytes) -> bytes:
446
448
t1 , _ = t .power_2_round (self .d )
447
449
448
450
# The packed public key is made from rho || t1
449
- pk_bytes = self ._pack_pk (rho , t1 )
451
+ pk = self ._pack_pk (rho , t1 )
450
452
451
453
# Ensure the public key matches the hash within the secret key
452
- if tr != self ._h (pk_bytes , 64 ):
454
+ if tr != self ._h (pk , 64 ):
453
455
raise ValueError ("malformed secret key" )
454
456
455
- return pk_bytes
457
+ return pk
456
458
457
459
"""
458
460
The following external mu functions are not in FIPS 204, but are in
@@ -462,7 +464,7 @@ def pk_from_sk(self, sk_bytes: bytes) -> bytes:
462
464
https://datatracker.ietf.org/doc/html/draft-ietf-lamps-dilithium-certificates-07
463
465
"""
464
466
465
- def prehash_external_mu (self , pk_bytes , m , ctx = b"" ):
467
+ def prehash_external_mu (self , pk : bytes , m : bytes , ctx : bytes = b"" ) -> bytes :
466
468
"""
467
469
Prehash the message `m` with context `ctx` together with
468
470
the public key. For use with `sign_external_mu()`
@@ -472,22 +474,24 @@ def prehash_external_mu(self, pk_bytes, m, ctx=b""):
472
474
raise ValueError (
473
475
f"ctx bytes must have length at most 255, ctx has length { len (ctx ) = } "
474
476
)
475
- if len (pk_bytes ) != self ._pk_size ():
477
+ if len (pk ) != self ._pk_size ():
476
478
raise ValueError (
477
479
f"Public key size doesn't match this ML-DSA object parameters,"
478
- f"received { len (pk_bytes ) = } , expected: { self ._pk_size ()} "
480
+ f"received { len (pk ) = } , expected: { self ._pk_size ()} "
479
481
)
480
482
481
483
# Format the message using the context
482
484
m_prime = bytes ([0 ]) + bytes ([len (ctx )]) + ctx + m
483
485
484
486
# Compute mu by hashing the public key into the message
485
- tr = self ._h (pk_bytes , 64 )
487
+ tr = self ._h (pk , 64 )
486
488
mu = self ._h (tr + m_prime , 64 )
487
489
488
490
return mu
489
491
490
- def sign_external_mu (self , sk_bytes , mu , deterministic = False ):
492
+ def sign_external_mu (
493
+ self , sk : bytes , mu : bytes , deterministic : bool = False
494
+ ) -> bytes :
491
495
"""
492
496
Generates an ML-DSA signature of a message given the prehash
493
497
mu = H(H(pk), M')
@@ -505,5 +509,5 @@ def sign_external_mu(self, sk_bytes, mu, deterministic=False):
505
509
506
510
# Compute the signature given external mu, we set the external_mu
507
511
# to True
508
- sig_bytes = self ._sign_internal (sk_bytes , mu , rnd , external_mu = True )
509
- return sig_bytes
512
+ sig = self ._sign_internal (sk , mu , rnd , external_mu = True )
513
+ return sig
0 commit comments