Skip to content

Commit 18f2095

Browse files
committed
quic: handle peer-initiated key updates
RFC 9001, Section 6. For golang/go#58547 Change-Id: I3700043d27ab41536521b547ecf5e632a08eb1b5 Reviewed-on: https://go-review.googlesource.com/c/net/+/528835 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]>
1 parent 008c0af commit 18f2095

File tree

10 files changed

+472
-67
lines changed

10 files changed

+472
-67
lines changed

internal/quic/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type Conn struct {
4444
// Packet protection keys, CRYPTO streams, and TLS state.
4545
keysInitial fixedKeyPair
4646
keysHandshake fixedKeyPair
47-
keysAppData fixedKeyPair
47+
keysAppData updatingKeyPair
4848
crypto [numberSpaceCount]cryptoStream
4949
tls *tls.QUICConn
5050

internal/quic/conn_recv.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
8989
}
9090

9191
pnumMax := c.acks[appDataSpace].largestSeen()
92-
p, n := parse1RTTPacket(buf, c.keysAppData.r, connIDLen, pnumMax)
92+
p, n := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax)
9393
if n < 0 {
9494
return -1
9595
}
@@ -247,7 +247,7 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
247247

248248
func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int {
249249
c.loss.receiveAckStart()
250-
_, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
250+
largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
251251
if end > c.loss.nextNumber(space) {
252252
// Acknowledgement of a packet we never sent.
253253
c.abort(now, localTransportError(errProtocolViolation))
@@ -280,6 +280,9 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte)
280280
delay = ackDelay.Duration(uint8(c.peerAckDelayExponent))
281281
}
282282
c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss)
283+
if space == appDataSpace {
284+
c.keysAppData.handleAckFor(largest)
285+
}
283286
return n
284287
}
285288

internal/quic/conn_send.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
128128
if logPackets {
129129
logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
130130
}
131-
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, c.keysAppData.w); sent != nil {
131+
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
132132
c.loss.packetSent(now, appDataSpace, sent)
133133
}
134134
}
@@ -197,16 +197,23 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber,
197197
// All frames other than ACK and PADDING are ack-eliciting,
198198
// so if the packet is ack-eliciting we've added additional
199199
// frames to it.
200-
if shouldSendAck || c.w.sent.ackEliciting {
201-
// Either we are willing to send an ACK-only packet,
202-
// or we've added additional frames.
203-
c.acks[space].sentAck()
204-
} else {
200+
if !shouldSendAck && !c.w.sent.ackEliciting {
205201
// There's nothing in this packet but ACK frames, and
206202
// we don't want to send an ACK-only packet at this time.
207203
// Abandoning the packet means we wrote an ACK frame for
208204
// nothing, but constructing the frame is cheap.
209205
c.w.abandonPacket()
206+
return
207+
}
208+
// Either we are willing to send an ACK-only packet,
209+
// or we've added additional frames.
210+
c.acks[space].sentAck()
211+
if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() {
212+
// The peer has initiated a key update.
213+
// We haven't sent them any packets yet in the new phase.
214+
// Make this an ack-eliciting packet.
215+
// Their ack of this packet will complete the key update.
216+
c.w.appendPingFrame()
210217
}
211218
}()
212219
}

internal/quic/conn_test.go

Lines changed: 128 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ func (d testDatagram) String() string {
7676
}
7777

7878
type testPacket struct {
79-
ptype packetType
80-
version uint32
81-
num packetNumber
82-
dstConnID []byte
83-
srcConnID []byte
84-
frames []debugFrame
79+
ptype packetType
80+
version uint32
81+
num packetNumber
82+
keyPhaseBit bool
83+
keyNumber int
84+
dstConnID []byte
85+
srcConnID []byte
86+
frames []debugFrame
8587
}
8688

