Skip to content

Commit e0845a7

Browse files
Merge pull request #6 from disentangle-network/feature/pq-handshake
Hybrid X25519 + ML-KEM-1024 handshake and full PQ integration
2 parents d43cf56 + c55507a commit e0845a7

File tree

15 files changed

+553
-37
lines changed

15 files changed

+553
-37
lines changed

cert/cert.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,20 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
158158
return c, nil
159159
}
160160

161+
// UnmarshalCertificateFromBytes unmarshals a certificate from raw bytes when the full certificate
162+
// (including public key) is already present. This is used for PQ certificates where the public key
163+
// is not stripped during handshake marshaling.
164+
func UnmarshalCertificateFromBytes(b []byte, v Version, curve Curve) (Certificate, error) {
165+
switch v {
166+
case VersionPre1, Version1:
167+
return unmarshalCertificateV1(b, nil)
168+
case Version2:
169+
return unmarshalCertificateV2(b, nil, curve)
170+
default:
171+
return nil, ErrUnknownVersion
172+
}
173+
}
174+
161175
// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates
162176
// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step.
163177
func CalculateAlternateFingerprint(c Certificate) (string, error) {

cert/cert_v2.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,14 @@ func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
239239
pub = privkey.PublicKey().Bytes()
240240
case Curve_PQ:
241241
// Host certs use ML-KEM-1024 key agreement keys
242-
if len(key) != mlkem1024.PrivateKeySize {
242+
var sk mlkem1024.PrivateKey
243+
if err := sk.Unpack(key); err != nil {
244+
return ErrInvalidPrivateKey
245+
}
246+
derivedPub, err := sk.Public().(*mlkem1024.PublicKey).MarshalBinary()
247+
if err != nil {
243248
return ErrInvalidPrivateKey
244249
}
245-
_, sk := mlkem1024.NewKeyFromSeed(key[:mlkem1024.KeySeedSize])
246-
pub, _ = sk.MarshalBinary()
247-
// Compare only the public key portion
248-
derivedPub := pub[:mlkem1024.PublicKeySize]
249250
if !bytes.Equal(derivedPub, c.publicKey) {
250251
return ErrPublicPrivateKeyMismatch
251252
}

cert_test/cert.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"net/netip"
1010
"time"
1111

12+
"github.com/cloudflare/circl/kem/mlkem/mlkem1024"
13+
"github.com/cloudflare/circl/sign/mldsa/mldsa87"
1214
"github.com/slackhq/nebula/cert"
1315
"golang.org/x/crypto/curve25519"
1416
"golang.org/x/crypto/ed25519"
@@ -30,6 +32,13 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
3032

3133
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
3234
priv = privk.D.FillBytes(make([]byte, 32))
35+
case cert.Curve_PQ:
36+
pk, sk, err := mldsa87.GenerateKey(rand.Reader)
37+
if err != nil {
38+
panic(err)
39+
}
40+
pub = pk.Bytes()
41+
priv = sk.Bytes()
3342
default:
3443
// There is no default to allow the underlying lib to respond with an error
3544
}
@@ -84,6 +93,8 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
8493
pub, priv = X25519Keypair()
8594
case cert.Curve_P256:
8695
pub, priv = P256Keypair()
96+
case cert.Curve_PQ:
97+
pub, priv = PQKeypair()
8798
default:
8899
panic("unknown curve")
89100
}
@@ -163,3 +174,13 @@ func P256Keypair() ([]byte, []byte) {
163174
pubkey := privkey.PublicKey()
164175
return pubkey.Bytes(), privkey.Bytes()
165176
}
177+
178+
func PQKeypair() ([]byte, []byte) {
179+
pk, sk, err := mlkem1024.GenerateKeyPair(rand.Reader)
180+
if err != nil {
181+
panic(err)
182+
}
183+
pubBytes, _ := pk.MarshalBinary()
184+
privBytes, _ := sk.MarshalBinary()
185+
return pubBytes, privBytes
186+
}

cmd/nebula-cert/ca.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414
"time"
1515

