Skip to content

Commit 664ed3f

Browse files
committed
Add HTTP proxy support for tunnel connections
1 parent 1cedefa commit 664ed3f

File tree

5 files changed

+345
-22
lines changed

5 files changed

+345
-22
lines changed

ingress/origin_dialer.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/rs/zerolog"
12+
"golang.org/x/net/proxy"
1213
)
1314

1415
// OriginTCPDialer provides a TCP dial operation to a requested address.
@@ -115,20 +116,33 @@ func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) {
115116
}
116117

117118
type Dialer struct {
118-
Dialer net.Dialer
119+
Dialer proxy.Dialer
119120
}
120121

121122
func NewDialer(config WarpRoutingConfig) *Dialer {
123+
// Create proxy-aware dialer for warp routing
124+
proxyDialer := createProxyDialer(config.ConnectTimeout.Duration, config.TCPKeepAlive.Duration, nil)
122125
return &Dialer{
123-
Dialer: net.Dialer{
124-
Timeout: config.ConnectTimeout.Duration,
125-
KeepAlive: config.TCPKeepAlive.Duration,
126-
},
126+
Dialer: proxyDialer,
127127
}
128128
}
129129

130+
// createProxyDialer creates a proxy.Dialer that respects proxy environment variables
131+
func createProxyDialer(timeout, keepAlive time.Duration, logger *zerolog.Logger) proxy.Dialer {
132+
// Reuse the unified proxy logic from origin_service.go
133+
return newProxyAwareDialer(timeout, keepAlive, logger)
134+
}
135+
130136
func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) {
131-
conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String())
137+
var conn net.Conn
138+
var err error
139+
140+
if contextDialer, ok := d.Dialer.(proxy.ContextDialer); ok {
141+
conn, err = contextDialer.DialContext(ctx, "tcp", dest.String())
142+
} else {
143+
conn, err = d.Dialer.Dial("tcp", dest.String())
144+
}
145+
132146
if err != nil {
133147
return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err)
134148
}

ingress/origin_proxy.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99

1010
"github.com/rs/zerolog"
11+
"golang.org/x/net/proxy"
1112
)
1213

1314
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@@ -86,7 +87,15 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
8687
}
8788

8889
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
89-
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
90+
var conn net.Conn
91+
var err error
92+
93+
if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok {
94+
conn, err = contextDialer.DialContext(ctx, "tcp", dest)
95+
} else {
96+
conn, err = o.dialer.Dial("tcp", dest)
97+
}
98+
9099
if err != nil {
91100
return nil, err
92101
}
@@ -105,7 +114,13 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string,
105114
dest = o.dest
106115
}
107116

108-
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
117+
var conn net.Conn
118+
if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok {
119+
conn, err = contextDialer.DialContext(ctx, "tcp", dest)
120+
} else {
121+
conn, err = o.dialer.Dial("tcp", dest)
122+
}
123+
109124
if err != nil {
110125
return nil, err
111126
}

ingress/origin_proxy_test.go

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"net/http"
99
"net/http/httptest"
1010
"net/url"
11+
"os"
1112
"testing"
13+
"time"
1214

1315
"github.com/stretchr/testify/assert"
1416
"github.com/stretchr/testify/require"
@@ -24,7 +26,10 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
2426
listenerClosed := make(chan struct{})
2527
tcpListenRoutine(originListener, listenerClosed)
2628

27-
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
29+
rawTCPService := &rawTCPService{
30+
name: ServiceWarpRouting,
31+
dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil),
32+
}
2833

2934
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
3035
require.NoError(t, err)
@@ -40,6 +45,159 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
4045
require.Error(t, err)
4146
}
4247

