Skip to content

Commit c55507a

Browse files
lcloselclose
authored andcommitted
feat: nebula-cert PQ support, PKI integration, and e2e handshake test
Complete the PQ integration across the full Nebula stack: CLI (nebula-cert): - ca.go: -curve PQ generates ML-DSA-87 CA keypair (V2 only) - sign.go: PQ host certs use ML-KEM-1024 key agreement keys - keygen.go: -curve PQ generates ML-KEM-1024 keypair PKI integration: - pki.go: PQ certs marshal with full public key (not stripped) because MarshalForHandshakes assumes pubkey == Noise PeerStatic - cert.go: UnmarshalCertificateFromBytes for PQ cert reconstruction - cert_v2.go: fix VerifyPrivateKey to use Unpack for ML-KEM-1024 Handshake fixes: - handshake_ix.go: PQ path uses direct unmarshal instead of Recombine (cert pubkey is ML-KEM, not the X25519 PeerStatic) - Skip pubkey==PeerStatic check for PQ (different key types) Test infrastructure: - cert_test/cert.go: PQ cases for NewTestCaCert and NewTestCert - e2e/helpers_test.go: derive curve from CA cert (not hardcoded) - e2e/handshakes_test.go: TestGoodHandshakePQ -- full tunnel test TestGoodHandshakePQ proves: PQ CA signs host cert, two nodes establish tunnel via hybrid X25519+ML-KEM-1024 handshake, bidirectional encrypted data transfer works. 0.41s handshake time. Zero regressions in classical handshake tests.
1 parent d694afe commit c55507a

File tree

10 files changed

+151
-28
lines changed

10 files changed

+151
-28
lines changed

cert/cert.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,20 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
158158
return c, nil
159159
}
160160

161+
// UnmarshalCertificateFromBytes unmarshals a certificate from raw bytes when the full certificate
162+
// (including public key) is already present. This is used for PQ certificates where the public key
163+
// is not stripped during handshake marshaling.
164+
func UnmarshalCertificateFromBytes(b []byte, v Version, curve Curve) (Certificate, error) {
165+
switch v {
166+
case VersionPre1, Version1:
167+
return unmarshalCertificateV1(b, nil)
168+
case Version2:
169+
return unmarshalCertificateV2(b, nil, curve)
170+
default:
171+
return nil, ErrUnknownVersion
172+
}
173+
}
174+
161175
// CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates
162176
// CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step.
163177
func CalculateAlternateFingerprint(c Certificate) (string, error) {

cert/cert_v2.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,14 @@ func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
239239
pub = privkey.PublicKey().Bytes()
240240
case Curve_PQ:
241241
// Host certs use ML-KEM-1024 key agreement keys
242-
if len(key) != mlkem1024.PrivateKeySize {
242+
var sk mlkem1024.PrivateKey
243+
if err := sk.Unpack(key); err != nil {
244+
return ErrInvalidPrivateKey
245+
}
246+
derivedPub, err := sk.Public().(*mlkem1024.PublicKey).MarshalBinary()
247+
if err != nil {
243248
return ErrInvalidPrivateKey
244249
}
245-
_, sk := mlkem1024.NewKeyFromSeed(key[:mlkem1024.KeySeedSize])
246-
pub, _ = sk.MarshalBinary()
247-
// Compare only the public key portion
248-
derivedPub := pub[:mlkem1024.PublicKeySize]
249250
if !bytes.Equal(derivedPub, c.publicKey) {
250251
return ErrPublicPrivateKeyMismatch
251252
}

cert_test/cert.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"net/netip"
1010
"time"
1111

12+
"github.com/cloudflare/circl/kem/mlkem/mlkem1024"
13+
"github.com/cloudflare/circl/sign/mldsa/mldsa87"
1214
"github.com/slackhq/nebula/cert"
1315
"golang.org/x/crypto/curve25519"
1416
"golang.org/x/crypto/ed25519"
@@ -30,6 +32,13 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
3032

3133
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
3234
priv = privk.D.FillBytes(make([]byte, 32))
35+
case cert.Curve_PQ:
36+
pk, sk, err := mldsa87.GenerateKey(rand.Reader)
37+
if err != nil {
38+
panic(err)
39+
}
40+
pub = pk.Bytes()
41+
priv = sk.Bytes()
3342
default:
3443
// There is no default to allow the underlying lib to respond with an error
3544
}
@@ -84,6 +93,8 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
8493
pub, priv = X25519Keypair()
8594
case cert.Curve_P256:
8695
pub, priv = P256Keypair()
96+
case cert.Curve_PQ:
97+
pub, priv = PQKeypair()
8798
default:
8899
panic("unknown curve")
89100
}
@@ -163,3 +174,13 @@ func P256Keypair() ([]byte, []byte) {
163174
pubkey := privkey.PublicKey()
164175
return pubkey.Bytes(), privkey.Bytes()
165176
}
177+
178+
func PQKeypair() ([]byte, []byte) {
179+
pk, sk, err := mlkem1024.GenerateKeyPair(rand.Reader)
180+
if err != nil {
181+
panic(err)
182+
}
183+
pubBytes, _ := pk.MarshalBinary()
184+
privBytes, _ := sk.MarshalBinary()
185+
return pubBytes, privBytes
186+
}

