|
| 1 | +// Copyright 2020 Cloudflare, Inc. All rights reserved. Use of this source code |
| 2 | +// is governed by a BSD-style license that can be found in the LICENSE file. |
| 3 | + |
| 4 | +package tls |
| 5 | + |
| 6 | +import ( |
| 7 | + "crypto/x509" |
| 8 | + "fmt" |
| 9 | + "io" |
| 10 | + "io/ioutil" |
| 11 | + "testing" |
| 12 | + "time" |
| 13 | +) |
| 14 | + |
| 15 | +type testTimingInfo struct { |
| 16 | + serverTimingInfo CFEventTLS13ServerHandshakeTimingInfo |
| 17 | + clientTimingInfo CFEventTLS13ClientHandshakeTimingInfo |
| 18 | +} |
| 19 | + |
| 20 | +func (t testTimingInfo) isMonotonicallyIncreasing() bool { |
| 21 | + serverIsMonotonicallyIncreasing := |
| 22 | + t.serverTimingInfo.ProcessClientHello < t.serverTimingInfo.WriteServerHello && |
| 23 | + t.serverTimingInfo.WriteServerHello < t.serverTimingInfo.WriteEncryptedExtensions && |
| 24 | + t.serverTimingInfo.WriteEncryptedExtensions < t.serverTimingInfo.WriteCertificate && |
| 25 | + t.serverTimingInfo.WriteCertificate < t.serverTimingInfo.WriteCertificateVerify && |
| 26 | + t.serverTimingInfo.WriteCertificateVerify < t.serverTimingInfo.WriteServerFinished && |
| 27 | + t.serverTimingInfo.WriteServerFinished < t.serverTimingInfo.ReadCertificate && |
| 28 | + t.serverTimingInfo.ReadCertificate < t.serverTimingInfo.ReadCertificateVerify && |
| 29 | + t.serverTimingInfo.ReadCertificateVerify < t.serverTimingInfo.ReadClientFinished |
| 30 | + |
| 31 | + clientIsMonotonicallyIncreasing := |
| 32 | + t.clientTimingInfo.WriteClientHello < t.clientTimingInfo.ProcessServerHello && |
| 33 | + t.clientTimingInfo.ProcessServerHello < t.clientTimingInfo.ReadEncryptedExtensions && |
| 34 | + t.clientTimingInfo.ReadEncryptedExtensions < t.clientTimingInfo.ReadCertificate && |
| 35 | + t.clientTimingInfo.ReadCertificate < t.clientTimingInfo.ReadCertificateVerify && |
| 36 | + t.clientTimingInfo.ReadCertificateVerify < t.clientTimingInfo.ReadServerFinished && |
| 37 | + t.clientTimingInfo.ReadServerFinished < t.clientTimingInfo.WriteCertificate && |
| 38 | + t.clientTimingInfo.WriteCertificate < t.clientTimingInfo.WriteCertificateVerify && |
| 39 | + t.clientTimingInfo.WriteCertificateVerify < t.clientTimingInfo.WriteClientFinished |
| 40 | + |
| 41 | + return (serverIsMonotonicallyIncreasing && clientIsMonotonicallyIncreasing) |
| 42 | +} |
| 43 | + |
| 44 | +func (r *testTimingInfo) eventHandler(event CFEvent) { |
| 45 | + switch e := event.(type) { |
| 46 | + case CFEventTLS13ServerHandshakeTimingInfo: |
| 47 | + r.serverTimingInfo = e |
| 48 | + case CFEventTLS13ClientHandshakeTimingInfo: |
| 49 | + r.clientTimingInfo = e |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +func runHandshake(t *testing.T, clientConfig, serverConfig *Config) (timingState testTimingInfo, err error) { |
| 54 | + const sentinel = "SENTINEL\n" |
| 55 | + c, s := localPipe(t) |
| 56 | + errChan := make(chan error) |
| 57 | + |
| 58 | + clientConfig.CFEventHandler = timingState.eventHandler |
| 59 | + serverConfig.CFEventHandler = timingState.eventHandler |
| 60 | + |
| 61 | + go func() { |
| 62 | + cli := Client(c, clientConfig) |
| 63 | + err := cli.Handshake() |
| 64 | + if err != nil { |
| 65 | + errChan <- fmt.Errorf("client: %v", err) |
| 66 | + c.Close() |
| 67 | + return |
| 68 | + } |
| 69 | + defer cli.Close() |
| 70 | + buf, err := ioutil.ReadAll(cli) |
| 71 | + if err != nil { |
| 72 | + t.Errorf("failed to call cli.Read: %v", err) |
| 73 | + } |
| 74 | + if got := string(buf); got != sentinel { |
| 75 | + t.Errorf("read %q from TLS connection, but expected %q", got, sentinel) |
| 76 | + } |
| 77 | + errChan <- nil |
| 78 | + }() |
| 79 | + |
| 80 | + server := Server(s, serverConfig) |
| 81 | + err = server.Handshake() |
| 82 | + if err == nil { |
| 83 | + if _, err := io.WriteString(server, sentinel); err != nil { |
| 84 | + t.Errorf("failed to call server.Write: %v", err) |
| 85 | + } |
| 86 | + if err := server.Close(); err != nil { |
| 87 | + t.Errorf("failed to call server.Close: %v", err) |
| 88 | + } |
| 89 | + err = <-errChan |
| 90 | + } else { |
| 91 | + s.Close() |
| 92 | + <-errChan |
| 93 | + } |
| 94 | + |
| 95 | + return |
| 96 | +} |
| 97 | + |
| 98 | +func TestTLS13HandshakeTiming(t *testing.T) { |
| 99 | + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) |
| 100 | + if err != nil { |
| 101 | + panic(err) |
| 102 | + } |
| 103 | + rootCAs := x509.NewCertPool() |
| 104 | + rootCAs.AddCert(issuer) |
| 105 | + |
| 106 | + const serverName = "example.golang" |
| 107 | + |
| 108 | + baseConfig := &Config{ |
| 109 | + Time: time.Now, |
| 110 | + Rand: zeroSource{}, |
| 111 | + Certificates: make([]Certificate, 1), |
| 112 | + MaxVersion: VersionTLS13, |
| 113 | + RootCAs: rootCAs, |
| 114 | + ClientCAs: rootCAs, |
| 115 | + ClientAuth: RequireAndVerifyClientCert, |
| 116 | + ServerName: serverName, |
| 117 | + } |
| 118 | + baseConfig.Certificates[0].Certificate = [][]byte{testRSACertificate} |
| 119 | + baseConfig.Certificates[0].PrivateKey = testRSAPrivateKey |
| 120 | + |
| 121 | + clientConfig := baseConfig.Clone() |
| 122 | + serverConfig := baseConfig.Clone() |
| 123 | + |
| 124 | + ts, err := runHandshake(t, clientConfig, serverConfig) |
| 125 | + if err != nil { |
| 126 | + t.Fatalf("Handshake failed: %v", err) |
| 127 | + } |
| 128 | + |
| 129 | + if !ts.isMonotonicallyIncreasing() { |
| 130 | + t.Fatalf("Timing information is not monotonic") |
| 131 | + } |
| 132 | +} |
0 commit comments