Skip to content

Commit ec19bed

Browse files
Add support for ProxyConnectHeader in the dialer.
Add the ability to pass ProxyConnectHeader to the dialer. This set of headers will be used when a CONNECT request is made to an http(s) proxy.
1 parent e064f32 commit ec19bed

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ type Dialer struct {
8787
// If Proxy is nil or returns a nil *URL, no proxy is used.
8888
Proxy func(*http.Request) (*url.URL, error)
8989

90+
// ProxyConnectHeader specifies optional headers to use during proxy connect requests.
91+
ProxyConnectHeader http.Header
92+
9093
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
9194
// If nil, the default configuration is used.
9295
// If NetDialTLSContext is set, Dial assumes the TLS handshake
@@ -416,7 +419,7 @@ func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *u
416419
}
417420
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
418421
if proxyURL != nil {
419-
return proxyFromURL(proxyURL, netDial)
422+
return proxyFromURL(proxyURL, netDial, d.ProxyConnectHeader)
420423
}
421424
return netDial, nil
422425
}

client_server_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,46 @@ func TestProxyAuthorizationDial(t *testing.T) {
242242
sendRecv(t, ws)
243243
}
244244

245+
func TestProxyDialConnectHeaders(t *testing.T) {
246+
s := newServer(t)
247+
defer s.Close()
248+
249+
surl, _ := url.Parse(s.Server.URL)
250+
251+
cstDialer := cstDialer // make local copy for modification on next line.
252+
cstDialer.Proxy = http.ProxyURL(surl)
253+
cstDialer.ProxyConnectHeader = http.Header{"User-Agent": []string{"test-proxy-agent"}}
254+
255+
connect := false
256+
origHandler := s.Server.Config.Handler
257+
258+
// Capture the request Host header.
259+
s.Server.Config.Handler = http.HandlerFunc(
260+
func(w http.ResponseWriter, r *http.Request) {
261+
t.Logf("Request headers: %v", r.Header)
262+
userAgent := r.Header.Get("User-Agent")
263+
if r.Method == http.MethodConnect && userAgent == "test-proxy-agent" {
264+
connect = true
265+
w.WriteHeader(http.StatusOK)
266+
return
267+
}
268+
269+
if !connect {
270+
t.Log("connect with proxy connect headers not received")
271+
http.Error(w, "connect with proxy connect headers not received", http.StatusMethodNotAllowed)
272+
return
273+
}
274+
origHandler.ServeHTTP(w, r)
275+
})
276+
277+
ws, _, err := cstDialer.Dial(s.URL, nil)
278+
if err != nil {
279+
t.Fatalf("Dial: %v", err)
280+
}
281+
defer ws.Close()
282+
sendRecv(t, ws)
283+
}
284+
245285
func TestDial(t *testing.T) {
246286
s := newServer(t)
247287
defer s.Close()

proxy.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (
2828
return fn(ctx, network, addr)
2929
}
3030

31-
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
31+
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc, connectHeader http.Header) (netDialerFunc, error) {
3232
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
33-
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
33+
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial, proxyConnectHeader: connectHeader}).DialContext, nil
3434
}
3535
dialer, err := proxy.FromURL(proxyURL, forwardDial)
3636
if err != nil {
@@ -45,8 +45,13 @@ func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc,
4545
}
4646

4747
type httpProxyDialer struct {
48-
proxyURL *url.URL
49-
forwardDial netDialerFunc
48+
proxyURL *url.URL
49+
forwardDial netDialerFunc
50+
proxyConnectHeader http.Header
51+
}
52+
53+
func (hpd *httpProxyDialer) Dial(network, addr string) (net.Conn, error) {
54+
return hpd.DialContext(context.Background(), network, addr)
5055
}
5156

5257
func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
@@ -56,7 +61,11 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add
5661
return nil, err
5762
}
5863

59-
connectHeader := make(http.Header)
64+
connectHeader := hpd.proxyConnectHeader
65+
if hpd.proxyConnectHeader == nil {
66+
connectHeader := make(http.Header)
67+
}
68+
6069
if user := hpd.proxyURL.User; user != nil {
6170
proxyUser := user.Username()
6271
if proxyPassword, passwordSet := user.Password(); passwordSet {

0 commit comments

Comments
 (0)