cmd/nebula-cert/ca.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414
"time"
1515

16+
"github.com/cloudflare/circl/sign/mldsa/mldsa87"
1617
"github.com/skip2/go-qrcode"
1718
"github.com/slackhq/nebula/cert"
1819
"github.com/slackhq/nebula/pkclient"
@@ -59,7 +60,7 @@ func newCaFlags() *caFlags {
5960
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
6061
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
6162
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
62-
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
63+
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA/PQ Curve (25519, P256, PQ)")
6364
cf.p11url = p11Flag(cf.set)
6465

6566
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
@@ -243,11 +244,23 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
243244
}
244245
rawPriv = eKey.Bytes()
245246
pub = eKey.PublicKey().Bytes()
247+
case "PQ":
248+
curve = cert.Curve_PQ
249+
pk, sk, err := mldsa87.GenerateKey(rand.Reader)
250+
if err != nil {
251+
return fmt.Errorf("error while generating ML-DSA-87 keys: %s", err)
252+
}
253+
pub = pk.Bytes()
254+
rawPriv = sk.Bytes()
246255
default:
247256
return fmt.Errorf("invalid curve: %s", *cf.curve)
248257
}
249258
}
250259

260+
if curve == cert.Curve_PQ && version != cert.Version2 {
261+
return fmt.Errorf("PQ curve requires certificate version 2")
262+
}
263+
251264
t := &cert.TBSCertificate{
252265
Version: version,
253266
Name: *cf.name,

cmd/nebula-cert/keygen.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func newKeygenFlags() *keygenFlags {
2424
cf.set.Usage = func() {}
2525
cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
2626
cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
27-
cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
27+
cf.curve = cf.set.String("curve", "25519", "ECDH/KEM Curve (25519, P256, PQ)")
2828
cf.p11url = p11Flag(cf.set)
2929
return &cf
3030
}
@@ -64,6 +64,9 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
6464
case "P256":
6565
pub, rawPriv = p256Keypair()
6666
curve = cert.Curve_P256
67+
case "PQ":
68+
pub, rawPriv = pqKeypair()
69+
curve = cert.Curve_PQ
6770
default:
6871
return fmt.Errorf("invalid curve: %s", *cf.curve)
6972
}

cmd/nebula-cert/sign.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/skip2/go-qrcode"
1616
"github.com/slackhq/nebula/cert"
17+
"github.com/slackhq/nebula/noiseutil"
1718
"github.com/slackhq/nebula/pkclient"
1819
"golang.org/x/crypto/curve25519"
1920
)
@@ -405,6 +406,8 @@ func newKeypair(curve cert.Curve) ([]byte, []byte) {
405406
return x25519Keypair()
406407
case cert.Curve_P256:
407408
return p256Keypair()
409+
case cert.Curve_PQ:
410+
return pqKeypair()
408411
default:
409412
return nil, nil
410413
}
@@ -433,6 +436,14 @@ func p256Keypair() ([]byte, []byte) {
433436
return pubkey.Bytes(), privkey.Bytes()
434437
}
435438

439+
func pqKeypair() ([]byte, []byte) {
440+
pub, priv, err := noiseutil.PQKEMKeypair()
441+
if err != nil {
442+
panic(err)
443+
}
444+
return pub, priv
445+
}
446+
436447
func signSummary() string {
437448
return "sign <flags>: create and sign a certificate"
438449
}

e2e/handshakes_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,48 @@ func TestGoodHandshake(t *testing.T) {
134134
theirControl.Stop()
135135
}
136136

