@@ -6,12 +6,14 @@ package websocket
6
6
7
7
import (
8
8
"bytes"
9
+ "context"
9
10
"crypto/tls"
10
11
"errors"
11
12
"io"
12
13
"io/ioutil"
13
14
"net"
14
15
"net/http"
16
+ "net/http/httptrace"
15
17
"net/url"
16
18
"strings"
17
19
"time"
@@ -51,6 +53,10 @@ type Dialer struct {
51
53
// NetDial is nil, net.Dial is used.
52
54
NetDial func (network , addr string ) (net.Conn , error )
53
55
56
+ // NetDialContext specifies the dial function for creating TCP connections. If
57
+ // NetDialContext is nil, net.DialContext is used.
58
+ NetDialContext func (ctx context.Context , network , addr string ) (net.Conn , error )
59
+
54
60
// Proxy specifies a function to return a proxy for a given
55
61
// Request. If the function returns a non-nil error, the
56
62
// request is aborted with the provided error.
@@ -95,6 +101,11 @@ type Dialer struct {
95
101
Jar http.CookieJar
96
102
}
97
103
104
+ // Dial creates a new client connection by calling DialContext with a background context.
105
+ func (d * Dialer ) Dial (urlStr string , requestHeader http.Header ) (* Conn , * http.Response , error ) {
106
+ return d .DialContext (urlStr , requestHeader , context .Background ())
107
+ }
108
+
98
109
var errMalformedURL = errors .New ("malformed ws or wss URL" )
99
110
100
111
func hostPortNoPort (u * url.URL ) (hostPort , hostNoPort string ) {
@@ -124,17 +135,18 @@ var DefaultDialer = &Dialer{
124
135
// nilDialer is dialer to use when receiver is nil.
125
136
var nilDialer Dialer = * DefaultDialer
126
137
127
- // Dial creates a new client connection. Use requestHeader to specify the
138
+ // DialContext creates a new client connection. Use requestHeader to specify the
128
139
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
129
140
// Use the response.Header to get the selected subprotocol
130
141
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
131
142
//
143
+ // The context will be used in the request and in the Dialer
144
+ //
132
145
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
133
146
// non-nil *http.Response so that callers can handle redirects, authentication,
134
147
// etcetera. The response body may not contain the entire response and does not
135
148
// need to be closed by the application.
136
- func (d * Dialer ) Dial (urlStr string , requestHeader http.Header ) (* Conn , * http.Response , error ) {
137
-
149
+ func (d * Dialer ) DialContext (urlStr string , requestHeader http.Header , ctx context.Context ) (* Conn , * http.Response , error ) {
138
150
if d == nil {
139
151
d = & nilDialer
140
152
}
@@ -172,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
172
184
Header : make (http.Header ),
173
185
Host : u .Host ,
174
186
}
187
+ req = req .WithContext (ctx )
175
188
176
189
// Set the cookies present in the cookie jar of the dialer
177
190
if d .Jar != nil {
@@ -215,20 +228,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
215
228
req .Header ["Sec-WebSocket-Extensions" ] = []string {"permessage-deflate; server_no_context_takeover; client_no_context_takeover" }
216
229
}
217
230
218
- var deadline time.Time
219
231
if d .HandshakeTimeout != 0 {
220
- deadline = time .Now ().Add (d .HandshakeTimeout )
232
+ var cancel func ()
233
+ ctx , cancel = context .WithTimeout (ctx , d .HandshakeTimeout )
234
+ defer cancel ()
221
235
}
222
236
223
237
// Get network dial function.
224
- netDial := d .NetDial
225
- if netDial == nil {
226
- netDialer := & net.Dialer {Deadline : deadline }
227
- netDial = netDialer .Dial
238
+ var netDial func (network , add string ) (net.Conn , error )
239
+
240
+ if d .NetDialContext != nil {
241
+ netDial = func (network , addr string ) (net.Conn , error ) {
242
+ return d .NetDialContext (ctx , network , addr )
243
+ }
244
+ } else if d .NetDial != nil {
245
+ netDial = d .NetDial
246
+ } else {
247
+ netDialer := & net.Dialer {}
248
+ netDial = func (network , addr string ) (net.Conn , error ) {
249
+ return netDialer .DialContext (ctx , network , addr )
250
+ }
228
251
}
229
252
230
253
// If needed, wrap the dial function to set the connection deadline.
231
- if ! deadline . Equal (time. Time {}) {
254
+ if deadline , ok := ctx . Deadline (); ok {
232
255
forwardDial := netDial
233
256
netDial = func (network , addr string ) (net.Conn , error ) {
234
257
c , err := forwardDial (network , addr )
@@ -260,7 +283,17 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
260
283
}
261
284
262
285
hostPort , hostNoPort := hostPortNoPort (u )
286
+ trace := httptrace .ContextClientTrace (ctx )
287
+ if trace != nil && trace .GetConn != nil {
288
+ trace .GetConn (hostPort )
289
+ }
290
+
263
291
netConn , err := netDial ("tcp" , hostPort )
292
+ if trace != nil && trace .GotConn != nil {
293
+ trace .GotConn (httptrace.GotConnInfo {
294
+ Conn : netConn ,
295
+ })
296
+ }
264
297
if err != nil {
265
298
return nil , nil , err
266
299
}
@@ -278,13 +311,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
278
311
}
279
312
tlsConn := tls .Client (netConn , cfg )
280
313
netConn = tlsConn
281
- if err := tlsConn .Handshake (); err != nil {
282
- return nil , nil , err
314
+
315
+ var err error
316
+ if trace != nil {
317
+ err = doHandshakeWithTrace (trace , tlsConn , cfg )
318
+ } else {
319
+ err = doHandshake (tlsConn , cfg )
283
320
}
284
- if ! cfg .InsecureSkipVerify {
285
- if err := tlsConn .VerifyHostname (cfg .ServerName ); err != nil {
286
- return nil , nil , err
287
- }
321
+
322
+ if err != nil {
323
+ return nil , nil , err
288
324
}
289
325
}
290
326
@@ -294,6 +330,12 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
294
330
return nil , nil , err
295
331
}
296
332
333
+ if trace != nil && trace .GotFirstResponseByte != nil {
334
+ if peek , err := conn .br .Peek (1 ); err == nil && len (peek ) == 1 {
335
+ trace .GotFirstResponseByte ()
336
+ }
337
+ }
338
+
297
339
resp , err := http .ReadResponse (conn .br , req )
298
340
if err != nil {
299
341
return nil , nil , err
@@ -339,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
339
381
netConn = nil // to avoid close in defer.
340
382
return conn , resp , nil
341
383
}
384
+
385
+ func doHandshake (tlsConn * tls.Conn , cfg * tls.Config ) error {
386
+ if err := tlsConn .Handshake (); err != nil {
387
+ return err
388
+ }
389
+ if ! cfg .InsecureSkipVerify {
390
+ if err := tlsConn .VerifyHostname (cfg .ServerName ); err != nil {
391
+ return err
392
+ }
393
+ }
394
+ return nil
395
+ }
0 commit comments