@@ -11,8 +11,10 @@ import (
11
11
"crypto/x509"
12
12
"encoding/base64"
13
13
"encoding/binary"
14
+ "fmt"
14
15
"io"
15
16
"io/ioutil"
17
+ "log"
16
18
"net"
17
19
"net/http"
18
20
"net/http/cookiejar"
@@ -42,17 +44,12 @@ var cstDialer = Dialer{
42
44
HandshakeTimeout : 30 * time .Second ,
43
45
}
44
46
45
- var cstDialerWithoutHandshakeTimeout = Dialer {
46
- Subprotocols : []string {"p1" , "p2" },
47
- ReadBufferSize : 1024 ,
48
- WriteBufferSize : 1024 ,
49
- }
50
-
51
47
type cstHandler struct { * testing.T }
52
48
53
49
type cstServer struct {
54
50
* httptest.Server
55
51
URL string
52
+ t * testing.T
56
53
}
57
54
58
55
const (
@@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
288
285
sendRecv (t , ws )
289
286
}
290
287
291
- func TestDialTLS (t * testing.T ) {
292
- s := newTLSServer (t )
293
- defer s .Close ()
294
-
288
+ func rootCAs (t * testing.T , s * httptest.Server ) * x509.CertPool {
295
289
certs := x509 .NewCertPool ()
296
290
for _ , c := range s .TLS .Certificates {
297
291
roots , err := x509 .ParseCertificates (c .Certificate [len (c .Certificate )- 1 ])
@@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
302
296
certs .AddCert (root )
303
297
}
304
298
}
305
-
306
- d := cstDialer
307
- d .TLSClientConfig = & tls.Config {RootCAs : certs }
308
- ws , _ , err := d .Dial (s .URL , nil )
309
- if err != nil {
310
- t .Fatalf ("Dial: %v" , err )
311
- }
312
- defer ws .Close ()
313
- sendRecv (t , ws )
314
- }
315
-
316
- func xTestDialTLSBadCert (t * testing.T ) {
317
- // This test is deactivated because of noisy logging from the net/http package.
318
- s := newTLSServer (t )
319
- defer s .Close ()
320
-
321
- ws , _ , err := cstDialer .Dial (s .URL , nil )
322
- if err == nil {
323
- ws .Close ()
324
- t .Fatalf ("Dial: nil" )
325
- }
299
+ return certs
326
300
}
327
301
328
- func TestDialTLSNoVerify (t * testing.T ) {
302
+ func TestDialTLS (t * testing.T ) {
329
303
s := newTLSServer (t )
330
304
defer s .Close ()
331
305
332
306
d := cstDialer
333
- d .TLSClientConfig = & tls.Config {InsecureSkipVerify : true }
307
+ d .TLSClientConfig = & tls.Config {RootCAs : rootCAs ( t , s . Server ) }
334
308
ws , _ , err := d .Dial (s .URL , nil )
335
309
if err != nil {
336
310
t .Fatalf ("Dial: %v" , err )
@@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
415
389
s := newServer (t )
416
390
defer s .Close ()
417
391
418
- d := cstDialerWithoutHandshakeTimeout
392
+ d := cstDialer
393
+ d .HandshakeTimeout = 0
419
394
d .NetDialContext = func (ctx context.Context , n , a string ) (net.Conn , error ) {
420
395
netDialer := & net.Dialer {}
421
396
c , err := netDialer .DialContext (ctx , n , a )
@@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
566
541
}
567
542
}
568
543
569
- // TestHostHeader confirms that the host header provided in the call to Dial is
570
- // sent to the server.
571
- func TestHostHeader (t * testing.T ) {
572
- s := newServer (t )
573
- defer s .Close ()
544
+ type testLogWriter struct {
545
+ t * testing.T
546
+ }
574
547
575
- specifiedHost := make (chan string , 1 )
576
- origHandler := s .Server .Config .Handler
548
+ func (w testLogWriter ) Write (p []byte ) (int , error ) {
549
+ w .t .Logf ("%s" , p )
550
+ return len (p ), nil
551
+ }
577
552
578
- // Capture the request Host header.
579
- s .Server .Config .Handler = http .HandlerFunc (
580
- func (w http.ResponseWriter , r * http.Request ) {
581
- specifiedHost <- r .Host
582
- origHandler .ServeHTTP (w , r )
583
- })
553
+ // TestHost tests handling of host names and confirms that it matches net/http.
554
+ func TestHost (t * testing.T ) {
584
555
585
- ws , _ , err := cstDialer .Dial (s .URL , http.Header {"Host" : {"testhost" }})
586
- if err != nil {
587
- t .Fatalf ("Dial: %v" , err )
588
- }
589
- defer ws .Close ()
556
+ upgrader := Upgrader {}
557
+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
558
+ if IsWebSocketUpgrade (r ) {
559
+ c , err := upgrader .Upgrade (w , r , http.Header {"X-Test-Host" : {r .Host }})
560
+ if err != nil {
561
+ t .Fatal (err )
562
+ }
563
+ c .Close ()
564
+ } else {
565
+ w .Header ().Set ("X-Test-Host" , r .Host )
566
+ }
567
+ })
568
+
569
+ server := httptest .NewServer (handler )
570
+ defer server .Close ()
571
+
572
+ tlsServer := httptest .NewTLSServer (handler )
573
+ defer tlsServer .Close ()
574
+
575
+ addrs := map [* httptest.Server ]string {server : server .Listener .Addr ().String (), tlsServer : tlsServer .Listener .Addr ().String ()}
576
+ wsProtos := map [* httptest.Server ]string {server : "ws://" , tlsServer : "wss://" }
577
+ httpProtos := map [* httptest.Server ]string {server : "http://" , tlsServer : "https://" }
578
+
579
+ // Avoid log noise from net/http server by logging to testing.T
580
+ server .Config .ErrorLog = log .New (testLogWriter {t }, "" , 0 )
581
+ tlsServer .Config .ErrorLog = server .Config .ErrorLog
582
+
583
+ cas := rootCAs (t , tlsServer )
584
+
585
+ tests := []struct {
586
+ fail bool // true if dial / get should fail
587
+ server * httptest.Server // server to use
588
+ url string // host for request URI
589
+ header string // optional request host header
590
+ tls string // optiona host for tls ServerName
591
+ wantAddr string // expected host for dial
592
+ wantHeader string // expected request header on server
593
+ insecureSkipVerify bool
594
+ }{
595
+ {
596
+ server : server ,
597
+ url : addrs [server ],
598
+ wantAddr : addrs [server ],
599
+ wantHeader : addrs [server ],
600
+ },
601
+ {
602
+ server : tlsServer ,
603
+ url : addrs [tlsServer ],
604
+ wantAddr : addrs [tlsServer ],
605
+ wantHeader : addrs [tlsServer ],
606
+ },
607
+
608
+ {
609
+ server : server ,
610
+ url : addrs [server ],
611
+ header : "badhost.com" ,
612
+ wantAddr : addrs [server ],
613
+ wantHeader : "badhost.com" ,
614
+ },
615
+ {
616
+ server : tlsServer ,
617
+ url : addrs [tlsServer ],
618
+ header : "badhost.com" ,
619
+ wantAddr : addrs [tlsServer ],
620
+ wantHeader : "badhost.com" ,
621
+ },
622
+
623
+ {
624
+ server : server ,
625
+ url : "example.com" ,
626
+ header : "badhost.com" ,
627
+ wantAddr : "example.com:80" ,
628
+ wantHeader : "badhost.com" ,
629
+ },
630
+ {
631
+ server : tlsServer ,
632
+ url : "example.com" ,
633
+ header : "badhost.com" ,
634
+ wantAddr : "example.com:443" ,
635
+ wantHeader : "badhost.com" ,
636
+ },
590
637
591
- if gotHost := <- specifiedHost ; gotHost != "testhost" {
592
- t .Fatalf ("gotHost = %q, want \" testhost\" " , gotHost )
638
+ {
639
+ server : server ,
640
+ url : "badhost.com" ,
641
+ header : "example.com" ,
642
+ wantAddr : "badhost.com:80" ,
643
+ wantHeader : "example.com" ,
644
+ },
645
+ {
646
+ fail : true ,
647
+ server : tlsServer ,
648
+ url : "badhost.com" ,
649
+ header : "example.com" ,
650
+ wantAddr : "badhost.com:443" ,
651
+ },
652
+ {
653
+ server : tlsServer ,
654
+ url : "badhost.com" ,
655
+ insecureSkipVerify : true ,
656
+ wantAddr : "badhost.com:443" ,
657
+ wantHeader : "badhost.com" ,
658
+ },
659
+ {
660
+ server : tlsServer ,
661
+ url : "badhost.com" ,
662
+ tls : "example.com" ,
663
+ wantAddr : "badhost.com:443" ,
664
+ wantHeader : "badhost.com" ,
665
+ },
593
666
}
594
667
595
- sendRecv (t , ws )
668
+ for i , tt := range tests {
669
+
670
+ tls := & tls.Config {
671
+ RootCAs : cas ,
672
+ ServerName : tt .tls ,
673
+ InsecureSkipVerify : tt .insecureSkipVerify ,
674
+ }
675
+
676
+ var gotAddr string
677
+ dialer := Dialer {
678
+ NetDial : func (network , addr string ) (net.Conn , error ) {
679
+ gotAddr = addr
680
+ return net .Dial (network , addrs [tt .server ])
681
+ },
682
+ TLSClientConfig : tls ,
683
+ }
684
+
685
+ // Test websocket dial
686
+
687
+ h := http.Header {}
688
+ if tt .header != "" {
689
+ h .Set ("Host" , tt .header )
690
+ }
691
+ c , resp , err := dialer .Dial (wsProtos [tt .server ]+ tt .url + "/" , h )
692
+ if err == nil {
693
+ c .Close ()
694
+ }
695
+
696
+ check := func (protos map [* httptest.Server ]string ) {
697
+ name := fmt .Sprintf ("%d: %s%s/ header[Host]=%q, tls.ServerName=%q" , i + 1 , protos [tt .server ], tt .url , tt .header , tt .tls )
698
+ if gotAddr != tt .wantAddr {
699
+ t .Errorf ("%s: got addr %s, want %s" , name , gotAddr , tt .wantAddr )
700
+ }
701
+ switch {
702
+ case tt .fail && err == nil :
703
+ t .Errorf ("%s: unexpected success" , name )
704
+ case ! tt .fail && err != nil :
705
+ t .Errorf ("%s: unexpected error %v" , name , err )
706
+ case ! tt .fail && err == nil :
707
+ if gotHost := resp .Header .Get ("X-Test-Host" ); gotHost != tt .wantHeader {
708
+ t .Errorf ("%s: got host %s, want %s" , name , gotHost , tt .wantHeader )
709
+ }
710
+ }
711
+ }
712
+
713
+ check (wsProtos )
714
+
715
+ // Confirm that net/http has same result
716
+
717
+ transport := & http.Transport {
718
+ Dial : dialer .NetDial ,
719
+ TLSClientConfig : dialer .TLSClientConfig ,
720
+ }
721
+ req , _ := http .NewRequest ("GET" , httpProtos [tt .server ]+ tt .url + "/" , nil )
722
+ if tt .header != "" {
723
+ req .Host = tt .header
724
+ }
725
+ client := & http.Client {Transport : transport }
726
+ resp , err = client .Do (req )
727
+ if err == nil {
728
+ resp .Body .Close ()
729
+ }
730
+ transport .CloseIdleConnections ()
731
+ check (httpProtos )
732
+ }
596
733
}
597
734
598
735
func TestDialCompression (t * testing.T ) {
@@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
716
853
s := newTLSServer (t )
717
854
defer s .Close ()
718
855
719
- certs := x509 .NewCertPool ()
720
- for _ , c := range s .TLS .Certificates {
721
- roots , err := x509 .ParseCertificates (c .Certificate [len (c .Certificate )- 1 ])
722
- if err != nil {
723
- t .Fatalf ("error parsing server's root cert: %v" , err )
724
- }
725
- for _ , root := range roots {
726
- certs .AddCert (root )
727
- }
728
- }
729
-
730
856
d := cstDialer
731
- d .TLSClientConfig = & tls.Config {RootCAs : certs }
857
+ d .TLSClientConfig = & tls.Config {RootCAs : rootCAs ( t , s . Server ) }
732
858
733
859
ws , _ , err := d .DialContext (ctx , s .URL , nil )
734
860
if err != nil {
@@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
766
892
s := newTLSServer (t )
767
893
defer s .Close ()
768
894
769
- certs := x509 .NewCertPool ()
770
- for _ , c := range s .TLS .Certificates {
771
- roots , err := x509 .ParseCertificates (c .Certificate [len (c .Certificate )- 1 ])
772
- if err != nil {
773
- t .Fatalf ("error parsing server's root cert: %v" , err )
774
- }
775
- for _ , root := range roots {
776
- certs .AddCert (root )
777
- }
778
- }
779
-
780
895
d := cstDialer
781
- d .TLSClientConfig = & tls.Config {RootCAs : certs }
896
+ d .TLSClientConfig = & tls.Config {RootCAs : rootCAs ( t , s . Server ) }
782
897
783
898
ws , _ , err := d .DialContext (ctx , s .URL , nil )
784
899
if err != nil {
0 commit comments