@@ -76,12 +76,14 @@ func (d testDatagram) String() string {
7676}
7777
7878type testPacket struct {
79- ptype packetType
80- version uint32
81- num packetNumber
82- dstConnID []byte
83- srcConnID []byte
84- frames []debugFrame
79+ ptype packetType
80+ version uint32
81+ num packetNumber
82+ keyPhaseBit bool
83+ keyNumber int
84+ dstConnID []byte
85+ srcConnID []byte
86+ frames []debugFrame
8587}
8688
8789func (p testPacket ) String () string {
@@ -102,6 +104,9 @@ func (p testPacket) String() string {
102104 return b .String ()
103105}
104106
107+ // maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
108+ const maxTestKeyPhases = 3
109+
105110// A testConn is a Conn whose external interactions (sending and receiving packets,
106111// setting timers) can be manipulated in tests.
107112type testConn struct {
@@ -122,9 +127,10 @@ type testConn struct {
122127 // the Initial packet.
123128 keysInitial fixedKeyPair
124129 keysHandshake fixedKeyPair
125- keysAppData fixedKeyPair
126- rsecrets [numberSpaceCount ]testKeySecret
127- wsecrets [numberSpaceCount ]testKeySecret
130+ rkeyAppData test1RTTKeys
131+ wkeyAppData test1RTTKeys
132+ rsecrets [numberSpaceCount ]keySecret
133+ wsecrets [numberSpaceCount ]keySecret
128134
129135 // testConn uses a test hook to snoop on the conn's TLS events.
130136 // CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +162,19 @@ type testConn struct {
156162 // Frame types to ignore in tests.
157163 ignoreFrames map [byte ]bool
158164
165+ // Values to set in packets sent to the conn.
166+ sendKeyNumber int
167+ sendKeyPhaseBit bool
168+
159169 asyncTestState
160170}
161171
162- type testKeySecret struct {
172+ type test1RTTKeys struct {
173+ hdr headerKey
174+ pkt [maxTestKeyPhases ]packetKey
175+ }
176+
177+ type keySecret struct {
163178 suite uint16
164179 secret []byte
165180}
@@ -333,12 +348,20 @@ func (tc *testConn) logDatagram(text string, d *testDatagram) {
333348 }
334349 tc .t .Logf ("%v datagram%v" , text , pad )
335350 for _ , p := range d .packets {
351+ var s string
336352 switch p .ptype {
337353 case packetType1RTT :
338- tc . t . Logf (" %v pnum=%v" , p .ptype , p .num )
354+ s = fmt . Sprintf (" %v pnum=%v" , p .ptype , p .num )
339355 default :
340- tc .t .Logf (" %v pnum=%v ver=%v dst={%x} src={%x}" , p .ptype , p .num , p .version , p .dstConnID , p .srcConnID )
356+ s = fmt .Sprintf (" %v pnum=%v ver=%v dst={%x} src={%x}" , p .ptype , p .num , p .version , p .dstConnID , p .srcConnID )
357+ }
358+ if p .keyPhaseBit {
359+ s += fmt .Sprintf (" KeyPhase" )
341360 }
361+ if p .keyNumber != 0 {
362+ s += fmt .Sprintf (" keynum=%v" , p .keyNumber )
363+ }
364+ tc .t .Log (s )
342365 for _ , f := range p .frames {
343366 tc .t .Logf (" %v" , f )
344367 }
@@ -381,12 +404,14 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
381404 }
382405 d := & testDatagram {
383406 packets : []* testPacket {{
384- ptype : ptype ,
385- num : tc .peerNextPacketNum [space ],
386- frames : frames ,
387- version : 1 ,
388- dstConnID : dstConnID ,
389- srcConnID : tc .peerConnID ,
407+ ptype : ptype ,
408+ num : tc .peerNextPacketNum [space ],
409+ keyNumber : tc .sendKeyNumber ,
410+ keyPhaseBit : tc .sendKeyPhaseBit ,
411+ frames : frames ,
412+ version : 1 ,
413+ dstConnID : dstConnID ,
414+ srcConnID : tc .peerConnID ,
390415 }},
391416 }
392417 if ptype == packetTypeInitial && tc .conn .side == serverSide {
@@ -580,6 +605,22 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu
580605 }
581606}
582607
608+ // wantFrameType indicates that we expect the Conn to send a frame,
609+ // although we don't care about the contents.
610+ func (tc * testConn ) wantFrameType (expectation string , wantType packetType , want debugFrame ) {
611+ tc .t .Helper ()
612+ got , gotType := tc .readFrame ()
613+ if got == nil {
614+ tc .t .Fatalf ("%v:\n connection is idle\n want %v frame: %v" , expectation , wantType , want )
615+ }
616+ if gotType != wantType {
617+ tc .t .Fatalf ("%v:\n got %v packet, want %v\n got frame: %v" , expectation , gotType , wantType , got )
618+ }
619+ if reflect .TypeOf (got ) != reflect .TypeOf (want ) {
620+ tc .t .Fatalf ("%v:\n got frame: %v\n want frame of type: %v" , expectation , got , want )
621+ }
622+ }
623+
583624// wantIdle indicates that we expect the Conn to not send any more frames.
584625func (tc * testConn ) wantIdle (expectation string ) {
585626 tc .t .Helper ()
@@ -615,28 +656,42 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte {
615656 }
616657 w .appendPaddingTo (pad )
617658 if p .ptype != packetType1RTT {
618- var k fixedKeyPair
659+ var k fixedKeys
619660 switch p .ptype {
620661 case packetTypeInitial :
621- k = tc .keysInitial
662+ k = tc .keysInitial . w
622663 case packetTypeHandshake :
623- k = tc .keysHandshake
664+ k = tc .keysHandshake . w
624665 }
625- if ! k .canWrite () {
666+ if ! k .isSet () {
626667 tc .t .Fatalf ("sending %v packet with no write key" , p .ptype )
627668 }
628- w .finishProtectedLongHeaderPacket (pnumMaxAcked , k . w , longPacket {
669+ w .finishProtectedLongHeaderPacket (pnumMaxAcked , k , longPacket {
629670 ptype : p .ptype ,
630671 version : p .version ,
631672 num : p .num ,
632673 dstConnID : p .dstConnID ,
633674 srcConnID : p .srcConnID ,
634675 })
635676 } else {
636- if ! tc .keysAppData .canWrite () {
637- tc .t .Fatalf ("sending %v packet with no write key" , p .ptype )
677+ if ! tc .wkeyAppData .hdr .isSet () {
678+ tc .t .Fatalf ("sending 1-RTT packet with no write key" )
679+ }
680+ // Somewhat hackish: Generate a temporary updatingKeyPair that will
681+ // always use our desired key phase.
682+ k := & updatingKeyPair {
683+ w : updatingKeys {
684+ hdr : tc .wkeyAppData .hdr ,
685+ pkt : [2 ]packetKey {
686+ tc .wkeyAppData .pkt [p .keyNumber ],
687+ tc .wkeyAppData .pkt [p .keyNumber ],
688+ },
689+ },
690+ }
691+ if p .keyPhaseBit {
692+ k .phase |= keyPhaseBit
638693 }
639- w .finish1RTTPacket (p .num , pnumMaxAcked , p .dstConnID , tc . keysAppData . w )
694+ w .finish1RTTPacket (p .num , pnumMaxAcked , p .dstConnID , k )
640695 }
641696 return w .datagram ()
642697}
@@ -682,25 +737,45 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram {
682737 })
683738 buf = buf [n :]
684739 } else {
685- if ! tc .keysAppData . canRead () {
740+ if ! tc .rkeyAppData . hdr . isSet () {
686741 tc .t .Fatalf ("reading 1-RTT packet with no read key" )
687742 }
688743 var pnumMax packetNumber // TODO: Track packet numbers.
689- p , n := parse1RTTPacket (buf , tc .keysAppData .r , len (tc .peerConnID ), pnumMax )
690- if n < 0 {
691- tc .t .Fatalf ("packet parse error" )
744+ pnumOff := 1 + len (tc .peerConnID )
745+ // Try unprotecting the packet with the first maxTestKeyPhases keys.
746+ var phase int
747+ var pnum packetNumber
748+ var hdr []byte
749+ var pay []byte
750+ var err error
751+ for phase = 0 ; phase < maxTestKeyPhases ; phase ++ {
752+ b := append ([]byte {}, buf ... )
753+ hdr , pay , pnum , err = tc .rkeyAppData .hdr .unprotect (b , pnumOff , pnumMax )
754+ if err != nil {
755+ tc .t .Fatalf ("1-RTT packet header parse error" )
756+ }
757+ k := tc .rkeyAppData .pkt [phase ]
758+ pay , err = k .unprotect (hdr , pay , pnum )
759+ if err == nil {
760+ break
761+ }
692762 }
693- frames , err := tc .parseTestFrames (p .payload )
763+ if err != nil {
764+ tc .t .Fatalf ("1-RTT packet payload parse error" )
765+ }
766+ frames , err := tc .parseTestFrames (pay )
694767 if err != nil {
695768 tc .t .Fatal (err )
696769 }
697770 d .packets = append (d .packets , & testPacket {
698- ptype : packetType1RTT ,
699- num : p .num ,
700- dstConnID : buf [1 :][:len (tc .peerConnID )],
701- frames : frames ,
771+ ptype : packetType1RTT ,
772+ num : pnum ,
773+ dstConnID : hdr [1 :][:len (tc .peerConnID )],
774+ keyPhaseBit : hdr [0 ]& keyPhaseBit != 0 ,
775+ keyNumber : phase ,
776+ frames : frames ,
702777 })
703- buf = buf [n :]
778+ buf = buf [len ( buf ) :]
704779 }
705780 }
706781 // This is rather hackish: If the last frame in the last packet
@@ -766,7 +841,7 @@ type testConnHooks testConn
766841// and verify that both sides of the connection are getting
767842// matching keys.
768843func (tc * testConnHooks ) handleTLSEvent (e tls.QUICEvent ) {
769- checkKey := func (typ string , secrets * [numberSpaceCount ]testKeySecret , e tls.QUICEvent ) {
844+ checkKey := func (typ string , secrets * [numberSpaceCount ]keySecret , e tls.QUICEvent ) {
770845 var space numberSpace
771846 switch {
772847 case e .Level == tls .QUICEncryptionLevelHandshake :
@@ -781,25 +856,32 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
781856 secrets [space ].suite = e .Suite
782857 secrets [space ].secret = append ([]byte {}, e .Data ... )
783858 } else if secrets [space ].suite != e .Suite || ! bytes .Equal (secrets [space ].secret , e .Data ) {
784- tc .t .Errorf ("%v key mismatch for level %v" , typ , e .Level )
859+ tc .t .Errorf ("%v key mismatch for level for level %v" , typ , e .Level )
860+ }
861+ }
862+ setAppDataKey := func (suite uint16 , secret []byte , k * test1RTTKeys ) {
863+ k .hdr .init (suite , secret )
864+ for i := 0 ; i < len (k .pkt ); i ++ {
865+ k .pkt [i ].init (suite , secret )
866+ secret = updateSecret (suite , secret )
785867 }
786868 }
787869 switch e .Kind {
788870 case tls .QUICSetReadSecret :
789- checkKey ("read " , & tc .rsecrets , e )
871+ checkKey ("write " , & tc .wsecrets , e )
790872 switch e .Level {
791873 case tls .QUICEncryptionLevelHandshake :
792874 tc .keysHandshake .w .init (e .Suite , e .Data )
793875 case tls .QUICEncryptionLevelApplication :
794- tc . keysAppData . w . init (e .Suite , e .Data )
876+ setAppDataKey (e .Suite , e .Data , & tc . wkeyAppData )
795877 }
796878 case tls .QUICSetWriteSecret :
797- checkKey ("write " , & tc .wsecrets , e )
879+ checkKey ("read " , & tc .rsecrets , e )
798880 switch e .Level {
799881 case tls .QUICEncryptionLevelHandshake :
800882 tc .keysHandshake .r .init (e .Suite , e .Data )
801883 case tls .QUICEncryptionLevelApplication :
802- tc . keysAppData . r . init (e .Suite , e .Data )
884+ setAppDataKey (e .Suite , e .Data , & tc . rkeyAppData )
803885 }
804886 case tls .QUICWriteData :
805887 tc .cryptoDataOut [e .Level ] = append (tc .cryptoDataOut [e .Level ], e .Data ... )
@@ -811,20 +893,20 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
811893 case tls .QUICNoEvent :
812894 return
813895 case tls .QUICSetReadSecret :
814- checkKey ("write" , & tc .wsecrets , e )
896+ checkKey ("write" , & tc .rsecrets , e )
815897 switch e .Level {
816898 case tls .QUICEncryptionLevelHandshake :
817899 tc .keysHandshake .r .init (e .Suite , e .Data )
818900 case tls .QUICEncryptionLevelApplication :
819- tc . keysAppData . r . init (e .Suite , e .Data )
901+ setAppDataKey (e .Suite , e .Data , & tc . rkeyAppData )
820902 }
821903 case tls .QUICSetWriteSecret :
822- checkKey ("read" , & tc .rsecrets , e )
904+ checkKey ("read" , & tc .wsecrets , e )
823905 switch e .Level {
824906 case tls .QUICEncryptionLevelHandshake :
825907 tc .keysHandshake .w .init (e .Suite , e .Data )
826908 case tls .QUICEncryptionLevelApplication :
827- tc . keysAppData . w . init (e .Suite , e .Data )
909+ setAppDataKey (e .Suite , e .Data , & tc . wkeyAppData )
828910 }
829911 case tls .QUICWriteData :
830912 tc .cryptoDataIn [e .Level ] = append (tc .cryptoDataIn [e .Level ], e .Data ... )
0 commit comments