Skip to content

Commit cdd40f5

Browse files
Steven Scottgaryburd
authored andcommitted
Add comprehensive host test (#429)
Add table driven test for handling of host in request URL, request header and TLS server name. In addition to testing various uses of host names, this test also confirms that host names are handled the same as the net/http client. The new table driven test replaces TestDialTLS, TestDialTLSNoverify, TestDialTLSBadCert and TestHostHeader. Eliminate duplicated code for constructing root CA.
1 parent 66b9c49 commit cdd40f5

File tree

1 file changed

+194
-79
lines changed

1 file changed

+194
-79
lines changed

client_server_test.go

Lines changed: 194 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ import (
1111
"crypto/x509"
1212
"encoding/base64"
1313
"encoding/binary"
14+
"fmt"
1415
"io"
1516
"io/ioutil"
17+
"log"
1618
"net"
1719
"net/http"
1820
"net/http/cookiejar"
@@ -42,17 +44,12 @@ var cstDialer = Dialer{
4244
HandshakeTimeout: 30 * time.Second,
4345
}
4446

45-
var cstDialerWithoutHandshakeTimeout = Dialer{
46-
Subprotocols: []string{"p1", "p2"},
47-
ReadBufferSize: 1024,
48-
WriteBufferSize: 1024,
49-
}
50-
5147
type cstHandler struct{ *testing.T }
5248

5349
type cstServer struct {
5450
*httptest.Server
5551
URL string
52+
t *testing.T
5653
}
5754

5855
const (
@@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
288285
sendRecv(t, ws)
289286
}
290287

291-
func TestDialTLS(t *testing.T) {
292-
s := newTLSServer(t)
293-
defer s.Close()
294-
288+
func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
295289
certs := x509.NewCertPool()
296290
for _, c := range s.TLS.Certificates {
297291
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
@@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
302296
certs.AddCert(root)
303297
}
304298
}
305-
306-
d := cstDialer
307-
d.TLSClientConfig = &tls.Config{RootCAs: certs}
308-
ws, _, err := d.Dial(s.URL, nil)
309-
if err != nil {
310-
t.Fatalf("Dial: %v", err)
311-
}
312-
defer ws.Close()
313-
sendRecv(t, ws)
314-
}
315-
316-
func xTestDialTLSBadCert(t *testing.T) {
317-
// This test is deactivated because of noisy logging from the net/http package.
318-
s := newTLSServer(t)
319-
defer s.Close()
320-
321-
ws, _, err := cstDialer.Dial(s.URL, nil)
322-
if err == nil {
323-
ws.Close()
324-
t.Fatalf("Dial: nil")
325-
}
299+
return certs
326300
}
327301

328-
func TestDialTLSNoVerify(t *testing.T) {
302+
func TestDialTLS(t *testing.T) {
329303
s := newTLSServer(t)
330304
defer s.Close()
331305

332306
d := cstDialer
333-
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
307+
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
334308
ws, _, err := d.Dial(s.URL, nil)
335309
if err != nil {
336310
t.Fatalf("Dial: %v", err)
@@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
415389
s := newServer(t)
416390
defer s.Close()
417391

418-
d := cstDialerWithoutHandshakeTimeout
392+
d := cstDialer
393+
d.HandshakeTimeout = 0
419394
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
420395
netDialer := &net.Dialer{}
421396
c, err := netDialer.DialContext(ctx, n, a)
@@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
566541
}
567542
}
568543

569-
// TestHostHeader confirms that the host header provided in the call to Dial is
570-
// sent to the server.
571-
func TestHostHeader(t *testing.T) {
572-
s := newServer(t)
573-
defer s.Close()
544+
type testLogWriter struct {
545+
t *testing.T
546+
}
574547

575-
specifiedHost := make(chan string, 1)
576-
origHandler := s.Server.Config.Handler
548+
func (w testLogWriter) Write(p []byte) (int, error) {
549+
w.t.Logf("%s", p)
550+
return len(p), nil
551+
}
577552

578-
// Capture the request Host header.
579-
s.Server.Config.Handler = http.HandlerFunc(
580-
func(w http.ResponseWriter, r *http.Request) {
581-
specifiedHost <- r.Host
582-
origHandler.ServeHTTP(w, r)
583-
})
553+
// TestHost tests handling of host names and confirms that it matches net/http.
554+
func TestHost(t *testing.T) {
584555

585-
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
586-
if err != nil {
587-
t.Fatalf("Dial: %v", err)
588-
}
589-
defer ws.Close()
556+
upgrader := Upgrader{}
557+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
558+
if IsWebSocketUpgrade(r) {
559+
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
560+
if err != nil {
561+
t.Fatal(err)
562+
}
563+
c.Close()
564+
} else {
565+
w.Header().Set("X-Test-Host", r.Host)
566+
}
567+
})
568+
569+
server := httptest.NewServer(handler)
570+
defer server.Close()
571+
572+
tlsServer := httptest.NewTLSServer(handler)
573+
defer tlsServer.Close()
574+
575+
addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
576+
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
577+
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
578+
579+
// Avoid log noise from net/http server by logging to testing.T
580+
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
581+
tlsServer.Config.ErrorLog = server.Config.ErrorLog
582+
583+
cas := rootCAs(t, tlsServer)
584+
585+
tests := []struct {
586+
fail bool // true if dial / get should fail
587+
server *httptest.Server // server to use
588+
url string // host for request URI
589+
header string // optional request host header
590+
tls string // optiona host for tls ServerName
591+
wantAddr string // expected host for dial
592+
wantHeader string // expected request header on server
593+
insecureSkipVerify bool
594+
}{
595+
{
596+
server: server,
597+
url: addrs[server],
598+
wantAddr: addrs[server],
599+
wantHeader: addrs[server],
600+
},
601+
{
602+
server: tlsServer,
603+
url: addrs[tlsServer],
604+
wantAddr: addrs[tlsServer],
605+
wantHeader: addrs[tlsServer],
606+
},
607+
608+
{
609+
server: server,
610+
url: addrs[server],
611+
header: "badhost.com",
612+
wantAddr: addrs[server],
613+
wantHeader: "badhost.com",
614+
},
615+
{
616+
server: tlsServer,
617+
url: addrs[tlsServer],
618+
header: "badhost.com",
619+
wantAddr: addrs[tlsServer],
620+
wantHeader: "badhost.com",
621+
},
622+
623+
{
624+
server: server,
625+
url: "example.com",
626+
header: "badhost.com",
627+
wantAddr: "example.com:80",
628+
wantHeader: "badhost.com",
629+
},
630+
{
631+
server: tlsServer,
632+
url: "example.com",
633+
header: "badhost.com",
634+
wantAddr: "example.com:443",
635+
wantHeader: "badhost.com",
636+
},
590637

591-
if gotHost := <-specifiedHost; gotHost != "testhost" {
592-
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
638+
{
639+
server: server,
640+
url: "badhost.com",
641+
header: "example.com",
642+
wantAddr: "badhost.com:80",
643+
wantHeader: "example.com",
644+
},
645+
{
646+
fail: true,
647+
server: tlsServer,
648+
url: "badhost.com",
649+
header: "example.com",
650+
wantAddr: "badhost.com:443",
651+
},
652+
{
653+
server: tlsServer,
654+
url: "badhost.com",
655+
insecureSkipVerify: true,
656+
wantAddr: "badhost.com:443",
657+
wantHeader: "badhost.com",
658+
},
659+
{
660+
server: tlsServer,
661+
url: "badhost.com",
662+
tls: "example.com",
663+
wantAddr: "badhost.com:443",
664+
wantHeader: "badhost.com",
665+
},
593666
}
594667

595-
sendRecv(t, ws)
668+
for i, tt := range tests {
669+
670+
tls := &tls.Config{
671+
RootCAs: cas,
672+
ServerName: tt.tls,
673+
InsecureSkipVerify: tt.insecureSkipVerify,
674+
}
675+
676+
var gotAddr string
677+
dialer := Dialer{
678+
NetDial: func(network, addr string) (net.Conn, error) {
679+
gotAddr = addr
680+
return net.Dial(network, addrs[tt.server])
681+
},
682+
TLSClientConfig: tls,
683+
}
684+
685+
// Test websocket dial
686+
687+
h := http.Header{}
688+
if tt.header != "" {
689+
h.Set("Host", tt.header)
690+
}
691+
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
692+
if err == nil {
693+
c.Close()
694+
}
695+
696+
check := func(protos map[*httptest.Server]string) {
697+
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
698+
if gotAddr != tt.wantAddr {
699+
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
700+
}
701+
switch {
702+
case tt.fail && err == nil:
703+
t.Errorf("%s: unexpected success", name)
704+
case !tt.fail && err != nil:
705+
t.Errorf("%s: unexpected error %v", name, err)
706+
case !tt.fail && err == nil:
707+
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
708+
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
709+
}
710+
}
711+
}
712+
713+
check(wsProtos)
714+
715+
// Confirm that net/http has same result
716+
717+
transport := &http.Transport{
718+
Dial: dialer.NetDial,
719+
TLSClientConfig: dialer.TLSClientConfig,
720+
}
721+
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
722+
if tt.header != "" {
723+
req.Host = tt.header
724+
}
725+
client := &http.Client{Transport: transport}
726+
resp, err = client.Do(req)
727+
if err == nil {
728+
resp.Body.Close()
729+
}
730+
transport.CloseIdleConnections()
731+
check(httpProtos)
732+
}
596733
}
597734

598735
func TestDialCompression(t *testing.T) {
@@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
716853
s := newTLSServer(t)
717854
defer s.Close()
718855

719-
certs := x509.NewCertPool()
720-
for _, c := range s.TLS.Certificates {
721-
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
722-
if err != nil {
723-
t.Fatalf("error parsing server's root cert: %v", err)
724-
}
725-
for _, root := range roots {
726-
certs.AddCert(root)
727-
}
728-
}
729-
730856
d := cstDialer
731-
d.TLSClientConfig = &tls.Config{RootCAs: certs}
857+
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
732858

733859
ws, _, err := d.DialContext(ctx, s.URL, nil)
734860
if err != nil {
@@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
766892
s := newTLSServer(t)
767893
defer s.Close()
768894

769-
certs := x509.NewCertPool()
770-
for _, c := range s.TLS.Certificates {
771-
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
772-
if err != nil {
773-
t.Fatalf("error parsing server's root cert: %v", err)
774-
}
775-
for _, root := range roots {
776-
certs.AddCert(root)
777-
}
778-
}
779-
780895
d := cstDialer
781-
d.TLSClientConfig = &tls.Config{RootCAs: certs}
896+
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
782897

783898
ws, _, err := d.DialContext(ctx, s.URL, nil)
784899
if err != nil {

0 commit comments

Comments
 (0)