Skip to content

Commit ceae452

Browse files
juliensgaryburd
authored andcommitted
Add context in the Dialer
1 parent b378cae commit ceae452

File tree

4 files changed

+230
-16
lines changed

4 files changed

+230
-16
lines changed

client.go

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ package websocket
66

77
import (
88
"bytes"
9+
"context"
910
"crypto/tls"
1011
"errors"
1112
"io"
1213
"io/ioutil"
1314
"net"
1415
"net/http"
16+
"net/http/httptrace"
1517
"net/url"
1618
"strings"
1719
"time"
@@ -51,6 +53,10 @@ type Dialer struct {
5153
// NetDial is nil, net.Dial is used.
5254
NetDial func(network, addr string) (net.Conn, error)
5355

56+
// NetDialContext specifies the dial function for creating TCP connections. If
57+
// NetDialContext is nil, net.DialContext is used.
58+
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
59+
5460
// Proxy specifies a function to return a proxy for a given
5561
// Request. If the function returns a non-nil error, the
5662
// request is aborted with the provided error.
@@ -95,6 +101,11 @@ type Dialer struct {
95101
Jar http.CookieJar
96102
}
97103

104+
// Dial creates a new client connection by calling DialContext with a background context.
105+
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
106+
return d.DialContext(urlStr, requestHeader, context.Background())
107+
}
108+
98109
var errMalformedURL = errors.New("malformed ws or wss URL")
99110

100111
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
@@ -124,17 +135,18 @@ var DefaultDialer = &Dialer{
124135
// nilDialer is dialer to use when receiver is nil.
125136
var nilDialer Dialer = *DefaultDialer
126137

127-
// Dial creates a new client connection. Use requestHeader to specify the
138+
// DialContext creates a new client connection. Use requestHeader to specify the
128139
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
129140
// Use the response.Header to get the selected subprotocol
130141
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
131142
//
143+
// The context will be used in the request and in the Dialer
144+
//
132145
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
133146
// non-nil *http.Response so that callers can handle redirects, authentication,
134147
// etcetera. The response body may not contain the entire response and does not
135148
// need to be closed by the application.
136-
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
137-
149+
func (d *Dialer) DialContext(urlStr string, requestHeader http.Header, ctx context.Context) (*Conn, *http.Response, error) {
138150
if d == nil {
139151
d = &nilDialer
140152
}
@@ -172,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
172184
Header: make(http.Header),
173185
Host: u.Host,
174186
}
187+
req = req.WithContext(ctx)
175188

176189
// Set the cookies present in the cookie jar of the dialer
177190
if d.Jar != nil {
@@ -215,20 +228,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
215228
req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
216229
}
217230

218-
var deadline time.Time
219231
if d.HandshakeTimeout != 0 {
220-
deadline = time.Now().Add(d.HandshakeTimeout)
232+
var cancel func()
233+
ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
234+
defer cancel()
221235
}
222236

223237
// Get network dial function.
224-
netDial := d.NetDial
225-
if netDial == nil {
226-
netDialer := &net.Dialer{Deadline: deadline}
227-
netDial = netDialer.Dial
238+
var netDial func(network, add string) (net.Conn, error)
239+
240+
if d.NetDialContext != nil {
241+
netDial = func(network, addr string) (net.Conn, error) {
242+
return d.NetDialContext(ctx, network, addr)
243+
}
244+
} else if d.NetDial != nil {
245+
netDial = d.NetDial
246+
} else {
247+
netDialer := &net.Dialer{}
248+
netDial = func(network, addr string) (net.Conn, error) {
249+
return netDialer.DialContext(ctx, network, addr)
250+
}
228251
}
229252

230253
// If needed, wrap the dial function to set the connection deadline.
231-
if !deadline.Equal(time.Time{}) {
254+
if deadline, ok := ctx.Deadline(); ok {
232255
forwardDial := netDial
233256
netDial = func(network, addr string) (net.Conn, error) {
234257
c, err := forwardDial(network, addr)
@@ -260,7 +283,17 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
260283
}
261284

262285
hostPort, hostNoPort := hostPortNoPort(u)
286+
trace := httptrace.ContextClientTrace(ctx)
287+
if trace != nil && trace.GetConn != nil {
288+
trace.GetConn(hostPort)
289+
}
290+
263291
netConn, err := netDial("tcp", hostPort)
292+
if trace != nil && trace.GotConn != nil {
293+
trace.GotConn(httptrace.GotConnInfo{
294+
Conn: netConn,
295+
})
296+
}
264297
if err != nil {
265298
return nil, nil, err
266299
}
@@ -278,13 +311,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
278311
}
279312
tlsConn := tls.Client(netConn, cfg)
280313
netConn = tlsConn
281-
if err := tlsConn.Handshake(); err != nil {
282-
return nil, nil, err
314+
315+
var err error
316+
if trace != nil {
317+
err = doHandshakeWithTrace(trace, tlsConn, cfg)
318+
} else {
319+
err = doHandshake(tlsConn, cfg)
283320
}
284-
if !cfg.InsecureSkipVerify {
285-
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
286-
return nil, nil, err
287-
}
321+
322+
if err != nil {
323+
return nil, nil, err
288324
}
289325
}
290326

@@ -294,6 +330,12 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
294330
return nil, nil, err
295331
}
296332

333+
if trace != nil && trace.GotFirstResponseByte != nil {
334+
if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
335+
trace.GotFirstResponseByte()
336+
}
337+
}
338+
297339
resp, err := http.ReadResponse(conn.br, req)
298340
if err != nil {
299341
return nil, nil, err
@@ -339,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
339381
netConn = nil // to avoid close in defer.
340382
return conn, resp, nil
341383
}
384+
385+
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
386+
if err := tlsConn.Handshake(); err != nil {
387+
return err
388+
}
389+
if !cfg.InsecureSkipVerify {
390+
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
391+
return err
392+
}
393+
}
394+
return nil
395+
}

