@@ -13,6 +13,7 @@ import (
1313 "time"
1414
1515 "github.com/cloudflare/circl/kem/kyber/kyber1024"
16+ "github.com/cloudflare/circl/sign/dilithium/mode5"
1617 "golang.org/x/crypto/blake2s"
1718 "golang.org/x/crypto/chacha20poly1305"
1819 "golang.org/x/crypto/poly1305"
@@ -62,13 +63,13 @@ const (
6263)
6364
6465const (
65- MessageInitiationSize = 148 + (MLKEMCiphertextSize + poly1305 .TagSize ) // size of handshake initiation message
66- MessageResponseSize = 92 // size of response message
67- MessageCookieReplySize = 64 // size of cookie reply message
68- MessageTransportHeaderSize = 16 // size of data preceding content in transport message
69- MessageTransportSize = MessageTransportHeaderSize + poly1305 .TagSize // size of empty transport
70- MessageKeepaliveSize = MessageTransportSize // size of keepalive
71- MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
66+ MessageInitiationSize = 148 + (MLKEMCiphertextSize + poly1305 .TagSize ) + MLDSASignatureSize // size of handshake initiation message
67+ MessageResponseSize = 92 // size of response message
68+ MessageCookieReplySize = 64 // size of cookie reply message
69+ MessageTransportHeaderSize = 16 // size of data preceding content in transport message
70+ MessageTransportSize = MessageTransportHeaderSize + poly1305 .TagSize // size of empty transport
71+ MessageKeepaliveSize = MessageTransportSize // size of keepalive
72+ MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
7273)
7374
7475const (
@@ -90,6 +91,7 @@ type MessageInitiation struct {
9091 Static [NoisePublicKeySize + poly1305 .TagSize ]byte
9192 MLKEM [MLKEMCiphertextSize + poly1305 .TagSize ]byte
9293 Timestamp [tai64n .TimestampSize + poly1305 .TagSize ]byte
94+ Signature MLDSASignature
9395 MAC1 [blake2s .Size128 ]byte
9496 MAC2 [blake2s .Size128 ]byte
9597}
@@ -131,8 +133,9 @@ func (msg *MessageInitiation) unmarshal(b []byte) error {
131133 copy (msg .Static [:], b [8 + len (msg .Ephemeral ):])
132134 copy (msg .MLKEM [:], b [8 + len (msg .Ephemeral )+ len (msg .Static ):])
133135 copy (msg .Timestamp [:], b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM ):])
134- copy (msg .MAC1 [:], b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp ):])
135- copy (msg .MAC2 [:], b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp )+ len (msg .MAC1 ):])
136+ copy (msg .Signature [:], b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp ):])
137+ copy (msg .MAC1 [:], b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp )+ len (msg .Signature ):])
138+ copy (msg .MAC2 [:], b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp )+ len (msg .Signature )+ len (msg .MAC1 ):])
136139
137140 return nil
138141}
@@ -148,8 +151,9 @@ func (msg *MessageInitiation) marshal(b []byte) error {
148151 copy (b [8 + len (msg .Ephemeral ):], msg .Static [:])
149152 copy (b [8 + len (msg .Ephemeral )+ len (msg .Static ):], msg .MLKEM [:])
150153 copy (b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM ):], msg .Timestamp [:])
151- copy (b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp ):], msg .MAC1 [:])
152- copy (b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp )+ len (msg .MAC1 ):], msg .MAC2 [:])
154+ copy (b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp ):], msg .Signature [:])
155+ copy (b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp )+ len (msg .Signature ):], msg .MAC1 [:])
156+ copy (b [8 + len (msg .Ephemeral )+ len (msg .Static )+ len (msg .MLKEM )+ len (msg .Timestamp )+ len (msg .Signature )+ len (msg .MAC1 ):], msg .MAC2 [:])
153157
154158 return nil
155159}
@@ -223,6 +227,7 @@ type Handshake struct {
223227 remoteIndex uint32 // index for sending
224228 remoteStatic NoisePublicKey // long term key
225229 remoteMLKEMStatic MLKEMPublicKey // long term remote ML-KEM static public key
230+ remoteMLDSAStatic MLDSAPublicKey // long term remote ML-DSA static public key
226231 remoteEphemeral NoisePublicKey // ephemeral public key
227232 precomputedStaticStatic [NoisePublicKeySize ]byte // precomputed shared secret
228233 lastTimestamp tai64n.Timestamp
@@ -348,6 +353,21 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
348353
349354 handshake .mixHash (msg .Timestamp [:])
350355 handshake .state = handshakeInitiationCreated
356+
357+ signScheme := mode5 .Scheme ()
358+ skSign , err := signScheme .UnmarshalBinaryPrivateKey (device .staticIdentity .mldsaPrivateKey [:])
359+ if err != nil {
360+ return nil , err
361+ }
362+
363+ messageToSign := make ([]byte , MessageInitiationSize )
364+ if err := msg .marshal (messageToSign ); err != nil {
365+ return nil , err
366+ }
367+
368+ signature := signScheme .Sign (skSign , messageToSign [:MessageInitiationSize - blake2s .Size128 * 2 - MLDSASignatureSize ], nil )
369+ copy (msg .Signature [:], signature )
370+
351371 return & msg , nil
352372}
353373
@@ -391,6 +411,21 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
391411 return nil
392412 }
393413
414+ signScheme := mode5 .Scheme ()
415+ pkSign , err := signScheme .UnmarshalBinaryPublicKey (peer .handshake .remoteMLDSAStatic [:])
416+ if err != nil {
417+ return nil
418+ }
419+
420+ messageToCheck := make ([]byte , MessageInitiationSize )
421+ if err := msg .marshal (messageToCheck ); err != nil {
422+ return nil
423+ }
424+
425+ if ! signScheme .Verify (pkSign , messageToCheck [:MessageInitiationSize - blake2s .Size128 * 2 - MLDSASignatureSize ], msg .Signature [:], nil ) {
426+ return nil
427+ }
428+
394429 handshake := & peer .handshake
395430
396431 // decrypt KEM ciphertext
0 commit comments