Skip to content

Commit 865c975

Browse files
committed
Fixes case where NetDialTLS is set with HTTPS proxy
1 parent 25c6c2a commit 865c975

File tree

2 files changed

+196
-1
lines changed

2 files changed

+196
-1
lines changed

client.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ func (d *Dialer) proxyNetDial(req *http.Request, netDial netDialerFunc) (netDial
426426
if err != nil {
427427
return nil, err
428428
}
429-
// Request should not be proxied; use originial dial function.
429+
// Request should *not* be proxied; use originial dial function.
430430
if proxyURL == nil {
431431
return netDial, nil
432432
}
@@ -439,6 +439,8 @@ func (d *Dialer) proxyNetDial(req *http.Request, netDial netDialerFunc) (netDial
439439
if proxyURL.Scheme == "https" {
440440
if d.NetDialTLSContext != nil {
441441
netDial = d.NetDialTLSContext
442+
// Ensures later TLS handshake occurs to backend over proxied connection.
443+
d.NetDialTLSContext = nil
442444
} else if d.TLSClientConfig == nil {
443445
return nil, errors.New("HTTPS proxy requires TLS dial function or TLS client config")
444446
} else {

client_proxy_server_test.go

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ var proxyHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Reques
101101
<-done
102102
})
103103

104+
// Permutation 1
105+
//
106+
// Proxy: HTTP
107+
// Backend: HTTP
104108
func TestHTTPProxyAndBackend(t *testing.T) {
105109
// Start the websocket server, which echoes data back to sender.
106110
websocketServer := httptest.NewServer(websocketEchoHandler)
@@ -151,6 +155,11 @@ func TestHTTPProxyAndBackend(t *testing.T) {
151155
}
152156
}
153157

158+
// Permutation 2
159+
//
160+
// Proxy: HTTP
161+
// Backend: HTTP
162+
// DialFn: NetDial (dials proxy)
154163
func TestHTTPProxyWithNetDial(t *testing.T) {
155164
// Start the websocket server, which echoes data back to sender.
156165
websocketServer := httptest.NewServer(websocketEchoHandler)
@@ -209,6 +218,11 @@ func TestHTTPProxyWithNetDial(t *testing.T) {
209218
}
210219
}
211220

221+
// Permutation 3
222+
//
223+
// Proxy: HTTP
224+
// Backend: HTTP
225+
// DialFn: NetDialContext (dials proxy)
212226
func TestHTTPProxyWithNetDialContext(t *testing.T) {
213227
// Start the websocket server, which echoes data back to sender.
214228
websocketServer := httptest.NewServer(websocketEchoHandler)
@@ -267,6 +281,11 @@ func TestHTTPProxyWithNetDialContext(t *testing.T) {
267281
}
268282
}
269283

284+
// Permutation 4
285+
//
286+
// Proxy: HTTPS
287+
// Backend: HTTPS
288+
// TLS Config: set (used for both proxy and backend TLS)
270289
func TestHTTPSProxyAndBackend(t *testing.T) {
271290
// Start the websocket server running TLS.
272291
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
@@ -335,6 +354,12 @@ func TestHTTPSProxyAndBackend(t *testing.T) {
335354
}
336355
}
337356

357+
// Permutation 5
358+
//
359+
// Proxy: HTTPS
360+
// Backend: HTTPS
361+
// DialFn: NetDial (used to dial proxy)
362+
// TLS Config: set (used for both proxy and backend TLS)
338363
func TestHTTPSProxyUsingNetDial(t *testing.T) {
339364
// Start the websocket server running TLS.
340365
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
@@ -413,6 +438,12 @@ func TestHTTPSProxyUsingNetDial(t *testing.T) {
413438
}
414439
}
415440

441+
// Permutation 6
442+
//
443+
// Proxy: HTTPS
444+
// Backend: HTTPS
445+
// DialFn: NetDialContext (used to dial proxy)
446+
// TLS Config: set (used for both proxy and backend TLS)
416447
func TestHTTPSProxyUsingNetDialContext(t *testing.T) {
417448
// Start the websocket server running TLS.
418449
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
@@ -491,6 +522,168 @@ func TestHTTPSProxyUsingNetDialContext(t *testing.T) {
491522
}
492523
}
493524

525+
// Permutation 7
526+
//
527+
// Proxy: HTTPS
528+
// Backend: HTTPS
529+
// DialFn: NetDialTLSContext (used for proxy TLS)
530+
// TLS Config: set (used for backend TLS)
531+
func TestHTTPSProxyUsingNetDialTLSContext(t *testing.T) {
532+
// Start the websocket server running TLS.
533+
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
534+
if err != nil {
535+
t.Fatalf("error creating TLS key pair: %v", err)
536+
}
537+
websocketServer := httptest.NewUnstartedServer(websocketEchoHandler)
538+
websocketServer.TLS = &tls.Config{
539+
Certificates: []tls.Certificate{cert},
540+
}
541+
websocketServer.StartTLS()
542+
defer websocketServer.Close()
543+
websocketURL, err := url.Parse(websocketServer.URL)
544+
if err != nil {
545+
t.Fatalf("error parsing websocket server URL: %v", err)
546+
}
547+
// Start the proxy server running TLS.
548+
var proxyCalled atomic.Int64
549+
proxyServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
550+
proxyCalled.Add(1)
551+
proxyHandler.ServeHTTP(w, req)
552+
}))
553+
proxyServer.TLS = &tls.Config{
554+
Certificates: []tls.Certificate{cert},
555+
}
556+
proxyServer.StartTLS()
557+
defer proxyServer.Close()
558+
proxyServerURL, err := url.Parse(proxyServer.URL)
559+
if err != nil {
560+
t.Fatalf("error parsing websocket server URL: %v", err)
561+
}
562+
// Dial the websocket server to create the websocket connection through
563+
// the proxy. The "NetDialTLSContext" function to dials the proxy and
564+
// performs the TLS handshake. NOTE: Subsequent TLS handshake to backend
565+
// (over proxied connection) uses TLSClientConfig for handshake.
566+
certPool := x509.NewCertPool()
567+
certPool.AppendCertsFromPEM(localhostCert)
568+
tlsConfig := &tls.Config{RootCAs: certPool}
569+
var netDialCalled atomic.Int64
570+
dialer := Dialer{
571+
Proxy: http.ProxyURL(proxyServerURL),
572+
// Dial and TLS handshake function to proxy.
573+
NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
574+
netDialCalled.Add(1)
575+
return tls.Dial(network, addr, tlsConfig)
576+
},
577+
// Used for second TLS handshake to backend server over previously
578+
// established proxied connection.
579+
TLSClientConfig: tlsConfig,
580+
Subprotocols: []string{subprotocolv1},
581+
}
582+
websocketURL.Scheme = "wss"
583+
wsClient, _, err := dialer.Dial(websocketURL.String(), nil)
584+
if err != nil {
585+
t.Fatalf("websocket dial error: %v", err)
586+
}
587+
// Generate random data to send/receive over websocket connection.
588+
randomSize := 128 * 1024
589+
randomData := make([]byte, randomSize)
590+
if _, err := rand.Read(randomData); err != nil {
591+
t.Errorf("unexpected error reading random data: %v", err)
592+
}
593+
err = wsClient.WriteMessage(BinaryMessage, randomData)
594+
if err != nil {
595+
t.Errorf("websocket write error: %v", err)
596+
}
597+
// Read all the data from the websocket connection, then verify
598+
_, received, err := wsClient.ReadMessage()
599+
if !bytes.Equal(randomData, received) {
600+
t.Errorf("unexpected data received: %d bytes sent, %d bytes received",
601+
len(received), len(randomData))
602+
}
603+
if e, a := int64(1), netDialCalled.Load(); e != a {
604+
t.Errorf("netDial not called")
605+
}
606+
if e, a := int64(1), proxyCalled.Load(); e != a {
607+
t.Errorf("proxy not called")
608+
}
609+
}
610+
611+
// Permutation 8
612+
//
613+
// Proxy: HTTPS
614+
// Backend: HTTP
615+
// DialFn: NetDialTLSContext (used for proxy TLS)
616+
func TestHTTPSProxyUsingNetDialTLSContextWithHTTPBackend(t *testing.T) {
617+
// Start the websocket server.
618+
websocketServer := httptest.NewUnstartedServer(websocketEchoHandler)
619+
websocketServer.Start()
620+
defer websocketServer.Close()
621+
websocketURL, err := url.Parse(websocketServer.URL)
622+
if err != nil {
623+
t.Fatalf("error parsing websocket server URL: %v", err)
624+
}
625+
// Start the proxy server running TLS.
626+
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
627+
if err != nil {
628+
t.Fatalf("error creating TLS key pair: %v", err)
629+
}
630+
var proxyCalled atomic.Int64
631+
proxyServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
632+
proxyCalled.Add(1)
633+
proxyHandler.ServeHTTP(w, req)
634+
}))
635+
proxyServer.TLS = &tls.Config{
636+
Certificates: []tls.Certificate{cert},
637+
}
638+
proxyServer.StartTLS()
639+
defer proxyServer.Close()
640+
proxyServerURL, err := url.Parse(proxyServer.URL)
641+
if err != nil {
642+
t.Fatalf("error parsing websocket server URL: %v", err)
643+
}
644+
// Dials websocket backend through HTTPS proxy, using NetDialTLSContext.
645+
certPool := x509.NewCertPool()
646+
certPool.AppendCertsFromPEM(localhostCert)
647+
tlsConfig := &tls.Config{RootCAs: certPool}
648+
var netDialCalled atomic.Int64
649+
dialer := Dialer{
650+
Proxy: http.ProxyURL(proxyServerURL),
651+
// Dial and TLS handshake function to proxy.
652+
NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
653+
netDialCalled.Add(1)
654+
return tls.Dial(network, addr, tlsConfig)
655+
},
656+
Subprotocols: []string{subprotocolv1},
657+
}
658+
websocketURL.Scheme = "ws"
659+
wsClient, _, err := dialer.Dial(websocketURL.String(), nil)
660+
if err != nil {
661+
t.Fatalf("websocket dial error: %v", err)
662+
}
663+
// Generate random data to send/receive over websocket connection.
664+
randomSize := 128 * 1024
665+
randomData := make([]byte, randomSize)
666+
if _, err := rand.Read(randomData); err != nil {
667+
t.Errorf("unexpected error reading random data: %v", err)
668+
}
669+
err = wsClient.WriteMessage(BinaryMessage, randomData)
670+
if err != nil {
671+
t.Errorf("websocket write error: %v", err)
672+
}
673+
// Read all the data from the websocket connection, then verify
674+
_, received, err := wsClient.ReadMessage()
675+
if !bytes.Equal(randomData, received) {
676+
t.Errorf("unexpected data received: %d bytes sent, %d bytes received",
677+
len(received), len(randomData))
678+
}
679+
if e, a := int64(1), netDialCalled.Load(); e != a {
680+
t.Errorf("netDial not called")
681+
}
682+
if e, a := int64(1), proxyCalled.Load(); e != a {
683+
t.Errorf("proxy not called")
684+
}
685+
}
686+
494687
// localhostCert was generated from crypto/tls/generate_cert.go with the following command:
495688
//
496689
// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h

0 commit comments

Comments
 (0)