137+
func TestGoodHandshakePQ(t *testing.T) {
138+
// Post-quantum handshake test: ML-DSA-87 certificates + hybrid X25519/ML-KEM-1024 key exchange
139+
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_PQ, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
140+
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me-pq", "10.128.0.1/24", nil)
141+
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them-pq", "10.128.0.2/24", nil)
142+
143+
// Put their info in our lighthouse
144+
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
145+
146+
// Start the servers
147+
myControl.Start()
148+
theirControl.Start()
149+
150+
t.Log("PQ: Send a udp packet through to begin standing up the tunnel")
151+
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from PQ me"))
152+
153+
t.Log("PQ: Have them consume my stage 0 packet. They have a tunnel now")
154+
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
155+
156+
t.Log("PQ: Have me consume their stage 1 packet. I have a tunnel now")
157+
myControl.InjectUDPPacket(theirControl.GetFromUDP(true))
158+
159+
t.Log("PQ: Wait until we see my cached packet come through")
160+
myControl.WaitForType(1, 0, theirControl)
161+
162+
t.Log("PQ: Make sure our host infos are correct")
163+
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
164+
165+
t.Log("PQ: Get that cached packet and make sure it looks right")
166+
myCachedPacket := theirControl.GetFromTun(true)
167+
assertUdpPacket(t, []byte("Hi from PQ me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
168+
169+
t.Log("PQ: Do a bidirectional tunnel test")
170+
r := router.NewR(t, myControl, theirControl)
171+
defer r.RenderFlow()
172+
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
173+
174+
r.RenderHostmaps("PQ Final hostmaps", myControl, theirControl)
175+
myControl.Stop()
176+
theirControl.Stop()
177+
}
178+
137179
func TestGoodHandshakeNoOverlap(t *testing.T) {
138180
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
139181
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)

e2e/helpers_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
102102
}
103103
}
104104

105-
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
105+
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, caCrt.Curve(), caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
106106

107107
caB, err := caCrt.MarshalPEM()
108108
if err != nil {

handshake_ix.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,27 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
141141
return
142142
}
143143

144-
// For PQ, the cert's public key is ML-KEM-1024, not the X25519 PeerStatic.
145-
// Pass nil to Recombine so the cert uses its own embedded public key.
146-
var peerStaticForCert []byte
147-
if ci.myCert.Curve() != cert.Curve_PQ {
148-
peerStaticForCert = ci.H.PeerStatic()
149-
}
150-
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, peerStaticForCert, ci.Curve())
151-
if err != nil {
152-
f.l.WithError(err).WithField("from", via).
153-
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
154-
Info("Handshake did not contain a certificate")
155-
return
144+
var rc cert.Certificate
145+
if ci.myCert.Curve() == cert.Curve_PQ {
146+
// PQ certs include the full public key in the handshake bytes (not stripped).
147+
// Unmarshal directly instead of using Recombine, which expects PeerStatic.
148+
var unmarshalErr error
149+
rc, unmarshalErr = cert.UnmarshalCertificateFromBytes(hs.Details.Cert, cert.Version(hs.Details.CertVersion), ci.Curve())
150+
if unmarshalErr != nil {
151+
f.l.WithError(unmarshalErr).WithField("from", via).
152+
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
153+
Info("Handshake did not contain a valid PQ certificate")
154+
return
155+
}
156+
} else {
157+
var recombineErr error
158+
rc, recombineErr = cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
159+
if recombineErr != nil {
160+
f.l.WithError(recombineErr).WithField("from", via).
161+
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
162+
Info("Handshake did not contain a certificate")
163+
return
164+
}
156165
}
157166

158167
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
@@ -553,12 +562,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
553562
return true
554563
}
555564

556-
// For PQ, the cert's public key is ML-KEM-1024, not the X25519 PeerStatic.
557-
var peerStaticForCert2 []byte
558-
if ci.myCert.Curve() != cert.Curve_PQ {
559-
peerStaticForCert2 = ci.H.PeerStatic()
565+
var rc cert.Certificate
566+
if ci.myCert.Curve() == cert.Curve_PQ {
567+
rc, err = cert.UnmarshalCertificateFromBytes(hs.Details.Cert, cert.Version(hs.Details.CertVersion), ci.Curve())
568+
} else {
569+
rc, err = cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
560570
}
561-
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, peerStaticForCert2, ci.Curve())
562571
if err != nil {
563572
f.l.WithError(err).WithField("from", via).
564573
WithField("vpnAddrs", hostinfo.vpnAddrs).

pki.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,18 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
402402
}
403403
}
404404

405-
v2hs, err := v2.MarshalForHandshakes()
406-
if err != nil {
407-
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
405+
var v2hs []byte
406+
var v2hsErr error
407+
if v2.Curve() == cert.Curve_PQ {
408+
// PQ certs must include the full public key (ML-KEM-1024) in handshake bytes
409+
// because MarshalForHandshakes strips it, and the Noise PeerStatic (X25519)
410+
// is different from the cert's key agreement key.
411+
v2hs, v2hsErr = v2.Marshal()
412+
} else {
413+
v2hs, v2hsErr = v2.MarshalForHandshakes()
414+
}
415+
if v2hsErr != nil {
416+
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", v2hsErr)
408417
}
409418
cs.v2Cert = v2
410419
cs.v2HandshakeBytes = v2hs

0 commit comments

Comments
 (0)