Skip to content

Commit 008c0af

Browse files
committed
quic: refactor keys for key updates
Refactor how we store encryption keys in preparation for adding support for key updates. Previously, we had a single "keys" type containing header and packet protection key material. With key update, the 1-RTT header protection keys are consistent across the lifetime of a connection, while packet protection keys vary. Separate out the header and packet protection keys into distinct types. Add "fixed" key types for keys which remain fixed across a connection's lifetime and do not update. For the moment, 1-RTT keys are still fixed. Remove a number of can-never-happen error returns from key handling paths. We were previously inconsistent about where to panic and where to return an error on these paths; we now consistently panic in paths where errors can only occur due to a bug. (For example, attempting to create an AEAD with an incorrect secret size.) No functional changes, this is purely refactoring. For golang/go#58547 Change-Id: I49f83091517186e452845b65a1597add60e5fc92 Reviewed-on: https://go-review.googlesource.com/c/net/+/529155 Reviewed-by: Jonathan Amsterdam <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 6a4de22 commit 008c0af

File tree

10 files changed

+384
-280
lines changed

10 files changed

+384
-280
lines changed

internal/quic/conn.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ type Conn struct {
4242
idleTimeout time.Time
4343

4444
// Packet protection keys, CRYPTO streams, and TLS state.
45-
rkeys [numberSpaceCount]keys
46-
wkeys [numberSpaceCount]keys
47-
crypto [numberSpaceCount]cryptoStream
48-
tls *tls.QUICConn
45+
keysInitial fixedKeyPair
46+
keysHandshake fixedKeyPair
47+
keysAppData fixedKeyPair
48+
crypto [numberSpaceCount]cryptoStream
49+
tls *tls.QUICConn
4950

5051
// handshakeConfirmed is set when the handshake is confirmed.
5152
// For server connections, it tracks sending HANDSHAKE_DONE.
@@ -156,8 +157,12 @@ func (c *Conn) confirmHandshake(now time.Time) {
156157
// discardKeys discards unused packet protection keys.
157158
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9
158159
func (c *Conn) discardKeys(now time.Time, space numberSpace) {
159-
c.rkeys[space].discard()
160-
c.wkeys[space].discard()
160+
switch space {
161+
case initialSpace:
162+
c.keysInitial.discard()
163+
case handshakeSpace:
164+
c.keysHandshake.discard()
165+
}
161166
c.loss.discardKeys(now, space)
162167
}
163168

internal/quic/conn_recv.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) {
2626
// https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4
2727
return
2828
}
29-
n = c.handleLongHeader(now, ptype, initialSpace, buf)
29+
n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf)
3030
case packetTypeHandshake:
31-
n = c.handleLongHeader(now, ptype, handshakeSpace, buf)
31+
n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf)
3232
case packetType1RTT:
3333
n = c.handle1RTT(now, buf)
3434
default:
@@ -43,13 +43,13 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) {
4343
}
4444
}
4545

