@@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
85
85
return & s
86
86
}
87
87
88
+ type cstProxyServer struct {}
89
+
90
+ func (s * cstProxyServer ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
91
+ if req .Method != http .MethodConnect {
92
+ http .Error (w , "method not allowed" , http .StatusMethodNotAllowed )
93
+ return
94
+ }
95
+
96
+ conn , _ , err := w .(http.Hijacker ).Hijack ()
97
+ if err != nil {
98
+ http .Error (w , err .Error (), http .StatusInternalServerError )
99
+ return
100
+ }
101
+ defer conn .Close ()
102
+
103
+ upstream , err := (& net.Dialer {}).DialContext (req .Context (), "tcp" , req .URL .Host )
104
+ if err != nil {
105
+ _ , _ = fmt .Fprintf (conn , "HTTP/1.1 502 Bad Gateway\r \n \r \n " )
106
+ return
107
+ }
108
+ defer upstream .Close ()
109
+
110
+ _ , _ = fmt .Fprintf (conn , "HTTP/1.1 200 Connection established\r \n \r \n " )
111
+
112
+ wg := sync.WaitGroup {}
113
+ wg .Add (2 )
114
+ go func () {
115
+ defer wg .Done ()
116
+ _ , _ = io .Copy (upstream , conn )
117
+ }()
118
+ go func () {
119
+ defer wg .Done ()
120
+ _ , _ = io .Copy (conn , upstream )
121
+ }()
122
+ wg .Wait ()
123
+ }
124
+
125
+ func newProxyServer () * httptest.Server {
126
+ return httptest .NewServer (& cstProxyServer {})
127
+ }
128
+
129
+ func newTLSProxyServer () * httptest.Server {
130
+ return httptest .NewTLSServer (& cstProxyServer {})
131
+ }
132
+
88
133
func (t cstHandler ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
89
134
// Because tests wait for a response from a server, we are guaranteed that
90
135
// the wait group count is incremented before the test waits on the group
@@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
165
210
}
166
211
167
212
func TestProxyDial (t * testing.T ) {
168
-
169
213
s := newServer (t )
170
214
defer s .Close ()
171
215
@@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
202
246
sendRecv (t , ws )
203
247
}
204
248
249
+ func TestProxyDialer (t * testing.T ) {
250
+ testcases := []struct {
251
+ name string
252
+ isTLS bool
253
+ tlsServerName string
254
+ insecureSkipVerify bool
255
+ netDialTLSContext func (ctx context.Context , network , addr string ) (net.Conn , error )
256
+ }{{
257
+ name : "http" ,
258
+ isTLS : false ,
259
+ }, {
260
+ name : "https" ,
261
+ isTLS : true ,
262
+ }, {
263
+ name : "https with ServerName" ,
264
+ isTLS : true ,
265
+ tlsServerName : "example.com" ,
266
+ }, {
267
+ name : "https with insecureSkipVerify" ,
268
+ isTLS : true ,
269
+ insecureSkipVerify : true ,
270
+ }, {
271
+ name : "https with netDialTLSContext" ,
272
+ isTLS : true ,
273
+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
274
+ dialer := & tls.Dialer {
275
+ Config : & tls.Config {
276
+ InsecureSkipVerify : true ,
277
+ },
278
+ }
279
+ return dialer .DialContext (ctx , network , addr )
280
+ },
281
+ }}
282
+
283
+ for _ , tc := range testcases {
284
+ t .Run (tc .name , func (tt * testing.T ) {
285
+ s := newServer (tt )
286
+ defer s .Close ()
287
+
288
+ var ps * httptest.Server
289
+ if tc .isTLS {
290
+ ps = newTLSProxyServer ()
291
+ } else {
292
+ ps = newProxyServer ()
293
+ }
294
+
295
+ psurl , _ := url .Parse (ps .URL )
296
+
297
+ netDialCalled := false
298
+
299
+ cstDialer := cstDialer // make local copy for modification on next line.
300
+ cstDialer .Proxy = http .ProxyURL (psurl )
301
+ if tc .isTLS {
302
+ cstDialer .TLSClientConfig = & tls.Config {
303
+ RootCAs : rootCAs (tt , ps ),
304
+ ServerName : tc .tlsServerName ,
305
+ InsecureSkipVerify : tc .insecureSkipVerify ,
306
+ }
307
+ if tc .netDialTLSContext != nil {
308
+ cstDialer .NetDialTLSContext = func (ctx context.Context , network , addr string ) (net.Conn , error ) {
309
+ netDialCalled = true
310
+ return tc .netDialTLSContext (ctx , network , addr )
311
+ }
312
+ } else {
313
+ netDialCalled = true
314
+ }
315
+ } else {
316
+ netDialCalled = true
317
+ }
318
+
319
+ connect := false
320
+ origHandler := ps .Config .Handler
321
+
322
+ // Capture the request Host header.
323
+ ps .Config .Handler = http .HandlerFunc (
324
+ func (w http.ResponseWriter , r * http.Request ) {
325
+ if r .Method == http .MethodConnect {
326
+ connect = true
327
+ }
328
+
329
+ origHandler .ServeHTTP (w , r )
330
+ })
331
+
332
+ ws , _ , err := cstDialer .Dial (s .URL , nil )
333
+ if err != nil {
334
+ tt .Fatalf ("Dial: %v" , err )
335
+ }
336
+ defer ws .Close ()
337
+ sendRecv (tt , ws )
338
+
339
+ if ! connect {
340
+ tt .Error ("connect not received" )
341
+ }
342
+ if ! netDialCalled {
343
+ tt .Error ("netDialTLSContext not called" )
344
+ }
345
+ })
346
+ }
347
+ }
348
+
205
349
func TestProxyAuthorizationDial (t * testing.T ) {
206
350
s := newServer (t )
207
351
defer s .Close ()
0 commit comments