Skip to content

Commit 3dddd1c

Browse files
committed
add tests
1 parent bad5b0a commit 3dddd1c

File tree

2 files changed

+151
-8
lines changed

2 files changed

+151
-8
lines changed

client.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
258258
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
259259
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
260260
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
261-
if d.TLSClientConfig == nil {
262-
tlsClientConfig = &tls.Config{
263-
ServerName: proxyURL.Hostname(),
264-
}
261+
if tlsClientConfig.ServerName == "" {
262+
_, hostNoPort := hostPortNoPort(proxyURL)
263+
tlsClientConfig.ServerName = hostNoPort
265264
}
266265
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
267266
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
@@ -369,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
369368
if proto != "http/1.1" {
370369
return nil, nil, fmt.Errorf(
371370
"websocket: protocol %q was given but is not supported;"+
372-
"sharing tls.Config with net/http Transport can cause this error: %w",
371+
"sharing tlsServerName.Config with net/http Transport can cause this error: %w",
373372
proto, err,
374373
)
375374
}

client_server_test.go

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
8585
return &s
8686
}
8787

88+
type cstProxyServer struct{}
89+
90+
func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
91+
if req.Method != http.MethodConnect {
92+
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
93+
return
94+
}
95+
96+
conn, _, err := w.(http.Hijacker).Hijack()
97+
if err != nil {
98+
http.Error(w, err.Error(), http.StatusInternalServerError)
99+
return
100+
}
101+
defer conn.Close()
102+
103+
upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
104+
if err != nil {
105+
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
106+
return
107+
}
108+
defer upstream.Close()
109+
110+
_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")
111+
112+
wg := sync.WaitGroup{}
113+
wg.Add(2)
114+
go func() {
115+
defer wg.Done()
116+
_, _ = io.Copy(upstream, conn)
117+
}()
118+
go func() {
119+
defer wg.Done()
120+
_, _ = io.Copy(conn, upstream)
121+
}()
122+
wg.Wait()
123+
}
124+
125+
func newProxyServer() *httptest.Server {
126+
return httptest.NewServer(&cstProxyServer{})
127+
}
128+
129+
func newTLSProxyServer() *httptest.Server {
130+
return httptest.NewTLSServer(&cstProxyServer{})
131+
}
132+
88133
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
89134
// Because tests wait for a response from a server, we are guaranteed that
90135
// the wait group count is incremented before the test waits on the group
@@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
165210
}
166211

167212
func TestProxyDial(t *testing.T) {
168-
169213
s := newServer(t)
170214
defer s.Close()
171215

@@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
202246
sendRecv(t, ws)
203247
}
204248

249+
func TestProxyDialer(t *testing.T) {
250+
testcases := []struct {
251+
name string
252+
isTLS bool
253+
tlsServerName string // optional host for tls ServerName
254+
insecureSkipVerify bool
255+
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
256+
}{{
257+
name: "http",
258+
isTLS: false,
259+
}, {
260+
name: "https",
261+
isTLS: true,
262+
}, {
263+
name: "https with ServerName",
264+
isTLS: true,
265+
tlsServerName: "example.com",
266+
}, {
267+
name: "https with insecureSkipVerify",
268+
isTLS: true,
269+
insecureSkipVerify: true,
270+
}, {
271+
name: "https with netDialTLSContext",
272+
isTLS: true,
273+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
274+
dialer := &tls.Dialer{
275+
Config: &tls.Config{
276+
InsecureSkipVerify: true,
277+
},
278+
}
279+
return dialer.DialContext(ctx, network, addr)
280+
},
281+
}}
282+
283+
for _, tc := range testcases {
284+
t.Run(tc.name, func(tt *testing.T) {
285+
s := newServer(tt)
286+
defer s.Close()
287+
288+
var ps *httptest.Server
289+
if tc.isTLS {
290+
ps = newTLSProxyServer()
291+
} else {
292+
ps = newProxyServer()
293+
}
294+
295+
psurl, _ := url.Parse(ps.URL)
296+
297+
netDialCalled := false
298+
299+
cstDialer := cstDialer // make local copy for modification on next line.
300+
cstDialer.Proxy = http.ProxyURL(psurl)
301+
if tc.isTLS {
302+
cstDialer.TLSClientConfig = &tls.Config{
303+
RootCAs: rootCAs(tt, ps),
304+
ServerName: tc.tlsServerName,
305+
InsecureSkipVerify: tc.insecureSkipVerify,
306+
}
307+
if tc.netDialTLSContext != nil {
308+
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
309+
netDialCalled = true
310+
return tc.netDialTLSContext(ctx, network, addr)
311+
}
312+
} else {
313+
netDialCalled = true
314+
}
315+
} else {
316+
netDialCalled = true
317+
}
318+
319+
connect := false
320+
origHandler := ps.Config.Handler
321+
322+
// Capture the request Host header.
323+
ps.Config.Handler = http.HandlerFunc(
324+
func(w http.ResponseWriter, r *http.Request) {
325+
if r.Method == http.MethodConnect {
326+
connect = true
327+
}
328+
329+
origHandler.ServeHTTP(w, r)
330+
})
331+
332+
ws, _, err := cstDialer.Dial(s.URL, nil)
333+
if err != nil {
334+
tt.Fatalf("Dial: %v", err)
335+
}
336+
defer ws.Close()
337+
sendRecv(tt, ws)
338+
339+
if !connect {
340+
tt.Error("connect not received")
341+
}
342+
if !netDialCalled {
343+
tt.Error("netDialTLSContext not called")
344+
}
345+
})
346+
}
347+
}
348+
205349
func TestProxyAuthorizationDial(t *testing.T) {
206350
s := newServer(t)
207351
defer s.Close()
@@ -652,7 +796,7 @@ func TestHost(t *testing.T) {
652796
server *httptest.Server // server to use
653797
url string // host for request URI
654798
header string // optional request host header
655-
tls string // optional host for tls ServerName
799+
tls string // optional host for tlsServerName ServerName
656800
wantAddr string // expected host for dial
657801
wantHeader string // expected request header on server
658802
insecureSkipVerify bool
@@ -759,7 +903,7 @@ func TestHost(t *testing.T) {
759903
}
760904

761905
check := func(protos map[*httptest.Server]string) {
762-
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
906+
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tlsServerName.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
763907
if gotAddr != tt.wantAddr {
764908
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
765909
}

0 commit comments

Comments
 (0)