Skip to content

Commit c359b81

Browse files
committed
Add HTTP proxy support for tunnel connections
1 parent 1cedefa commit c359b81

File tree

5 files changed

+334
-22
lines changed

5 files changed

+334
-22
lines changed

ingress/origin_dialer.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/rs/zerolog"
12+
"golang.org/x/net/proxy"
1213
)
1314

1415
// OriginTCPDialer provides a TCP dial operation to a requested address.
@@ -115,20 +116,33 @@ func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) {
115116
}
116117

117118
type Dialer struct {
118-
Dialer net.Dialer
119+
Dialer proxy.Dialer
119120
}
120121

121122
func NewDialer(config WarpRoutingConfig) *Dialer {
123+
// Create proxy-aware dialer for warp routing
124+
proxyDialer := createProxyDialer(config.ConnectTimeout.Duration, config.TCPKeepAlive.Duration, nil)
122125
return &Dialer{
123-
Dialer: net.Dialer{
124-
Timeout: config.ConnectTimeout.Duration,
125-
KeepAlive: config.TCPKeepAlive.Duration,
126-
},
126+
Dialer: proxyDialer,
127127
}
128128
}
129129

130+
// createProxyDialer creates a proxy.Dialer that respects proxy environment variables
131+
func createProxyDialer(timeout, keepAlive time.Duration, logger *zerolog.Logger) proxy.Dialer {
132+
// Reuse the unified proxy logic from origin_service.go
133+
return newProxyAwareDialer(timeout, keepAlive, logger)
134+
}
135+
130136
func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) {
131-
conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String())
137+
var conn net.Conn
138+
var err error
139+
140+
if contextDialer, ok := d.Dialer.(proxy.ContextDialer); ok {
141+
conn, err = contextDialer.DialContext(ctx, "tcp", dest.String())
142+
} else {
143+
conn, err = d.Dialer.Dial("tcp", dest.String())
144+
}
145+
132146
if err != nil {
133147
return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err)
134148
}

ingress/origin_proxy.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99

1010
"github.com/rs/zerolog"
11+
"golang.org/x/net/proxy"
1112
)
1213

1314
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@@ -86,7 +87,15 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
8687
}
8788

8889
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
89-
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
90+
var conn net.Conn
91+
var err error
92+
93+
if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok {
94+
conn, err = contextDialer.DialContext(ctx, "tcp", dest)
95+
} else {
96+
conn, err = o.dialer.Dial("tcp", dest)
97+
}
98+
9099
if err != nil {
91100
return nil, err
92101
}
@@ -105,7 +114,13 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string,
105114
dest = o.dest
106115
}
107116

108-
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
117+
var conn net.Conn
118+
if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok {
119+
conn, err = contextDialer.DialContext(ctx, "tcp", dest)
120+
} else {
121+
conn, err = o.dialer.Dial("tcp", dest)
122+
}
123+
109124
if err != nil {
110125
return nil, err
111126
}

ingress/origin_proxy_test.go

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"net/http"
99
"net/http/httptest"
1010
"net/url"
11+
"os"
1112
"testing"
13+
"time"
1214

1315
"github.com/stretchr/testify/assert"
1416
"github.com/stretchr/testify/require"
@@ -24,7 +26,10 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
2426
listenerClosed := make(chan struct{})
2527
tcpListenRoutine(originListener, listenerClosed)
2628

27-
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
29+
rawTCPService := &rawTCPService{
30+
name: ServiceWarpRouting,
31+
dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil),
32+
}
2833

2934
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
3035
require.NoError(t, err)
@@ -40,6 +45,148 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
4045
require.Error(t, err)
4146
}
4247