46-
func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, buf []byte) int {
47-
if !c.rkeys[space].isSet() {
46+
func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
47+
if !k.isSet() {
4848
return skipLongHeaderPacket(buf)
4949
}
5050

5151
pnumMax := c.acks[space].largestSeen()
52-
p, n := parseLongHeaderPacket(buf, c.rkeys[space], pnumMax)
52+
p, n := parseLongHeaderPacket(buf, k, pnumMax)
5353
if n < 0 {
5454
return -1
5555
}
@@ -82,14 +82,14 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
8282
}
8383

8484
func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
85-
if !c.rkeys[appDataSpace].isSet() {
85+
if !c.keysAppData.canRead() {
8686
// 1-RTT packets extend to the end of the datagram,
8787
// so skip the remainder of the datagram if we can't parse this.
8888
return len(buf)
8989
}
9090

9191
pnumMax := c.acks[appDataSpace].largestSeen()
92-
p, n := parse1RTTPacket(buf, c.rkeys[appDataSpace], connIDLen, pnumMax)
92+
p, n := parse1RTTPacket(buf, c.keysAppData.r, connIDLen, pnumMax)
9393
if n < 0 {
9494
return -1
9595
}

internal/quic/conn_send.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
5959
// Initial packet.
6060
pad := false
6161
var sentInitial *sentPacket
62-
if k := c.wkeys[initialSpace]; k.isSet() {
62+
if c.keysInitial.canWrite() {
6363
pnumMaxAcked := c.acks[initialSpace].largestSeen()
6464
pnum := c.loss.nextNumber(initialSpace)
6565
p := longPacket{
@@ -74,7 +74,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
7474
if logPackets {
7575
logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload())
7676
}
77-
sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p)
77+
sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p)
7878
if sentInitial != nil {
7979
// Client initial packets need to be sent in a datagram padded to
8080
// at least 1200 bytes. We can't add the padding yet, however,
@@ -86,7 +86,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
8686
}
8787

8888
// Handshake packet.
89-
if k := c.wkeys[handshakeSpace]; k.isSet() {
89+
if c.keysHandshake.canWrite() {
9090
pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
9191
pnum := c.loss.nextNumber(handshakeSpace)
9292
p := longPacket{
@@ -101,7 +101,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
101101
if logPackets {
102102
logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload())
103103
}
104-
if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p); sent != nil {
104+
if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil {
105105
c.loss.packetSent(now, handshakeSpace, sent)
106106
if c.side == clientSide {
107107
// "[...] a client MUST discard Initial keys when it first
@@ -113,7 +113,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
113113
}
114114

115115
// 1-RTT packet.
116-
if k := c.wkeys[appDataSpace]; k.isSet() {
116+
if c.keysAppData.canWrite() {
117117
pnumMaxAcked := c.acks[appDataSpace].largestSeen()
118118
pnum := c.loss.nextNumber(appDataSpace)
119119
c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID)
@@ -128,7 +128,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
128128
if logPackets {
129129
logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
130130
}
131-
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, k); sent != nil {
131+
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, c.keysAppData.w); sent != nil {
132132
c.loss.packetSent(now, appDataSpace, sent)
133133
}
134134
}
@@ -157,7 +157,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
157157
sentInitial.inFlight = true
158158
}
159159
}
160-
if k := c.wkeys[initialSpace]; k.isSet() {
160+
// If we're a client and this Initial packet is coalesced
161+
// with a Handshake packet, then we've discarded Initial keys
162+
// since constructing the packet and shouldn't record it as in-flight.
163+
if c.keysInitial.canWrite() {
161164
c.loss.packetSent(now, initialSpace, sentInitial)
162165
}
163166
}

internal/quic/conn_test.go

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,18 @@ type testConn struct {
113113
timerLastFired time.Time
114114
idlec chan struct{} // only accessed on the conn's loop
115115

116-
// Read and write keys are distinct from the conn's keys,
116+
// Keys are distinct from the conn's keys,
117117
// because the test may know about keys before the conn does.
118118
// For example, when sending a datagram with coalesced
119119
// Initial and Handshake packets to a client conn,
120120
// we use Handshake keys to encrypt the packet.
121121
// The client only acquires those keys when it processes
122122
// the Initial packet.
123-
rkeys [numberSpaceCount]keyData // for packets sent to the conn
124-
wkeys [numberSpaceCount]keyData // for packets sent by the conn
123+
keysInitial fixedKeyPair
124+
keysHandshake fixedKeyPair
125+
keysAppData fixedKeyPair
126+
rsecrets [numberSpaceCount]testKeySecret
127+
wsecrets [numberSpaceCount]testKeySecret
125128

126129
// testConn uses a test hook to snoop on the conn's TLS events.
127130
// CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +159,9 @@ type testConn struct {
156159
asyncTestState
157160
}
158161

159-
type keyData struct {
162+
type testKeySecret struct {
160163
suite uint16
161164
secret []byte
162-
k keys
163165
}
164166

165167
// newTestConn creates a Conn for testing.
@@ -225,8 +227,8 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
225227
}
226228
tc.conn = conn
227229

228-
tc.wkeys[initialSpace].k = conn.wkeys[initialSpace]
229-
tc.rkeys[initialSpace].k = conn.rkeys[initialSpace]
230+
tc.keysInitial.r = conn.keysInitial.w
231+
tc.keysInitial.w = conn.keysInitial.r
230232

231233
tc.wait()
232234
return tc
@@ -611,22 +613,30 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
611613
for _, f := range p.frames {
612614
f.write(&w)
613615
}
614-
space := spaceForPacketType(p.ptype)
615-
if !tc.rkeys[space].k.isSet() {
616-
tc.t.Fatalf("sending packet with no %v keys available", space)
617-
return nil
618-
}
619616
w.appendPaddingTo(pad)
620617
if p.ptype != packetType1RTT {
621-
w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{
618+
var k fixedKeyPair
619+
switch p.ptype {
620+
case packetTypeInitial:
621+
k = tc.keysInitial
622+
case packetTypeHandshake:
623+
k = tc.keysHandshake
624+
}
625+
if !k.canWrite() {
626+
tc.t.Fatalf("sending %v packet with no write key", p.ptype)
627+
}
628+
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k.w, longPacket{
622629
ptype: p.ptype,
623630
version: p.version,
624631
num: p.num,
625632
dstConnID: p.dstConnID,
626633
srcConnID: p.srcConnID,
627634
})
628635
} else {
629-
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space].k)
636+
if !tc.keysAppData.canWrite() {
637+
tc.t.Fatalf("sending %v packet with no write key", p.ptype)
638+
}
639+
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.keysAppData.w)
630640
}
631641
return w.datagram()
632642
}
@@ -642,13 +652,19 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
642652
break
643653
}
644654
ptype := getPacketType(buf)
645-
space := spaceForPacketType(ptype)
646-
if !tc.wkeys[space].k.isSet() {
647-
tc.t.Fatalf("no keys for space %v, packet type %v", space, ptype)
648-
}
649655
if isLongHeader(buf[0]) {
656+
var k fixedKeyPair
657+
switch ptype {
658+
case packetTypeInitial:
659+
k = tc.keysInitial
660+
case packetTypeHandshake:
661+
k = tc.keysHandshake
662+
}
663+
if !k.canRead() {
664+
tc.t.Fatalf("reading %v packet with no read key", ptype)
665+
}
650666
var pnumMax packetNumber // TODO: Track packet numbers.
651-
p, n := parseLongHeaderPacket(buf, tc.wkeys[space].k, pnumMax)
667+
p, n := parseLongHeaderPacket(buf, k.r, pnumMax)
652668
if n < 0 {
653669
tc.t.Fatalf("packet parse error")
654670
}
@@ -666,8 +682,11 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
666682
})
667683
buf = buf[n:]
668684
} else {
685+
if !tc.keysAppData.canRead() {
686+
tc.t.Fatalf("reading 1-RTT packet with no read key")
687+
}
669688
var pnumMax packetNumber // TODO: Track packet numbers.
670-
p, n := parse1RTTPacket(buf, tc.wkeys[space].k, len(tc.peerConnID), pnumMax)
689+
p, n := parse1RTTPacket(buf, tc.keysAppData.r, len(tc.peerConnID), pnumMax)
671690
if n < 0 {
672691
tc.t.Fatalf("packet parse error")
673692
}
@@ -747,12 +766,7 @@ type testConnHooks testConn
747766
// and verify that both sides of the connection are getting
748767
// matching keys.
749768
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
750-
setKey := func(keys *[numberSpaceCount]keyData, e tls.QUICEvent) {
751-
k, err := newKeys(e.Suite, e.Data)
752-
if err != nil {
753-
tc.t.Errorf("newKeys: %v", err)
754-
return
755-
}
769+
checkKey := func(typ string, secrets *[numberSpaceCount]testKeySecret, e tls.QUICEvent) {
756770
var space numberSpace
757771
switch {
758772
case e.Level == tls.QUICEncryptionLevelHandshake:
@@ -763,25 +777,30 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
763777
tc.t.Errorf("unexpected encryption level %v", e.Level)
764778
return
765779
}
766-
s := "read"
767-
if keys == &tc.wkeys {
768-
s = "write"
769-
}
770-
if keys[space].k.isSet() {
771-
if keys[space].suite != e.Suite || !bytes.Equal(keys[space].secret, e.Data) {
772-
tc.t.Errorf("%v key mismatch for level for level %v", s, e.Level)
773-
}
774-
return
780+
if secrets[space].secret == nil {
781+
secrets[space].suite = e.Suite
782+
secrets[space].secret = append([]byte{}, e.Data...)
783+
} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
784+
tc.t.Errorf("%v key mismatch for level %v", typ, e.Level)
775785
}
776-
keys[space].suite = e.Suite
777-
keys[space].secret = append([]byte{}, e.Data...)
778-
keys[space].k = k
779786
}
780787
switch e.Kind {
781788
case tls.QUICSetReadSecret:
782-
setKey(&tc.rkeys, e)
789+
checkKey("read", &tc.rsecrets, e)
790+
switch e.Level {
791+
case tls.QUICEncryptionLevelHandshake:
792+
tc.keysHandshake.w.init(e.Suite, e.Data)
793+
case tls.QUICEncryptionLevelApplication:
794+
tc.keysAppData.w.init(e.Suite, e.Data)
795+
}
783796
case tls.QUICSetWriteSecret:
784-
setKey(&tc.wkeys, e)
797+
checkKey("write", &tc.wsecrets, e)
798+
switch e.Level {
799+
case tls.QUICEncryptionLevelHandshake:
800+
tc.keysHandshake.r.init(e.Suite, e.Data)
801+
case tls.QUICEncryptionLevelApplication:
802+
tc.keysAppData.r.init(e.Suite, e.Data)
803+
}
785804
case tls.QUICWriteData:
786805
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
787806
tc.peerTLSConn.HandleData(e.Level, e.Data)
@@ -792,9 +811,21 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
792811
case tls.QUICNoEvent:
793812
return
794813
case tls.QUICSetReadSecret:
795-
setKey(&tc.wkeys, e)
814+
checkKey("write", &tc.wsecrets, e)
815+
switch e.Level {
816+
case tls.QUICEncryptionLevelHandshake:
817+
tc.keysHandshake.r.init(e.Suite, e.Data)
818+
case tls.QUICEncryptionLevelApplication:
819+
tc.keysAppData.r.init(e.Suite, e.Data)
820+
}
796821
case tls.QUICSetWriteSecret:
797-
setKey(&tc.rkeys, e)
822+
checkKey("read", &tc.rsecrets, e)
823+
switch e.Level {
824+
case tls.QUICEncryptionLevelHandshake:
825+
tc.keysHandshake.w.init(e.Suite, e.Data)
826+
case tls.QUICEncryptionLevelApplication:
827+
tc.keysAppData.w.init(e.Suite, e.Data)
828+
}
798829
case tls.QUICWriteData:
799830
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
800831
case tls.QUICTransportParameters:

0 commit comments

Comments
 (0)