8789
func (p testPacket) String() string {
@@ -102,6 +104,9 @@ func (p testPacket) String() string {
102104
return b.String()
103105
}
104106

107+
// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
108+
const maxTestKeyPhases = 3
109+
105110
// A testConn is a Conn whose external interactions (sending and receiving packets,
106111
// setting timers) can be manipulated in tests.
107112
type testConn struct {
@@ -122,9 +127,10 @@ type testConn struct {
122127
// the Initial packet.
123128
keysInitial fixedKeyPair
124129
keysHandshake fixedKeyPair
125-
keysAppData fixedKeyPair
126-
rsecrets [numberSpaceCount]testKeySecret
127-
wsecrets [numberSpaceCount]testKeySecret
130+
rkeyAppData test1RTTKeys
131+
wkeyAppData test1RTTKeys
132+
rsecrets [numberSpaceCount]keySecret
133+
wsecrets [numberSpaceCount]keySecret
128134

129135
// testConn uses a test hook to snoop on the conn's TLS events.
130136
// CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +162,19 @@ type testConn struct {
156162
// Frame types to ignore in tests.
157163
ignoreFrames map[byte]bool
158164

165+
// Values to set in packets sent to the conn.
166+
sendKeyNumber int
167+
sendKeyPhaseBit bool
168+
159169
asyncTestState
160170
}
161171

162-
type testKeySecret struct {
172+
type test1RTTKeys struct {
173+
hdr headerKey
174+
pkt [maxTestKeyPhases]packetKey
175+
}
176+
177+
type keySecret struct {
163178
suite uint16
164179
secret []byte
165180
}
@@ -333,12 +348,20 @@ func (tc *testConn) logDatagram(text string, d *testDatagram) {
333348
}
334349
tc.t.Logf("%v datagram%v", text, pad)
335350
for _, p := range d.packets {
351+
var s string
336352
switch p.ptype {
337353
case packetType1RTT:
338-
tc.t.Logf(" %v pnum=%v", p.ptype, p.num)
354+
s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num)
339355
default:
340-
tc.t.Logf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
356+
s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
357+
}
358+
if p.keyPhaseBit {
359+
s += fmt.Sprintf(" KeyPhase")
341360
}
361+
if p.keyNumber != 0 {
362+
s += fmt.Sprintf(" keynum=%v", p.keyNumber)
363+
}
364+
tc.t.Log(s)
342365
for _, f := range p.frames {
343366
tc.t.Logf(" %v", f)
344367
}
@@ -381,12 +404,14 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
381404
}
382405
d := &testDatagram{
383406
packets: []*testPacket{{
384-
ptype: ptype,
385-
num: tc.peerNextPacketNum[space],
386-
frames: frames,
387-
version: 1,
388-
dstConnID: dstConnID,
389-
srcConnID: tc.peerConnID,
407+
ptype: ptype,
408+
num: tc.peerNextPacketNum[space],
409+
keyNumber: tc.sendKeyNumber,
410+
keyPhaseBit: tc.sendKeyPhaseBit,
411+
frames: frames,
412+
version: 1,
413+
dstConnID: dstConnID,
414+
srcConnID: tc.peerConnID,
390415
}},
391416
}
392417
if ptype == packetTypeInitial && tc.conn.side == serverSide {
@@ -580,6 +605,22 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu
580605
}
581606
}
582607

608+
// wantFrameType indicates that we expect the Conn to send a frame,
609+
// although we don't care about the contents.
610+
func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
611+
tc.t.Helper()
612+
got, gotType := tc.readFrame()
613+
if got == nil {
614+
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
615+
}
616+
if gotType != wantType {
617+
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
618+
}
619+
if reflect.TypeOf(got) != reflect.TypeOf(want) {
620+
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want)
621+
}
622+
}
623+
583624
// wantIdle indicates that we expect the Conn to not send any more frames.
584625
func (tc *testConn) wantIdle(expectation string) {
585626
tc.t.Helper()
@@ -615,28 +656,42 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
615656
}
616657
w.appendPaddingTo(pad)
617658
if p.ptype != packetType1RTT {
618-
var k fixedKeyPair
659+
var k fixedKeys
619660
switch p.ptype {
620661
case packetTypeInitial:
621-
k = tc.keysInitial
662+
k = tc.keysInitial.w
622663
case packetTypeHandshake:
623-
k = tc.keysHandshake
664+
k = tc.keysHandshake.w
624665
}
625-
if !k.canWrite() {
666+
if !k.isSet() {
626667
tc.t.Fatalf("sending %v packet with no write key", p.ptype)
627668
}
628-
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k.w, longPacket{
669+
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
629670
ptype: p.ptype,
630671
version: p.version,
631672
num: p.num,
632673
dstConnID: p.dstConnID,
633674
srcConnID: p.srcConnID,
634675
})
635676
} else {
636-
if !tc.keysAppData.canWrite() {
637-
tc.t.Fatalf("sending %v packet with no write key", p.ptype)
677+
if !tc.wkeyAppData.hdr.isSet() {
678+
tc.t.Fatalf("sending 1-RTT packet with no write key")
679+
}
680+
// Somewhat hackish: Generate a temporary updatingKeyPair that will
681+
// always use our desired key phase.
682+
k := &updatingKeyPair{
683+
w: updatingKeys{
684+
hdr: tc.wkeyAppData.hdr,
685+
pkt: [2]packetKey{
686+
tc.wkeyAppData.pkt[p.keyNumber],
687+
tc.wkeyAppData.pkt[p.keyNumber],
688+
},
689+
},
690+
}
691+
if p.keyPhaseBit {
692+
k.phase |= keyPhaseBit
638693
}
639-
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.keysAppData.w)
694+
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
640695
}
641696
return w.datagram()
642697
}
@@ -682,25 +737,45 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
682737
})
683738
buf = buf[n:]
684739
} else {
685-
if !tc.keysAppData.canRead() {
740+
if !tc.rkeyAppData.hdr.isSet() {
686741
tc.t.Fatalf("reading 1-RTT packet with no read key")
687742
}
688743
var pnumMax packetNumber // TODO: Track packet numbers.
689-
p, n := parse1RTTPacket(buf, tc.keysAppData.r, len(tc.peerConnID), pnumMax)
690-
if n < 0 {
691-
tc.t.Fatalf("packet parse error")
744+
pnumOff := 1 + len(tc.peerConnID)
745+
// Try unprotecting the packet with the first maxTestKeyPhases keys.
746+
var phase int
747+
var pnum packetNumber
748+
var hdr []byte
749+
var pay []byte
750+
var err error
751+
for phase = 0; phase < maxTestKeyPhases; phase++ {
752+
b := append([]byte{}, buf...)
753+
hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
754+
if err != nil {
755+
tc.t.Fatalf("1-RTT packet header parse error")
756+
}
757+
k := tc.rkeyAppData.pkt[phase]
758+
pay, err = k.unprotect(hdr, pay, pnum)
759+
if err == nil {
760+
break
761+
}
692762
}
693-
frames, err := tc.parseTestFrames(p.payload)
763+
if err != nil {
764+
tc.t.Fatalf("1-RTT packet payload parse error")
765+
}
766+
frames, err := tc.parseTestFrames(pay)
694767
if err != nil {
695768
tc.t.Fatal(err)
696769
}
697770
d.packets = append(d.packets, &testPacket{
698-
ptype: packetType1RTT,
699-
num: p.num,
700-
dstConnID: buf[1:][:len(tc.peerConnID)],
701-
frames: frames,
771+
ptype: packetType1RTT,
772+
num: pnum,
773+
dstConnID: hdr[1:][:len(tc.peerConnID)],
774+
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
775+
keyNumber: phase,
776+
frames: frames,
702777
})
703-
buf = buf[n:]
778+
buf = buf[len(buf):]
704779
}
705780
}
706781
// This is rather hackish: If the last frame in the last packet
@@ -766,7 +841,7 @@ type testConnHooks testConn
766841
// and verify that both sides of the connection are getting
767842
// matching keys.
768843
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
769-
checkKey := func(typ string, secrets *[numberSpaceCount]testKeySecret, e tls.QUICEvent) {
844+
checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
770845
var space numberSpace
771846
switch {
772847
case e.Level == tls.QUICEncryptionLevelHandshake:
@@ -781,25 +856,32 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
781856
secrets[space].suite = e.Suite
782857
secrets[space].secret = append([]byte{}, e.Data...)
783858
} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
784-
tc.t.Errorf("%v key mismatch for level %v", typ, e.Level)
859+
tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
860+
}
861+
}
862+
setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
863+
k.hdr.init(suite, secret)
864+
for i := 0; i < len(k.pkt); i++ {
865+
k.pkt[i].init(suite, secret)
866+
secret = updateSecret(suite, secret)
785867
}
786868
}
787869
switch e.Kind {
788870
case tls.QUICSetReadSecret:
789-
checkKey("read", &tc.rsecrets, e)
871+
checkKey("write", &tc.wsecrets, e)
790872
switch e.Level {
791873
case tls.QUICEncryptionLevelHandshake:
792874
tc.keysHandshake.w.init(e.Suite, e.Data)
793875
case tls.QUICEncryptionLevelApplication:
794-
tc.keysAppData.w.init(e.Suite, e.Data)
876+
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
795877
}
796878
case tls.QUICSetWriteSecret:
797-
checkKey("write", &tc.wsecrets, e)
879+
checkKey("read", &tc.rsecrets, e)
798880
switch e.Level {
799881
case tls.QUICEncryptionLevelHandshake:
800882
tc.keysHandshake.r.init(e.Suite, e.Data)
801883
case tls.QUICEncryptionLevelApplication:
802-
tc.keysAppData.r.init(e.Suite, e.Data)
884+
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
803885
}
804886
case tls.QUICWriteData:
805887
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
@@ -811,20 +893,20 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
811893
case tls.QUICNoEvent:
812894
return
813895
case tls.QUICSetReadSecret:
814-
checkKey("write", &tc.wsecrets, e)
896+
checkKey("write", &tc.rsecrets, e)
815897
switch e.Level {
816898
case tls.QUICEncryptionLevelHandshake:
817899
tc.keysHandshake.r.init(e.Suite, e.Data)
818900
case tls.QUICEncryptionLevelApplication:
819-
tc.keysAppData.r.init(e.Suite, e.Data)
901+
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
820902
}
821903
case tls.QUICSetWriteSecret:
822-
checkKey("read", &tc.rsecrets, e)
904+
checkKey("read", &tc.wsecrets, e)
823905
switch e.Level {
824906
case tls.QUICEncryptionLevelHandshake:
825907
tc.keysHandshake.w.init(e.Suite, e.Data)
826908
case tls.QUICEncryptionLevelApplication:
827-
tc.keysAppData.w.init(e.Suite, e.Data)
909+
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
828910
}
829911
case tls.QUICWriteData:
830912
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)

0 commit comments

Comments
 (0)