48+
func TestProxyAwareDialer(t *testing.T) {
49+
tests := []struct {
50+
name string
51+
httpProxy string
52+
httpsProxy string
53+
socksProxy string
54+
expectDirect bool
55+
expectProxy bool
56+
}{
57+
{
58+
name: "no proxy configured",
59+
expectDirect: true,
60+
},
61+
{
62+
name: "HTTP proxy configured",
63+
httpProxy: "http://proxy.example.com:8080",
64+
expectProxy: true,
65+
},
66+
{
67+
name: "HTTPS proxy configured",
68+
httpsProxy: "http://proxy.example.com:8080",
69+
expectProxy: true,
70+
},
71+
{
72+
name: "SOCKS proxy configured",
73+
socksProxy: "socks5://proxy.example.com:1080",
74+
expectProxy: true,
75+
},
76+
}
77+
78+
for _, tt := range tests {
79+
t.Run(tt.name, func(t *testing.T) {
80+
origHTTP := os.Getenv("HTTP_PROXY")
81+
origHTTPS := os.Getenv("HTTPS_PROXY")
82+
origSOCKS := os.Getenv("ALL_PROXY")
83+
84+
defer func() {
85+
os.Setenv("HTTP_PROXY", origHTTP)
86+
os.Setenv("HTTPS_PROXY", origHTTPS)
87+
os.Setenv("ALL_PROXY", origSOCKS)
88+
}()
89+
90+
os.Setenv("HTTP_PROXY", tt.httpProxy)
91+
os.Setenv("HTTPS_PROXY", tt.httpsProxy)
92+
os.Setenv("ALL_PROXY", tt.socksProxy)
93+
94+
dialer := newProxyAwareDialer(30*time.Second, 30*time.Second, TestLogger)
95+
assert.NotNil(t, dialer)
96+
97+
if tt.expectDirect {
98+
_, ok := dialer.(*net.Dialer)
99+
assert.True(t, ok, "Expected net.Dialer when no proxy configured")
100+
} else if tt.expectProxy {
101+
assert.NotNil(t, dialer, "Expected proxy dialer when proxy configured")
102+
}
103+
})
104+
}
105+
}
106+
107+
func TestProxyAwareDialerHTTPConnect(t *testing.T) {
108+
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109+
if r.Method != "CONNECT" {
110+
w.WriteHeader(http.StatusMethodNotAllowed)
111+
return
112+
}
113+
w.WriteHeader(http.StatusOK)
114+
}))
115+
defer proxyServer.Close()
116+
117+
origHTTP := os.Getenv("HTTP_PROXY")
118+
defer os.Setenv("HTTP_PROXY", origHTTP)
119+
120+
os.Setenv("HTTP_PROXY", proxyServer.URL)
121+
122+
dialer := newProxyAwareDialer(5*time.Second, 5*time.Second, TestLogger)
123+
assert.NotNil(t, dialer)
124+
125+
// Test actual dial (this will fail because our mock proxy doesn't handle the full protocol)
126+
// but we can verify the proxy detection logic works
127+
proxyAwareDialer, ok := dialer.(*proxyAwareDialer)
128+
assert.True(t, ok, "Expected proxyAwareDialer when HTTP proxy configured")
129+
assert.NotNil(t, proxyAwareDialer.baseDialer)
130+
}
131+
132+
func TestGetEnvProxy(t *testing.T) {
133+
tests := []struct {
134+
name string
135+
upper string
136+
lower string
137+
upperVal string
138+
lowerVal string
139+
expected string
140+
}{
141+
{
142+
name: "upper case takes priority",
143+
upper: "TEST_PROXY",
144+
lower: "test_proxy",
145+
upperVal: "upper_value",
146+
lowerVal: "lower_value",
147+
expected: "upper_value",
148+
},
149+
{
150+
name: "lower case when upper not set",
151+
upper: "TEST_PROXY",
152+
lower: "test_proxy",
153+
lowerVal: "lower_value",
154+
expected: "lower_value",
155+
},
156+
{
157+
name: "empty when neither set",
158+
upper: "TEST_PROXY",
159+
lower: "test_proxy",
160+
expected: "",
161+
},
162+
}
163+
164+
for _, tt := range tests {
165+
t.Run(tt.name, func(t *testing.T) {
166+
// Save and restore environment
167+
origUpper := os.Getenv(tt.upper)
168+
origLower := os.Getenv(tt.lower)
169+
defer func() {
170+
os.Setenv(tt.upper, origUpper)
171+
os.Setenv(tt.lower, origLower)
172+
}()
173+
174+
os.Unsetenv(tt.upper)
175+
os.Unsetenv(tt.lower)
176+
177+
if tt.upperVal != "" {
178+
os.Setenv(tt.upper, tt.upperVal)
179+
}
180+
if tt.lowerVal != "" {
181+
os.Setenv(tt.lower, tt.lowerVal)
182+
}
183+
184+
result := getEnvProxy(tt.upper, tt.lower)
185+
assert.Equal(t, tt.expected, result)
186+
})
187+
}
188+
}
189+
43190
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
44191
originListener, err := net.Listen("tcp", "127.0.0.1:0")
45192
require.NoError(t, err)

0 commit comments

Comments
 (0)