48+
func TestProxyAwareDialer(t *testing.T) {
49+
tests := []struct {
50+
name string
51+
httpProxy string
52+
httpsProxy string
53+
socksProxy string
54+
expectDirect bool
55+
expectProxy bool
56+
}{
57+
{
58+
name: "no proxy configured",
59+
expectDirect: true,
60+
},
61+
{
62+
name: "HTTP proxy configured",
63+
httpProxy: "http://proxy.example.com:8080",
64+
expectProxy: true,
65+
},
66+
{
67+
name: "HTTPS proxy configured",
68+
httpsProxy: "http://proxy.example.com:8080",
69+
expectProxy: true,
70+
},
71+
{
72+
name: "SOCKS proxy configured",
73+
socksProxy: "socks5://proxy.example.com:1080",
74+
expectProxy: true,
75+
},
76+
}
77+
78+
for _, tt := range tests {
79+
t.Run(tt.name, func(t *testing.T) {
80+
// Save original environment
81+
origHTTP := os.Getenv("HTTP_PROXY")
82+
origHTTPS := os.Getenv("HTTPS_PROXY")
83+
origSOCKS := os.Getenv("ALL_PROXY")
84+
85+
defer func() {
86+
os.Setenv("HTTP_PROXY", origHTTP)
87+
os.Setenv("HTTPS_PROXY", origHTTPS)
88+
os.Setenv("ALL_PROXY", origSOCKS)
89+
}()
90+
91+
// Set test environment
92+
os.Setenv("HTTP_PROXY", tt.httpProxy)
93+
os.Setenv("HTTPS_PROXY", tt.httpsProxy)
94+
os.Setenv("ALL_PROXY", tt.socksProxy)
95+
96+
dialer := newProxyAwareDialer(30*time.Second, 30*time.Second, TestLogger)
97+
assert.NotNil(t, dialer)
98+
99+
// Test that dialer implements expected interface
100+
if tt.expectDirect {
101+
// Should return base net.Dialer when no proxy configured
102+
_, ok := dialer.(*net.Dialer)
103+
assert.True(t, ok, "Expected net.Dialer when no proxy configured")
104+
} else if tt.expectProxy {
105+
// Should return some proxy dialer when proxy configured
106+
assert.NotNil(t, dialer, "Expected proxy dialer when proxy configured")
107+
}
108+
})
109+
}
110+
}
111+
112+
func TestProxyAwareDialerHTTPConnect(t *testing.T) {
113+
// Create a mock HTTP CONNECT proxy server
114+
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115+
if r.Method != "CONNECT" {
116+
w.WriteHeader(http.StatusMethodNotAllowed)
117+
return
118+
}
119+
// Simulate successful CONNECT
120+
w.WriteHeader(http.StatusOK)
121+
}))
122+
defer proxyServer.Close()
123+
124+
// Save original environment
125+
origHTTP := os.Getenv("HTTP_PROXY")
126+
defer os.Setenv("HTTP_PROXY", origHTTP)
127+
128+
// Set proxy environment
129+
os.Setenv("HTTP_PROXY", proxyServer.URL)
130+
131+
dialer := newProxyAwareDialer(5*time.Second, 5*time.Second, TestLogger)
132+
assert.NotNil(t, dialer)
133+
134+
// Test actual dial (this will fail because our mock proxy doesn't handle the full protocol)
135+
// but we can verify the proxy detection logic works
136+
proxyAwareDialer, ok := dialer.(*proxyAwareDialer)
137+
assert.True(t, ok, "Expected proxyAwareDialer when HTTP proxy configured")
138+
assert.NotNil(t, proxyAwareDialer.baseDialer)
139+
}
140+
141+
func TestGetEnvProxy(t *testing.T) {
142+
tests := []struct {
143+
name string
144+
upper string
145+
lower string
146+
upperVal string
147+
lowerVal string
148+
expected string
149+
}{
150+
{
151+
name: "upper case takes priority",
152+
upper: "TEST_PROXY",
153+
lower: "test_proxy",
154+
upperVal: "upper_value",
155+
lowerVal: "lower_value",
156+
expected: "upper_value",
157+
},
158+
{
159+
name: "lower case when upper not set",
160+
upper: "TEST_PROXY",
161+
lower: "test_proxy",
162+
lowerVal: "lower_value",
163+
expected: "lower_value",
164+
},
165+
{
166+
name: "empty when neither set",
167+
upper: "TEST_PROXY",
168+
lower: "test_proxy",
169+
expected: "",
170+
},
171+
}
172+
173+
for _, tt := range tests {
174+
t.Run(tt.name, func(t *testing.T) {
175+
// Save and restore environment
176+
origUpper := os.Getenv(tt.upper)
177+
origLower := os.Getenv(tt.lower)
178+
defer func() {
179+
os.Setenv(tt.upper, origUpper)
180+
os.Setenv(tt.lower, origLower)
181+
}()
182+
183+
// Clear environment first
184+
os.Unsetenv(tt.upper)
185+
os.Unsetenv(tt.lower)
186+
187+
// Set test values
188+
if tt.upperVal != "" {
189+
os.Setenv(tt.upper, tt.upperVal)
190+
}
191+
if tt.lowerVal != "" {
192+
os.Setenv(tt.lower, tt.lowerVal)
193+
}
194+
195+
result := getEnvProxy(tt.upper, tt.lower)
196+
assert.Equal(t, tt.expected, result)
197+
})
198+
}
199+
}
200+
43201
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
44202
originListener, err := net.Listen("tcp", "127.0.0.1:0")
45203
require.NoError(t, err)

0 commit comments

Comments
 (0)