99 "github.com/sirupsen/logrus"
1010 "github.com/slackhq/nebula/cert"
1111 "github.com/slackhq/nebula/header"
12+ "github.com/slackhq/nebula/noiseutil"
1213)
1314
1415// NOISE IX Handshakes
@@ -71,6 +72,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
7172 Time : uint64 (time .Now ().UnixNano ()),
7273 Cert : crtHs ,
7374 CertVersion : uint32 (v ),
75+ KemPublicKey : ci .pqKemPubKey , // nil for non-PQ curves (omitted from wire)
7476 },
7577 }
7678
@@ -139,7 +141,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
139141 return
140142 }
141143
142- rc , err := cert .Recombine (cert .Version (hs .Details .CertVersion ), hs .Details .Cert , ci .H .PeerStatic (), ci .Curve ())
144+ // For PQ, the cert's public key is ML-KEM-1024, not the X25519 PeerStatic.
145+ // Pass nil to Recombine so the cert uses its own embedded public key.
146+ var peerStaticForCert []byte
147+ if ci .myCert .Curve () != cert .Curve_PQ {
148+ peerStaticForCert = ci .H .PeerStatic ()
149+ }
150+ rc , err := cert .Recombine (cert .Version (hs .Details .CertVersion ), hs .Details .Cert , peerStaticForCert , ci .Curve ())
143151 if err != nil {
144152 f .l .WithError (err ).WithField ("from" , via ).
145153 WithField ("handshake" , m {"stage" : 1 , "style" : "ix_psk0" }).
@@ -167,11 +175,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
167175 return
168176 }
169177
170- if ! bytes .Equal (remoteCert .Certificate .PublicKey (), ci .H .PeerStatic ()) {
171- f .l .WithField ("from" , via ).
172- WithField ("handshake" , m {"stage" : 1 , "style" : "ix_psk0" }).
173- WithField ("cert" , remoteCert ).Info ("public key mismatch between certificate and handshake" )
174- return
178+ // For PQ certs, the cert's public key is ML-KEM-1024 while PeerStatic is the
179+ // ephemeral X25519 key. Skip the pubkey==PeerStatic check for PQ.
180+ if ci .myCert .Curve () != cert .Curve_PQ {
181+ if ! bytes .Equal (remoteCert .Certificate .PublicKey (), ci .H .PeerStatic ()) {
182+ f .l .WithField ("from" , via ).
183+ WithField ("handshake" , m {"stage" : 1 , "style" : "ix_psk0" }).
184+ WithField ("cert" , remoteCert ).Info ("public key mismatch between certificate and handshake" )
185+ return
186+ }
175187 }
176188
177189 if remoteCert .Certificate .Version () != ci .myCert .Version () {
@@ -289,6 +301,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
289301 // Update the time in case their clock is way off from ours
290302 hs .Details .Time = uint64 (time .Now ().UnixNano ())
291303
304+ // PQ hybrid: if initiator sent a KEM public key, encapsulate to it
305+ if len (hs .Details .KemPublicKey ) > 0 && ci .myCert .Curve () == cert .Curve_PQ {
306+ ct , ss , err := noiseutil .PQKEMEncapsulate (hs .Details .KemPublicKey )
307+ if err != nil {
308+ f .l .WithError (err ).WithField ("from" , via ).
309+ WithField ("handshake" , m {"stage" : 1 , "style" : "ix_psk0" }).
310+ Error ("PQ KEM encapsulation failed" )
311+ return
312+ }
313+ hs .Details .KemCiphertext = ct
314+ ci .pqKemSS = ss
315+ // Clear the initiator's KEM public key from the response (only ciphertext goes back)
316+ hs .Details .KemPublicKey = nil
317+ }
318+
292319 hsBytes , err := hs .Marshal ()
293320 if err != nil {
294321 f .l .WithError (err ).WithField ("vpnAddrs" , hostinfo .vpnAddrs ).WithField ("from" , via ).
@@ -333,8 +360,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
333360 ci .window .Update (f .l , 2 )
334361
335362 ci .peerCert = remoteCert
336- ci .dKey = NewNebulaCipherState (dKey )
337- ci .eKey = NewNebulaCipherState (eKey )
363+
364+ // PQ hybrid key mixing: combine Noise DH keys with KEM shared secret
365+ if len (ci .pqKemSS ) > 0 {
366+ ci .dKey , ci .eKey , err = pqMixCipherStates (dKey , eKey , ci .pqKemSS )
367+ if err != nil {
368+ f .l .WithError (err ).WithField ("from" , via ).
369+ WithField ("handshake" , m {"stage" : 1 , "style" : "ix_psk0" }).
370+ Error ("PQ hybrid key mixing failed" )
371+ return
372+ }
373+ } else {
374+ ci .dKey = NewNebulaCipherState (dKey )
375+ ci .eKey = NewNebulaCipherState (eKey )
376+ }
338377
339378 hostinfo .remotes = f .lightHouse .QueryCache (vpnAddrs )
340379 if ! via .IsRelayed {
@@ -514,7 +553,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
514553 return true
515554 }
516555
517- rc , err := cert .Recombine (cert .Version (hs .Details .CertVersion ), hs .Details .Cert , ci .H .PeerStatic (), ci .Curve ())
556+ // For PQ, the cert's public key is ML-KEM-1024, not the X25519 PeerStatic.
557+ var peerStaticForCert2 []byte
558+ if ci .myCert .Curve () != cert .Curve_PQ {
559+ peerStaticForCert2 = ci .H .PeerStatic ()
560+ }
561+ rc , err := cert .Recombine (cert .Version (hs .Details .CertVersion ), hs .Details .Cert , peerStaticForCert2 , ci .Curve ())
518562 if err != nil {
519563 f .l .WithError (err ).WithField ("from" , via ).
520564 WithField ("vpnAddrs" , hostinfo .vpnAddrs ).
@@ -543,11 +587,15 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
543587 e .Info ("Invalid certificate from host" )
544588 return true
545589 }
546- if ! bytes .Equal (remoteCert .Certificate .PublicKey (), ci .H .PeerStatic ()) {
547- f .l .WithField ("from" , via ).
548- WithField ("handshake" , m {"stage" : 2 , "style" : "ix_psk0" }).
549- WithField ("cert" , remoteCert ).Info ("public key mismatch between certificate and handshake" )
550- return true
590+ // For PQ certs, the cert's public key is ML-KEM-1024 while PeerStatic is the
591+ // ephemeral X25519 key. Skip the pubkey==PeerStatic check for PQ.
592+ if ci .myCert .Curve () != cert .Curve_PQ {
593+ if ! bytes .Equal (remoteCert .Certificate .PublicKey (), ci .H .PeerStatic ()) {
594+ f .l .WithField ("from" , via ).
595+ WithField ("handshake" , m {"stage" : 2 , "style" : "ix_psk0" }).
596+ WithField ("cert" , remoteCert ).Info ("public key mismatch between certificate and handshake" )
597+ return true
598+ }
551599 }
552600
553601 if len (remoteCert .Certificate .Networks ()) == 0 {
@@ -570,8 +618,29 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
570618
571619 // Store their cert and our symmetric keys
572620 ci .peerCert = remoteCert
573- ci .dKey = NewNebulaCipherState (dKey )
574- ci .eKey = NewNebulaCipherState (eKey )
621+
622+ // PQ hybrid: if responder sent a KEM ciphertext, decapsulate and mix keys
623+ if len (hs .Details .KemCiphertext ) > 0 && len (ci .pqKemPrivKey ) > 0 {
624+ ss , kemErr := noiseutil .PQKEMDecapsulate (ci .pqKemPrivKey , hs .Details .KemCiphertext )
625+ if kemErr != nil {
626+ f .l .WithError (kemErr ).WithField ("vpnAddrs" , hostinfo .vpnAddrs ).WithField ("from" , via ).
627+ WithField ("handshake" , m {"stage" : 2 , "style" : "ix_psk0" }).
628+ Error ("PQ KEM decapsulation failed" )
629+ return true
630+ }
631+ ci .pqKemSS = ss
632+
633+ ci .dKey , ci .eKey , err = pqMixCipherStates (dKey , eKey , ci .pqKemSS )
634+ if err != nil {
635+ f .l .WithError (err ).WithField ("vpnAddrs" , hostinfo .vpnAddrs ).WithField ("from" , via ).
636+ WithField ("handshake" , m {"stage" : 2 , "style" : "ix_psk0" }).
637+ Error ("PQ hybrid key mixing failed" )
638+ return true
639+ }
640+ } else {
641+ ci .dKey = NewNebulaCipherState (dKey )
642+ ci .eKey = NewNebulaCipherState (eKey )
643+ }
575644
576645 // Make sure the current udpAddr being used is set for responding
577646 if ! via .IsRelayed {
@@ -676,3 +745,32 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
676745
677746 return false
678747}
748+
749+ // pqMixCipherStates takes the Noise-derived cipher states and mixes them with
750+ // the ML-KEM-1024 shared secret to produce hybrid cipher states. This is the
751+ // NIST-recommended hybrid combiner: both the classical DH key and the PQ KEM
752+ // key must be broken to compromise the session.
753+ //
754+ // The hybrid key is derived via HKDF-SHA256:
755+ // hybridKey = HKDF(noiseKey, kemSharedSecret, "nebula-pq-hybrid-v1")
756+ func pqMixCipherStates (dKey , eKey * noise.CipherState , kemSS []byte ) (* NebulaCipherState , * NebulaCipherState , error ) {
757+ dNoiseKey := dKey .UnsafeKey ()
758+ eNoiseKey := eKey .UnsafeKey ()
759+
760+ dHybrid , err := noiseutil .HybridMixKeys (dNoiseKey , kemSS )
761+ if err != nil {
762+ return nil , nil , err
763+ }
764+ eHybrid , err := noiseutil .HybridMixKeys (eNoiseKey , kemSS )
765+ if err != nil {
766+ return nil , nil , err
767+ }
768+
769+ // Create new cipher states with the hybrid keys.
770+ // We wrap the Noise cipher (extracted via Cipher()) with the hybrid key
771+ // by constructing a fresh NebulaCipherState that uses AES-256-GCM directly.
772+ dCipher := noiseutil .CipherAESGCM .Cipher (dHybrid )
773+ eCipher := noiseutil .CipherAESGCM .Cipher (eHybrid )
774+
775+ return & NebulaCipherState {c : dCipher }, & NebulaCipherState {c : eCipher }, nil
776+ }
0 commit comments