Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ type Dialer struct {
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)

// ProxyConnectHeader specifies optional headers to use during proxy connect requests.
ProxyConnectHeader http.Header

// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake
Expand Down Expand Up @@ -416,7 +419,7 @@ func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *u
}
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
if proxyURL != nil {
return proxyFromURL(proxyURL, netDial)
return proxyFromURL(proxyURL, netDial, d.ProxyConnectHeader)
}
return netDial, nil
}
Expand Down
40 changes: 40 additions & 0 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,46 @@ func TestProxyAuthorizationDial(t *testing.T) {
sendRecv(t, ws)
}

func TestProxyDialConnectHeaders(t *testing.T) {
s := newServer(t)
defer s.Close()

surl, _ := url.Parse(s.Server.URL)

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surl)
cstDialer.ProxyConnectHeader = http.Header{"User-Agent": []string{"test-proxy-agent"}}

connect := false
origHandler := s.Server.Config.Handler

// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
t.Logf("Request headers: %v", r.Header)
userAgent := r.Header.Get("User-Agent")
if r.Method == http.MethodConnect && userAgent == "test-proxy-agent" {
connect = true
w.WriteHeader(http.StatusOK)
return
}

if !connect {
t.Log("connect with proxy connect headers not received")
http.Error(w, "connect with proxy connect headers not received", http.StatusMethodNotAllowed)
return
}
origHandler.ServeHTTP(w, r)
})

ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

func TestDial(t *testing.T) {
s := newServer(t)
defer s.Close()
Expand Down
19 changes: 14 additions & 5 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (
return fn(ctx, network, addr)
}

func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc, connectHeader http.Header) (netDialerFunc, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial, proxyConnectHeader: connectHeader}).DialContext, nil
}
dialer, err := proxy.FromURL(proxyURL, forwardDial)
if err != nil {
Expand All @@ -45,8 +45,13 @@ func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc,
}

type httpProxyDialer struct {
proxyURL *url.URL
forwardDial netDialerFunc
proxyURL *url.URL
forwardDial netDialerFunc
proxyConnectHeader http.Header
}

func (hpd *httpProxyDialer) Dial(network, addr string) (net.Conn, error) {
return hpd.DialContext(context.Background(), network, addr)
}

func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
Expand All @@ -56,7 +61,11 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add
return nil, err
}

connectHeader := make(http.Header)
connectHeader := hpd.proxyConnectHeader
if hpd.proxyConnectHeader == nil {
connectHeader = make(http.Header)
}

if user := hpd.proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
Expand Down