16+
"github.com/cloudflare/circl/sign/mldsa/mldsa87"
1617
"github.com/skip2/go-qrcode"
1718
"github.com/slackhq/nebula/cert"
1819
"github.com/slackhq/nebula/pkclient"
@@ -59,7 +60,7 @@ func newCaFlags() *caFlags {
5960
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
6061
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
6162
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
62-
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
63+
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA/PQ Curve (25519, P256, PQ)")
6364
cf.p11url = p11Flag(cf.set)
6465

6566
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
@@ -243,11 +244,23 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
243244
}
244245
rawPriv = eKey.Bytes()
245246
pub = eKey.PublicKey().Bytes()
247+
case "PQ":
248+
curve = cert.Curve_PQ
249+
pk, sk, err := mldsa87.GenerateKey(rand.Reader)
250+
if err != nil {
251+
return fmt.Errorf("error while generating ML-DSA-87 keys: %s", err)
252+
}
253+
pub = pk.Bytes()
254+
rawPriv = sk.Bytes()
246255
default:
247256
return fmt.Errorf("invalid curve: %s", *cf.curve)
248257
}
249258
}
250259

260+
if curve == cert.Curve_PQ && version != cert.Version2 {
261+
return fmt.Errorf("PQ curve requires certificate version 2")
262+
}
263+
251264
t := &cert.TBSCertificate{
252265
Version: version,
253266
Name: *cf.name,

cmd/nebula-cert/keygen.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func newKeygenFlags() *keygenFlags {
2424
cf.set.Usage = func() {}
2525
cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
2626
cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
27-
cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
27+
cf.curve = cf.set.String("curve", "25519", "ECDH/KEM Curve (25519, P256, PQ)")
2828
cf.p11url = p11Flag(cf.set)
2929
return &cf
3030
}
@@ -64,6 +64,9 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
6464
case "P256":
6565
pub, rawPriv = p256Keypair()
6666
curve = cert.Curve_P256
67+
case "PQ":
68+
pub, rawPriv = pqKeypair()
69+
curve = cert.Curve_PQ
6770
default:
6871
return fmt.Errorf("invalid curve: %s", *cf.curve)
6972
}

cmd/nebula-cert/sign.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/skip2/go-qrcode"
1616
"github.com/slackhq/nebula/cert"
17+
"github.com/slackhq/nebula/noiseutil"
1718
"github.com/slackhq/nebula/pkclient"
1819
"golang.org/x/crypto/curve25519"
1920
)
@@ -405,6 +406,8 @@ func newKeypair(curve cert.Curve) ([]byte, []byte) {
405406
return x25519Keypair()
406407
case cert.Curve_P256:
407408
return p256Keypair()
409+
case cert.Curve_PQ:
410+
return pqKeypair()
408411
default:
409412
return nil, nil
410413
}
@@ -433,6 +436,14 @@ func p256Keypair() ([]byte, []byte) {
433436
return pubkey.Bytes(), privkey.Bytes()
434437
}
435438

439+
func pqKeypair() ([]byte, []byte) {
440+
pub, priv, err := noiseutil.PQKEMKeypair()
441+
if err != nil {
442+
panic(err)
443+
}
444+
return pub, priv
445+
}
446+
436447
func signSummary() string {
437448
return "sign <flags>: create and sign a certificate"
438449
}

connection_state.go

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ type ConnectionState struct {
2525
messageCounter atomic.Uint64
2626
window *Bits
2727
writeLock sync.Mutex
28+
29+
// Post-quantum KEM state (only used for Curve_PQ)
30+
pqKemPubKey []byte // Our ephemeral ML-KEM-1024 public key
31+
pqKemPrivKey []byte // Our ephemeral ML-KEM-1024 private key
32+
pqKemSS []byte // Shared secret from KEM exchange
2833
}
2934

3035
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
3136
var dhFunc noise.DHFunc
37+
var pqKemPub, pqKemPriv []byte
38+
3239
switch crt.Curve() {
3340
case cert.Curve_CURVE25519:
3441
dhFunc = noise.DH25519
@@ -38,6 +45,19 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
3845
} else {
3946
dhFunc = noiseutil.DHP256
4047
}
48+
case cert.Curve_PQ:
49+
// Hybrid mode: X25519 DH for classical security + ML-KEM-1024 for PQ security.
50+
// The Noise IX handshake uses X25519 for the DH tokens. The ML-KEM-1024
51+
// exchange is layered on top via the handshake payload (KemPublicKey/KemCiphertext).
52+
// Both shared secrets are mixed into the final symmetric keys via HKDF.
53+
dhFunc = noise.DH25519
54+
55+
// Generate ephemeral ML-KEM-1024 keypair for this handshake
56+
var err error
57+
pqKemPub, pqKemPriv, err = noiseutil.PQKEMKeypair()
58+
if err != nil {
59+
return nil, fmt.Errorf("NewConnectionState: ML-KEM-1024 keygen failed: %s", err)
60+
}
4161
default:
4262
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
4363
}
@@ -49,7 +69,24 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
4969
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
5070
}
5171

