diff --git a/ingress/origin_dialer.go b/ingress/origin_dialer.go index 36ade327c09..c826be92f66 100644 --- a/ingress/origin_dialer.go +++ b/ingress/origin_dialer.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/zerolog" + "golang.org/x/net/proxy" ) // OriginTCPDialer provides a TCP dial operation to a requested address. @@ -115,20 +116,33 @@ func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) { } type Dialer struct { - Dialer net.Dialer + Dialer proxy.Dialer } func NewDialer(config WarpRoutingConfig) *Dialer { + // Create proxy-aware dialer for warp routing + proxyDialer := createProxyDialer(config.ConnectTimeout.Duration, config.TCPKeepAlive.Duration, nil) return &Dialer{ - Dialer: net.Dialer{ - Timeout: config.ConnectTimeout.Duration, - KeepAlive: config.TCPKeepAlive.Duration, - }, + Dialer: proxyDialer, } } +// createProxyDialer creates a proxy.Dialer that respects proxy environment variables +func createProxyDialer(timeout, keepAlive time.Duration, logger *zerolog.Logger) proxy.Dialer { + // Reuse the unified proxy logic from origin_service.go + return newProxyAwareDialer(timeout, keepAlive, logger) +} + func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) { - conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String()) + var conn net.Conn + var err error + + if contextDialer, ok := d.Dialer.(proxy.ContextDialer); ok { + conn, err = contextDialer.DialContext(ctx, "tcp", dest.String()) + } else { + conn, err = d.Dialer.Dial("tcp", dest.String()) + } + if err != nil { return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err) } diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 7371eac92ec..4372b5a5393 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/rs/zerolog" + "golang.org/x/net/proxy" ) // HTTPOriginProxy can be implemented by origin services that want to proxy http requests. @@ -86,7 +87,15 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { } func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) { - conn, err := o.dialer.DialContext(ctx, "tcp", dest) + var conn net.Conn + var err error + + if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok { + conn, err = contextDialer.DialContext(ctx, "tcp", dest) + } else { + conn, err = o.dialer.Dial("tcp", dest) + } + if err != nil { return nil, err } @@ -105,7 +114,13 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, dest = o.dest } - conn, err := o.dialer.DialContext(ctx, "tcp", dest) + var conn net.Conn + if contextDialer, ok := o.dialer.(proxy.ContextDialer); ok { + conn, err = contextDialer.DialContext(ctx, "tcp", dest) + } else { + conn, err = o.dialer.Dial("tcp", dest) + } + if err != nil { return nil, err } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 7a6170a2a68..a953f6ec01e 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -8,7 +8,9 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,7 +26,10 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { listenerClosed := make(chan struct{}) tcpListenRoutine(originListener, listenerClosed) - rawTCPService := &rawTCPService{name: ServiceWarpRouting} + rawTCPService := &rawTCPService{ + name: ServiceWarpRouting, + dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil), + } req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) require.NoError(t, err) @@ -40,6 +45,148 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { require.Error(t, err) } +func TestProxyAwareDialer(t *testing.T) { + tests := []struct { + name string + httpProxy string + httpsProxy string + socksProxy string + expectDirect bool + expectProxy bool + }{ + { + name: "no proxy configured", + expectDirect: true, + }, + { + name: "HTTP proxy configured", + httpProxy: "http://proxy.example.com:8080", + expectProxy: true, + }, + { + name: "HTTPS proxy configured", + httpsProxy: "http://proxy.example.com:8080", + expectProxy: true, + }, + { + name: "SOCKS proxy configured", + socksProxy: "socks5://proxy.example.com:1080", + expectProxy: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origHTTP := os.Getenv("HTTP_PROXY") + origHTTPS := os.Getenv("HTTPS_PROXY") + origSOCKS := os.Getenv("ALL_PROXY") + + defer func() { + os.Setenv("HTTP_PROXY", origHTTP) + os.Setenv("HTTPS_PROXY", origHTTPS) + os.Setenv("ALL_PROXY", origSOCKS) + }() + + os.Setenv("HTTP_PROXY", tt.httpProxy) + os.Setenv("HTTPS_PROXY", tt.httpsProxy) + os.Setenv("ALL_PROXY", tt.socksProxy) + + dialer := newProxyAwareDialer(30*time.Second, 30*time.Second, TestLogger) + assert.NotNil(t, dialer) + + if tt.expectDirect { + _, ok := dialer.(*net.Dialer) + assert.True(t, ok, "Expected net.Dialer when no proxy configured") + } else if tt.expectProxy { + assert.NotNil(t, dialer, "Expected proxy dialer when proxy configured") + } + }) + } +} + +func TestProxyAwareDialerHTTPConnect(t *testing.T) { + proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + w.WriteHeader(http.StatusOK) + })) + defer proxyServer.Close() + + origHTTP := os.Getenv("HTTP_PROXY") + defer os.Setenv("HTTP_PROXY", origHTTP) + + os.Setenv("HTTP_PROXY", proxyServer.URL) + + dialer := newProxyAwareDialer(5*time.Second, 5*time.Second, TestLogger) + assert.NotNil(t, dialer) + + // Test actual dial (this will fail because our mock proxy doesn't handle the full protocol) + // but we can verify the proxy detection logic works + proxyAwareDialer, ok := dialer.(*proxyAwareDialer) + assert.True(t, ok, "Expected proxyAwareDialer when HTTP proxy configured") + assert.NotNil(t, proxyAwareDialer.baseDialer) +} + +func TestGetEnvProxy(t *testing.T) { + tests := []struct { + name string + upper string + lower string + upperVal string + lowerVal string + expected string + }{ + { + name: "upper case takes priority", + upper: "TEST_PROXY", + lower: "test_proxy", + upperVal: "upper_value", + lowerVal: "lower_value", + expected: "upper_value", + }, + { + name: "lower case when upper not set", + upper: "TEST_PROXY", + lower: "test_proxy", + lowerVal: "lower_value", + expected: "lower_value", + }, + { + name: "empty when neither set", + upper: "TEST_PROXY", + lower: "test_proxy", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore environment + origUpper := os.Getenv(tt.upper) + origLower := os.Getenv(tt.lower) + defer func() { + os.Setenv(tt.upper, origUpper) + os.Setenv(tt.lower, origLower) + }() + + os.Unsetenv(tt.upper) + os.Unsetenv(tt.lower) + + if tt.upperVal != "" { + os.Setenv(tt.upper, tt.upperVal) + } + if tt.lowerVal != "" { + os.Setenv(tt.lower, tt.lowerVal) + } + + result := getEnvProxy(tt.upper, tt.lower) + assert.Equal(t, tt.expected, result) + }) + } +} + func TestTCPOverWSServiceEstablishConnection(t *testing.T) { originListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/ingress/origin_service.go b/ingress/origin_service.go index e13204c5789..bc325d6175d 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -1,6 +1,7 @@ package ingress import ( + "bufio" "context" "crypto/tls" "encoding/json" @@ -9,11 +10,13 @@ import ( "net" "net/http" "net/url" + "os" "strconv" "time" "github.com/pkg/errors" "github.com/rs/zerolog" + "golang.org/x/net/proxy" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ipaccess" @@ -97,7 +100,7 @@ func (o httpService) MarshalJSON() ([]byte, error) { // It's used by warp routing type rawTCPService struct { name string - dialer net.Dialer + dialer proxy.Dialer writeTimeout time.Duration logger *zerolog.Logger } @@ -114,14 +117,145 @@ func (o rawTCPService) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } +// proxyAwareDialer wraps net.Dialer with proxy support for both HTTP CONNECT and SOCKS +type proxyAwareDialer struct { + baseDialer *net.Dialer + logger *zerolog.Logger +} + +// newProxyAwareDialer creates a dialer that supports proxy settings from environment +func newProxyAwareDialer(timeout, keepAlive time.Duration, logger *zerolog.Logger) proxy.Dialer { + baseDialer := &net.Dialer{ + Timeout: timeout, + KeepAlive: keepAlive, + } + + // Check for SOCKS proxy first using standard proxy package + if socksDialer := proxy.FromEnvironmentUsing(baseDialer); socksDialer != baseDialer { + if logger != nil { + logger.Debug().Msg("proxy: using SOCKS proxy from environment") + } + return socksDialer + } + + // Check for HTTP proxy environment variables + httpProxy := getEnvProxy("HTTP_PROXY", "http_proxy") + httpsProxy := getEnvProxy("HTTPS_PROXY", "https_proxy") + + if httpProxy == "" && httpsProxy == "" { + if logger != nil { + logger.Debug().Msg("proxy: no proxy configured, using direct connection") + } + return baseDialer + } + + if logger != nil { + logger.Debug().Str("HTTP_PROXY", httpProxy).Str("HTTPS_PROXY", httpsProxy).Msg("proxy: using HTTP proxy from environment") + } + return &proxyAwareDialer{ + baseDialer: baseDialer, + logger: logger, + } +} + +func getEnvProxy(upper, lower string) string { + if v := os.Getenv(upper); v != "" { + return v + } + return os.Getenv(lower) +} + +func (p *proxyAwareDialer) Dial(network, addr string) (net.Conn, error) { + return p.DialContext(context.Background(), network, addr) +} + +func (p *proxyAwareDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + return p.baseDialer.DialContext(ctx, network, addr) + } + + req := &http.Request{URL: &url.URL{Scheme: "http", Host: addr}} + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil || proxyURL == nil { + if p.logger != nil { + p.logger.Debug().Str("addr", addr).Msg("proxy: direct connection to") + } + return p.baseDialer.DialContext(ctx, network, addr) + } + + if p.logger != nil { + p.logger.Debug().Str("proxy_url", proxyURL.String()).Str("addr", addr).Msg("proxy: using proxy") + } + + switch proxyURL.Scheme { + case "socks4", "socks5": + return p.dialSOCKS(ctx, proxyURL, network, addr) + case "http", "https": + return p.dialHTTPConnect(ctx, proxyURL, addr) + default: + return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) + } +} + +func (p *proxyAwareDialer) dialSOCKS(ctx context.Context, proxyURL *url.URL, network, addr string) (net.Conn, error) { + socksDialer, err := proxy.FromURL(proxyURL, p.baseDialer) + if err != nil { + return nil, fmt.Errorf("SOCKS proxy error: %w", err) + } + + if contextDialer, ok := socksDialer.(proxy.ContextDialer); ok { + return contextDialer.DialContext(ctx, network, addr) + } + return socksDialer.Dial(network, addr) +} + +func (p *proxyAwareDialer) dialHTTPConnect(ctx context.Context, proxyURL *url.URL, addr string) (net.Conn, error) { + proxyAddr := proxyURL.Host + if proxyURL.Port() == "" { + if proxyURL.Scheme == "https" { + proxyAddr = net.JoinHostPort(proxyURL.Hostname(), "443") + } else { + proxyAddr = net.JoinHostPort(proxyURL.Hostname(), "80") + } + } + + conn, err := p.baseDialer.DialContext(ctx, "tcp", proxyAddr) + if err != nil { + return nil, fmt.Errorf("proxy connection failed: %w", err) + } + + connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", addr, addr) + if _, err := conn.Write([]byte(connectReq)); err != nil { + conn.Close() + return nil, fmt.Errorf("CONNECT request failed: %w", err) + } + + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, &http.Request{Method: "CONNECT"}) + if err != nil { + conn.Close() + return nil, fmt.Errorf("CONNECT response failed: %w", err) + } + resp.Body.Close() + + if resp.StatusCode != 200 { + conn.Close() + return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status) + } + + if p.logger != nil { + p.logger.Debug().Str("addr", addr).Msg("proxy: HTTP CONNECT successful") + } + return conn, nil +} + // tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as -// cloudflared access commands. type tcpOverWSService struct { scheme string dest string isBastion bool streamHandler streamHandlerFunc - dialer net.Dialer + dialer proxy.Dialer } type socksProxyOverWSService struct { @@ -142,12 +276,14 @@ func newTCPOverWSService(url *url.URL) *tcpOverWSService { return &tcpOverWSService{ scheme: url.Scheme, dest: url.Host, + dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil), } } func newBastionService() *tcpOverWSService { return &tcpOverWSService{ isBastion: true, + dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil), } } @@ -187,8 +323,8 @@ func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg Ori } else { o.streamHandler = DefaultStreamHandler } - o.dialer.Timeout = cfg.ConnectTimeout.Duration - o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration + // Recreate dialer with new timeout and keepalive settings + o.dialer = newProxyAwareDialer(cfg.ConnectTimeout.Duration, cfg.TCPKeepAlive.Duration, log) return nil } @@ -291,11 +427,8 @@ type WarpRoutingService struct { func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService { svc := &rawTCPService{ - name: ServiceWarpRouting, - dialer: net.Dialer{ - Timeout: config.ConnectTimeout.Duration, - KeepAlive: config.TCPKeepAlive.Duration, - }, + name: ServiceWarpRouting, + dialer: newProxyAwareDialer(config.ConnectTimeout.Duration, config.TCPKeepAlive.Duration, nil), writeTimeout: writeTimeout, } diff --git a/ingress/origins/dns.go b/ingress/origins/dns.go index c09c581dfd1..26c2e997175 100644 --- a/ingress/origins/dns.go +++ b/ingress/origins/dns.go @@ -205,8 +205,8 @@ func (r *resolver) peekDial(ctx context.Context, network, address string) (net.C // NewDNSDialer creates a custom dialer for the DNS resolver service to utilize. func NewDNSDialer() *ingress.Dialer { - return &ingress.Dialer{ - Dialer: net.Dialer{ + // For DNS, use direct connection to avoid circular dependencies + netDialer := &net.Dialer{ // We want short timeouts for the DNS requests Timeout: 5 * time.Second, // We do not want keep alive since the edge will not reuse TCP connections per request @@ -214,6 +214,9 @@ func NewDNSDialer() *ingress.Dialer { KeepAliveConfig: net.KeepAliveConfig{ Enable: false, }, - }, + } + + return &ingress.Dialer{ + Dialer: netDialer, } }