client_server_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package websocket
66

77
import (
88
"bytes"
9+
"context"
910
"crypto/tls"
1011
"crypto/x509"
1112
"encoding/base64"
@@ -16,6 +17,7 @@ import (
1617
"net/http"
1718
"net/http/cookiejar"
1819
"net/http/httptest"
20+
"net/http/httptrace"
1921
"net/url"
2022
"reflect"
2123
"strings"
@@ -40,6 +42,12 @@ var cstDialer = Dialer{
4042
HandshakeTimeout: 30 * time.Second,
4143
}
4244

45+
var cstDialerWithoutHandshakeTimeout = Dialer{
46+
Subprotocols: []string{"p1", "p2"},
47+
ReadBufferSize: 1024,
48+
WriteBufferSize: 1024,
49+
}
50+
4351
type cstHandler struct{ *testing.T }
4452

4553
type cstServer struct {
@@ -403,6 +411,26 @@ func TestHandshakeTimeout(t *testing.T) {
403411
ws.Close()
404412
}
405413

414+
func TestHandshakeTimeoutInContext(t *testing.T) {
415+
s := newServer(t)
416+
defer s.Close()
417+
418+
d := cstDialerWithoutHandshakeTimeout
419+
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
420+
netDialer := &net.Dialer{}
421+
c, err := netDialer.DialContext(ctx, n, a)
422+
return &requireDeadlineNetConn{c: c, t: t}, err
423+
}
424+
425+
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
426+
defer cancel()
427+
ws, _, err := d.DialContext(s.URL, nil, ctx)
428+
if err != nil {
429+
t.Fatal("Dial:", err)
430+
}
431+
ws.Close()
432+
}
433+
406434
func TestDialBadScheme(t *testing.T) {
407435
s := newServer(t)
408436
defer s.Close()
@@ -659,3 +687,104 @@ func TestSocksProxyDial(t *testing.T) {
659687
defer ws.Close()
660688
sendRecv(t, ws)
661689
}
690+
691+
func TestTracingDialWithContext(t *testing.T) {
692+
693+
var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
694+
trace := &httptrace.ClientTrace{
695+
WroteHeaders: func() {
696+
headersWrote = true
697+
},
698+
WroteRequest: func(httptrace.WroteRequestInfo) {
699+
requestWrote = true
700+
},
701+
GetConn: func(hostPort string) {
702+
getConn = true
703+
},
704+
GotConn: func(info httptrace.GotConnInfo) {
705+
gotConn = true
706+
},
707+
ConnectDone: func(network, addr string, err error) {
708+
connectDone = true
709+
},
710+
GotFirstResponseByte: func() {
711+
gotFirstResponseByte = true
712+
},
713+
}
714+
ctx := httptrace.WithClientTrace(context.Background(), trace)
715+
716+
s := newTLSServer(t)
717+
defer s.Close()
718+
719+
certs := x509.NewCertPool()
720+
for _, c := range s.TLS.Certificates {
721+
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
722+
if err != nil {
723+
t.Fatalf("error parsing server's root cert: %v", err)
724+
}
725+
for _, root := range roots {
726+
certs.AddCert(root)
727+
}
728+
}
729+
730+
d := cstDialer
731+
d.TLSClientConfig = &tls.Config{RootCAs: certs}
732+
733+
ws, _, err := d.DialContext(s.URL, nil, ctx)
734+
if err != nil {
735+
t.Fatalf("Dial: %v", err)
736+
}
737+
738+
if !headersWrote {
739+
t.Fatal("Headers was not written")
740+
}
741+
if !requestWrote {
742+
t.Fatal("Request was not written")
743+
}
744+
if !getConn {
745+
t.Fatal("getConn was not called")
746+
}
747+
if !gotConn {
748+
t.Fatal("gotConn was not called")
749+
}
750+
if !connectDone {
751+
t.Fatal("connectDone was not called")
752+
}
753+
if !gotFirstResponseByte {
754+
t.Fatal("GotFirstResponseByte was not called")
755+
}
756+
757+
defer ws.Close()
758+
sendRecv(t, ws)
759+
}
760+
761+
func TestEmptyTracingDialWithContext(t *testing.T) {
762+
763+
trace := &httptrace.ClientTrace{}
764+
ctx := httptrace.WithClientTrace(context.Background(), trace)
765+
766+
s := newTLSServer(t)
767+
defer s.Close()
768+
769+
certs := x509.NewCertPool()
770+
for _, c := range s.TLS.Certificates {
771+
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
772+
if err != nil {
773+
t.Fatalf("error parsing server's root cert: %v", err)
774+
}
775+
for _, root := range roots {
776+
certs.AddCert(root)
777+
}
778+
}
779+
780+
d := cstDialer
781+
d.TLSClientConfig = &tls.Config{RootCAs: certs}
782+
783+
ws, _, err := d.DialContext(s.URL, nil, ctx)
784+
if err != nil {
785+
t.Fatalf("Dial: %v", err)
786+
}
787+
788+
defer ws.Close()
789+
sendRecv(t, ws)
790+
}

trace.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// +build go1.8
2+
3+
package websocket
4+
5+
import (
6+
"crypto/tls"
7+
"net/http/httptrace"
8+
)
9+
10+
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
11+
if trace.TLSHandshakeStart != nil {
12+
trace.TLSHandshakeStart()
13+
}
14+
err := doHandshake(tlsConn, cfg)
15+
if trace.TLSHandshakeDone != nil {
16+
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
17+
}
18+
return err
19+
}

trace_17.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// +build !go1.8
2+
3+
package websocket
4+
5+
import (
6+
"crypto/tls"
7+
"net/http/httptrace"
8+
)
9+
10+
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
11+
return doHandshake(tlsConn, cfg)
12+
}

0 commit comments

Comments
 (0)