52-
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
72+
// For PQ hybrid mode, use a fresh X25519 keypair as the Noise "static" key.
73+
// The actual identity authentication comes from the ML-DSA-87 cert signature,
74+
// not from the DH static key. This is safe because:
75+
// 1. The cert is verified against the CA in the handshake payload
76+
// 2. The KEM exchange provides the PQ-secure key agreement
77+
// 3. The X25519 DH provides defense-in-depth
78+
var static noise.DHKey
79+
if crt.Curve() == cert.Curve_PQ {
80+
// Generate ephemeral X25519 keypair (cert's public key is ML-KEM, not X25519)
81+
var err error
82+
static, err = noise.DH25519.GenerateKeypair(rand.Reader)
83+
if err != nil {
84+
return nil, fmt.Errorf("NewConnectionState: X25519 keygen failed: %s", err)
85+
}
86+
} else {
87+
static = noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
88+
}
89+
5390
hs, err := noise.NewHandshakeState(noise.Config{
5491
CipherSuite: ncs,
5592
Random: rand.Reader,
@@ -67,10 +104,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
67104
// The queue and ready params prevent a counter race that would happen when
68105
// sending stored packets and simultaneously accepting new traffic.
69106
ci := &ConnectionState{
70-
H: hs,
71-
initiator: initiator,
72-
window: NewBits(ReplayWindow),
73-
myCert: crt,
107+
H: hs,
108+
initiator: initiator,
109+
window: NewBits(ReplayWindow),
110+
myCert: crt,
111+
pqKemPubKey: pqKemPub,
112+
pqKemPrivKey: pqKemPriv,
74113
}
75114
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
76115
ci.messageCounter.Add(2)

e2e/handshakes_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,48 @@ func TestGoodHandshake(t *testing.T) {
134134
theirControl.Stop()
135135
}
136136

137+
func TestGoodHandshakePQ(t *testing.T) {
138+
// Post-quantum handshake test: ML-DSA-87 certificates + hybrid X25519/ML-KEM-1024 key exchange
139+
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_PQ, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
140+
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me-pq", "10.128.0.1/24", nil)
141+
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them-pq", "10.128.0.2/24", nil)
142+
143+
// Put their info in our lighthouse
144+
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
145+
146+
// Start the servers
147+
myControl.Start()
148+
theirControl.Start()
149+
150+
t.Log("PQ: Send a udp packet through to begin standing up the tunnel")
151+
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from PQ me"))
152+
153+
t.Log("PQ: Have them consume my stage 0 packet. They have a tunnel now")
154+
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
155+
156+
t.Log("PQ: Have me consume their stage 1 packet. I have a tunnel now")
157+
myControl.InjectUDPPacket(theirControl.GetFromUDP(true))
158+
159+
t.Log("PQ: Wait until we see my cached packet come through")
160+
myControl.WaitForType(1, 0, theirControl)
161+
162+
t.Log("PQ: Make sure our host infos are correct")
163+
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
164+
165+
t.Log("PQ: Get that cached packet and make sure it looks right")
166+
myCachedPacket := theirControl.GetFromTun(true)
167+
assertUdpPacket(t, []byte("Hi from PQ me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
168+
169+
t.Log("PQ: Do a bidirectional tunnel test")
170+
r := router.NewR(t, myControl, theirControl)
171+
defer r.RenderFlow()
172+
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
173+
174+
r.RenderHostmaps("PQ Final hostmaps", myControl, theirControl)
175+
myControl.Stop()
176+
theirControl.Stop()
177+
}
178+
137179
func TestGoodHandshakeNoOverlap(t *testing.T) {
138180
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
139181
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)

e2e/helpers_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
102102
}
103103
}
104104

105-
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
105+
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, caCrt.Curve(), caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
106106

107107
caB, err := caCrt.MarshalPEM()
108108
if err != nil {

0 commit comments

Comments
 (0)