diff --git a/PROXY_SUPPORT.md b/PROXY_SUPPORT.md new file mode 100644 index 000000000..9677f5da9 --- /dev/null +++ b/PROXY_SUPPORT.md @@ -0,0 +1,35 @@ +# Agent Proxy Support + +This document describes how to configure the NGINX Agent to connect to the management plane through an explicit forward proxy (EFP), via HTTP/1.1 and authentication. + +--- + +## 1. Basic Proxy Configuration + +Add a `proxy` section under the `server` block in your agent config file: + +```yaml +server: + host: mgmt.example.com + port: 443 + type: 1 + proxy: + url: "http://proxy.example.com:3128" + timeout: 10s +``` + +- `url`: Proxy URL (http supported) +- `timeout`: Dial timeout for connecting to the proxy + +--- + +## 2. Proxy Authentication + +### Basic Auth +```yaml +proxy: + url: "http://proxy.example.com:3128" + auth_method: "basic" + username: "user" + password: "pass" +``` diff --git a/internal/collector/otel_collector_plugin.go b/internal/collector/otel_collector_plugin.go index 121d7ae2e..bcdbd21e7 100644 --- a/internal/collector/otel_collector_plugin.go +++ b/internal/collector/otel_collector_plugin.go @@ -10,6 +10,7 @@ import ( "fmt" "log/slog" "net" + "net/url" "os" "strings" "sync" @@ -231,6 +232,10 @@ func (oc *Collector) bootup(ctx context.Context) error { return } + if oc.config.IsCommandServerProxyConfigured() { + oc.setProxyIfNeeded(ctx) + } + appErr := oc.service.Run(ctx) if appErr != nil { errChan <- appErr @@ -394,6 +399,10 @@ func (oc *Collector) restartCollector(ctx context.Context) { } oc.service = oTelCollector + if oc.config.IsCommandServerProxyConfigured() { + oc.setProxyIfNeeded(ctx) + } + var runCtx context.Context runCtx, oc.cancel = context.WithCancel(ctx) @@ -409,6 +418,14 @@ func (oc *Collector) restartCollector(ctx context.Context) { } } +func (oc *Collector) setProxyIfNeeded(ctx context.Context) { + if oc.config.Collector.Exporters.OtlpExporters != nil || + oc.config.Collector.Exporters.PrometheusExporter != nil { + // Set proxy env vars for OTLP exporter if proxy is configured. + oc.setExporterProxyEnvVars(ctx) + } +} + func (oc *Collector) checkForNewReceivers(ctx context.Context, nginxConfigContext *model.NginxConfigContext) bool { nginxReceiverFound, reloadCollector := oc.updateExistingNginxPlusReceiver(nginxConfigContext) @@ -740,3 +757,55 @@ func escapeString(input string) string { return output } + +func (oc *Collector) setExporterProxyEnvVars(ctx context.Context) { + proxy := oc.config.Command.Server.Proxy + proxyURL := proxy.URL + parsedProxyURL, err := url.Parse(proxyURL) + if err != nil { + slog.ErrorContext(ctx, "Malformed proxy URL; skipping Proxy setup", "url", proxyURL, "error", err) + return + } + + if parsedProxyURL.Scheme == "https" { + slog.ErrorContext(ctx, "Protocol not supported, unable to configure proxy", "url", proxyURL) + } + + auth := "" + if proxy.AuthMethod != "" && strings.TrimSpace(proxy.AuthMethod) != "" { + auth = strings.TrimSpace(proxy.AuthMethod) + } + + // Use the standalone setProxyWithBasicAuth function + if auth == "" { + setProxyEnvs(ctx, proxyURL, "Setting Proxy from command.Proxy (no auth)") + return + } + authLower := strings.ToLower(auth) + if authLower == "basic" { + setProxyWithBasicAuth(ctx, proxy, parsedProxyURL) + } else { + slog.ErrorContext(ctx, "Unknown auth type for proxy; unable to configure proxy", "auth", auth, "url", proxyURL) + } +} + +// setProxyEnvs sets the HTTP_PROXY and HTTPS_PROXY environment variables and logs the action. +func setProxyEnvs(ctx context.Context, proxyEnvURL, msg string) { + slog.DebugContext(ctx, msg, "url", proxyEnvURL) + if setenvErr := os.Setenv("HTTP_PROXY", proxyEnvURL); setenvErr != nil { + slog.ErrorContext(ctx, "Failed to set Proxy", "error", setenvErr) + } +} + +// setProxyWithBasicAuth sets the proxy environment variables with basic auth credentials. +func setProxyWithBasicAuth(ctx context.Context, proxy *config.Proxy, parsedProxyURL *url.URL) { + username := proxy.Username + password := proxy.Password + if username == "" || password == "" { + slog.ErrorContext(ctx, "Username or password missing for basic auth") + return + } + parsedProxyURL.User = url.UserPassword(username, password) + proxyURL := parsedProxyURL.String() + setProxyEnvs(ctx, proxyURL, "Setting Proxy with basic auth") +} diff --git a/internal/collector/otel_collector_plugin_test.go b/internal/collector/otel_collector_plugin_test.go index 59b71e990..ab5054fbb 100644 --- a/internal/collector/otel_collector_plugin_test.go +++ b/internal/collector/otel_collector_plugin_test.go @@ -9,6 +9,8 @@ import ( "context" "errors" "net" + "net/url" + "os" "path/filepath" "testing" @@ -246,7 +248,11 @@ func TestCollector_ProcessNginxConfigUpdateTopic(t *testing.T) { conf := types.OTelConfig(t) - conf.Command = nil + conf.Command = &config.Command{ + Server: &config.ServerConfig{ + Proxy: &config.Proxy{}, + }, + } conf.Collector.Log.Path = "" conf.Collector.Receivers.HostMetrics = nil @@ -783,6 +789,127 @@ func TestCollector_updateNginxAppProtectTcplogReceivers(t *testing.T) { }) } +func Test_setProxyEnvs(t *testing.T) { + ctx := context.Background() + proxyURL := "http://localhost:8080" + msg := "Setting test proxy" + + // Unset first to ensure clean state + _ = os.Unsetenv("HTTP_PROXY") + + setProxyEnvs(ctx, proxyURL, msg) + + httpProxy := os.Getenv("HTTP_PROXY") + assert.Equal(t, proxyURL, httpProxy) +} + +func Test_setProxyWithBasicAuth(t *testing.T) { + ctx := context.Background() + u, _ := url.Parse("http://localhost:8080") + proxy := &config.Proxy{ + URL: "http://localhost:8080", + Username: "user", + Password: "pass", + } + + // Unset first to ensure clean state + _ = os.Unsetenv("HTTP_PROXY") + + setProxyWithBasicAuth(ctx, proxy, u) + + proxyURL := u.String() + httpProxy := os.Getenv("HTTP_PROXY") + assert.Equal(t, proxyURL, httpProxy) + + // Test missing username/password + proxyMissing := &config.Proxy{URL: "http://localhost:8080"} + setProxyWithBasicAuth(ctx, proxyMissing, u) // Should not panic +} + +func TestSetExporterProxyEnvVars(t *testing.T) { + ctx := context.Background() + logBuf := &bytes.Buffer{} + stub.StubLoggerWith(logBuf) + + tests := []struct { + name string + proxy *config.Proxy + expectedLog string + setEnv bool + }{ + { + name: "No proxy config", + proxy: nil, + expectedLog: "Proxy configuration is not setup. Unable to configure proxy for OTLP exporter", + setEnv: false, + }, + { + name: "Malformed proxy URL", + proxy: &config.Proxy{URL: "://bad_url"}, + expectedLog: "Malformed proxy URL; skipping Proxy setup", + setEnv: false, + }, + { + name: "No auth, valid URL", + proxy: &config.Proxy{URL: "http://proxy.example.com:8080"}, + expectedLog: "Setting Proxy from command.Proxy (no auth)", + setEnv: true, + }, + { + name: "Basic auth, valid URL", + proxy: &config.Proxy{ + URL: "http://proxy.example.com:8080", + AuthMethod: "basic", + Username: "user", + Password: "pass", + }, + expectedLog: "Setting Proxy with basic auth", + setEnv: true, + }, + { + name: "Unknown auth method", + proxy: &config.Proxy{URL: "http://proxy.example.com:8080", AuthMethod: "digest"}, + expectedLog: "Unknown auth type for proxy; unable to configure proxy", + setEnv: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logBuf.Reset() + + _ = os.Unsetenv("HTTP_PROXY") + + tmpDir := t.TempDir() + cfg := types.OTelConfig(t) + cfg.Collector.Log.Path = filepath.Join(tmpDir, "otel-collector-test.log") + cfg.Command.Server.Proxy = tt.proxy + + // If the proxy is nil, the production code would never call the setter functions. + // added this check to prevent the panic error in UT. + if cfg.Command.Server.Proxy == nil { + // For the nil proxy case, we expect nothing to happen. + assert.Empty(t, os.Getenv("HTTP_PROXY")) + + return + } + + collector, err := NewCollector(cfg) + require.NoError(t, err) + + collector.setExporterProxyEnvVars(ctx) + + helpers.ValidateLog(t, tt.expectedLog, logBuf) + + if tt.setEnv { + assert.NotEmpty(t, os.Getenv("HTTP_PROXY")) + } else { + assert.Empty(t, os.Getenv("HTTP_PROXY")) + } + }) + } +} + func TestCollector_findAvailableSyslogServers(t *testing.T) { conf := types.OTelConfig(t) conf.Collector.Log.Path = "" diff --git a/internal/config/config.go b/internal/config/config.go index dfd8d1e8f..3d8b1b4ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -628,6 +628,65 @@ func registerCommandFlags(fs *flag.FlagSet) { DefCommandTLServerNameKey, "Specifies the name of the server sent in the TLS configuration.", ) + fs.Duration( + CommandServerProxyTimeoutKey, + DefCommandServerProxyTimeoutKey, + "The explicit forward proxy HTTP Timeout, value in seconds") + fs.String( + CommandServerProxyURLKey, + DefCommandServerProxyURlKey, + "The Proxy URL to use for explicit forward proxy.", + ) + fs.String( + CommandServerProxyNoProxyKey, + DefCommandServerProxyNoProxyKey, + "The No-Proxy URL to use for explicit forward proxy.", + ) + fs.String( + CommandServerProxyAuthMethodKey, + DefCommandServerProxyAuthMethodKey, + "The Authentication method used for explicit forward proxy.", + ) + fs.String( + CommandServerProxyUsernameKey, + DefCommandServerProxyUsernameKey, + "The Username used for basic authentication for explicit forward proxy.", + ) + fs.String( + CommandServerProxyPasswordKey, + DefCommandServerProxyPasswordKey, + "The Password used for basic authentication for explicit forward proxy.", + ) + fs.String( + CommandServerProxyTokenKey, + DefCommandServerProxyTokenKey, + "The bearer token used for authentication for explicit forward proxy.", + ) + fs.String( + CommandServerProxyTLSCertKey, + DefCommandServerProxyTLSCertKey, + "The path to the certificate file to use for TLS communication with the command server.", + ) + fs.String( + CommandServerProxyTLSKeyKey, + DefCommandServerProxyTLSKeyKey, + "The path to the certificate key file to use for TLS communication with the command server.", + ) + fs.String( + CommandServerProxyTLSCaKey, + DefCommandServerProxyTLSCaKey, + "The path to CA certificate file to use for TLS communication with the command server.", + ) + fs.Bool( + CommandServerProxyTLSSkipVerifyKey, + DefCommandServerProxyTLSSkipVerifyKey, + "Testing only. Skip verify controls client verification of a server's certificate chain and host name.", + ) + fs.String( + CommandServerProxyTLSServerNameKey, + DefCommandServerProxyTLServerNameKey, + "Specifies the name of the server sent in the TLS configuration.", + ) } func registerAuxiliaryCommandFlags(fs *flag.FlagSet) { @@ -1235,9 +1294,10 @@ func resolveCommand() *Command { command := &Command{ Server: &ServerConfig{ - Host: viperInstance.GetString(CommandServerHostKey), - Port: viperInstance.GetInt(CommandServerPortKey), - Type: serverType, + Host: viperInstance.GetString(CommandServerHostKey), + Port: viperInstance.GetInt(CommandServerPortKey), + Type: serverType, + Proxy: resolveProxy(), }, } @@ -1365,3 +1425,47 @@ func resolveMapStructure(key string, object any) error { return nil } + +func resolveProxy() *Proxy { + proxy := &Proxy{ + Timeout: viperInstance.GetDuration(CommandServerProxyTimeoutKey), + URL: viperInstance.GetString(CommandServerProxyURLKey), + NoProxy: viperInstance.GetString(CommandServerProxyNoProxyKey), + Username: viperInstance.GetString(CommandServerProxyUsernameKey), + Password: viperInstance.GetString(CommandServerProxyPasswordKey), + Token: viperInstance.GetString(CommandServerProxyTokenKey), + AuthMethod: viperInstance.GetString(CommandServerProxyAuthMethodKey), + } + + if areCommandServerProxyTLSSettingsSet() { + proxy.TLS = &TLSConfig{ + Cert: viperInstance.GetString(CommandServerProxyTLSCertKey), + Key: viperInstance.GetString(CommandServerProxyTLSKeyKey), + Ca: viperInstance.GetString(CommandServerProxyTLSCaKey), + SkipVerify: viperInstance.GetBool(CommandServerProxyTLSSkipVerifyKey), + ServerName: viperInstance.GetString(CommandServerProxyTLSServerNameKey), + } + } + + // If all fields are zero/nil/empty, return nil + if proxy.TLS == nil && + proxy.Timeout == 0 && + proxy.URL == "" && + proxy.NoProxy == "" && + proxy.AuthMethod == "" && + proxy.Username == "" && + proxy.Password == "" && + proxy.Token == "" { + return nil + } + + return proxy +} + +func areCommandServerProxyTLSSettingsSet() bool { + return viperInstance.IsSet(CommandServerProxyTLSCertKey) || + viperInstance.IsSet(CommandServerProxyTLSKeyKey) || + viperInstance.IsSet(CommandServerProxyTLSCaKey) || + viperInstance.IsSet(CommandServerProxyTLSSkipVerifyKey) || + viperInstance.IsSet(CommandServerProxyTLSServerNameKey) +} diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 7a279d749..e0e676ed8 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -23,16 +23,28 @@ const ( DefNginxReloadBackoffMaxInterval = 10 * time.Second DefNginxReloadBackoffMaxElapsedTime = 30 * time.Second - DefCommandServerHostKey = "" - DefCommandServerPortKey = 0 - DefCommandServerTypeKey = "grpc" - DefCommandAuthTokenKey = "" - DefCommandAuthTokenPathKey = "" - DefCommandTLSCertKey = "" - DefCommandTLSKeyKey = "" - DefCommandTLSCaKey = "" - DefCommandTLSSkipVerifyKey = false - DefCommandTLServerNameKey = "" + DefCommandServerHostKey = "" + DefCommandServerPortKey = 0 + DefCommandServerTypeKey = "grpc" + DefCommandAuthTokenKey = "" + DefCommandAuthTokenPathKey = "" + DefCommandTLSCertKey = "" + DefCommandTLSKeyKey = "" + DefCommandTLSCaKey = "" + DefCommandTLSSkipVerifyKey = false + DefCommandTLServerNameKey = "" + DefCommandServerProxyTimeoutKey = 0 + DefCommandServerProxyURlKey = "" + DefCommandServerProxyNoProxyKey = "" + DefCommandServerProxyAuthMethodKey = "" + DefCommandServerProxyUsernameKey = "" + DefCommandServerProxyPasswordKey = "" + DefCommandServerProxyTokenKey = "" + DefCommandServerProxyTLSCertKey = "" + DefCommandServerProxyTLSKeyKey = "" + DefCommandServerProxyTLSCaKey = "" + DefCommandServerProxyTLSSkipVerifyKey = false + DefCommandServerProxyTLServerNameKey = "" DefAuxiliaryCommandServerHostKey = "" DefAuxiliaryCommandServerPortKey = 0 diff --git a/internal/config/flags.go b/internal/config/flags.go index 172879518..b8e08dcd6 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -81,19 +81,31 @@ var ( CollectorLogLevelKey = pre(CollectorLogKey) + "level" CollectorLogPathKey = pre(CollectorLogKey) + "path" - CommandAuthKey = pre(CommandRootKey) + "auth" - CommandAuthTokenKey = pre(CommandAuthKey) + "token" - CommandAuthTokenPathKey = pre(CommandAuthKey) + "tokenpath" - CommandServerHostKey = pre(CommandServerKey) + "host" - CommandServerKey = pre(CommandRootKey) + "server" - CommandServerPortKey = pre(CommandServerKey) + "port" - CommandServerTypeKey = pre(CommandServerKey) + "type" - CommandTLSKey = pre(CommandRootKey) + "tls" - CommandTLSCaKey = pre(CommandTLSKey) + "ca" - CommandTLSCertKey = pre(CommandTLSKey) + "cert" - CommandTLSKeyKey = pre(CommandTLSKey) + "key" - CommandTLSServerNameKey = pre(CommandTLSKey) + "server_name" - CommandTLSSkipVerifyKey = pre(CommandTLSKey) + "skip_verify" + CommandAuthKey = pre(CommandRootKey) + "auth" + CommandAuthTokenKey = pre(CommandAuthKey) + "token" + CommandAuthTokenPathKey = pre(CommandAuthKey) + "tokenpath" + CommandServerHostKey = pre(CommandServerKey) + "host" + CommandServerKey = pre(CommandRootKey) + "server" + CommandServerPortKey = pre(CommandServerKey) + "port" + CommandServerTypeKey = pre(CommandServerKey) + "type" + CommandTLSKey = pre(CommandRootKey) + "tls" + CommandTLSCaKey = pre(CommandTLSKey) + "ca" + CommandTLSCertKey = pre(CommandTLSKey) + "cert" + CommandTLSKeyKey = pre(CommandTLSKey) + "key" + CommandTLSServerNameKey = pre(CommandTLSKey) + "server_name" + CommandTLSSkipVerifyKey = pre(CommandTLSKey) + "skip_verify" + CommandServerProxyTimeoutKey = pre(CommandServerKey) + "proxy_timeout" + CommandServerProxyURLKey = pre(CommandServerKey) + "proxy_url" + CommandServerProxyUsernameKey = pre(CommandServerKey) + "proxy_username" + CommandServerProxyPasswordKey = pre(CommandServerKey) + "proxy_password" + CommandServerProxyNoProxyKey = pre(CommandServerKey) + "proxy_no_proxy" + CommandServerProxyTokenKey = pre(CommandServerKey) + "proxy_token" + CommandServerProxyAuthMethodKey = pre(CommandServerKey) + "proxy_auth_method" + CommandServerProxyTLSCertKey = pre(CommandServerKey) + "proxy_tls_cert" + CommandServerProxyTLSKeyKey = pre(CommandServerKey) + "proxy_tls_key" + CommandServerProxyTLSCaKey = pre(CommandServerKey) + "proxy_tls_ca" + CommandServerProxyTLSSkipVerifyKey = pre(CommandServerKey) + "proxy_tls_skip_verify" + CommandServerProxyTLSServerNameKey = pre(CommandServerKey) + "proxy_tls_server_name" AuxiliaryCommandAuthKey = pre(AuxiliaryCommandRootKey) + "auth" AuxiliaryCommandAuthTokenKey = pre(AuxiliaryCommandAuthKey) + "token" diff --git a/internal/config/types.go b/internal/config/types.go index 4159c6a83..718dd89db 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -290,9 +290,10 @@ type ( } ServerConfig struct { - Type ServerType `yaml:"type" mapstructure:"type"` - Host string `yaml:"host" mapstructure:"host"` - Port int `yaml:"port" mapstructure:"port"` + Proxy *Proxy `yaml:"proxy" mapstructure:"proxy"` + Type ServerType `yaml:"type" mapstructure:"type"` + Host string `yaml:"host" mapstructure:"host"` + Port int `yaml:"port" mapstructure:"port"` } AuthConfig struct { @@ -338,6 +339,18 @@ type ( ExcludeFiles []string `yaml:"exclude_files" mapstructure:"exclude_files"` MonitoringFrequency time.Duration `yaml:"monitoring_frequency" mapstructure:"monitoring_frequency"` } + + // nolint: govet + Proxy struct { + TLS *TLSConfig `yaml:"tls,omitempty" mapstructure:"tls"` + Timeout time.Duration `yaml:"timeout" mapstructure:"timeout"` + URL string `yaml:"url" mapstructure:"url"` + NoProxy string `yaml:"no_proxy,omitempty" mapstructure:"no_proxy"` + AuthMethod string `yaml:"auth_method,omitempty" mapstructure:"auth_method"` + Username string `yaml:"username,omitempty" mapstructure:"username"` + Password string `yaml:"password,omitempty" mapstructure:"password"` + Token string `yaml:"token,omitempty" mapstructure:"token"` + } ) func (col *Collector) Validate(allowedDirectories []string) error { @@ -460,6 +473,14 @@ func (c *Config) NewContextWithLabels(ctx context.Context) context.Context { return metadata.NewOutgoingContext(ctx, md) } +func (c *Config) IsCommandServerProxyConfigured() bool { + if c.Command.Server.Proxy == nil { + return false + } + + return c.Command.Server.Proxy.URL != "" +} + // isAllowedDir checks if the given path is in the list of allowed directories. // It returns true if the path is allowed, false otherwise. // If the path is allowed but does not exist, it also logs a warning. diff --git a/internal/grpc/grpc.go b/internal/grpc/grpc.go index fe0feaf62..d1f766632 100644 --- a/internal/grpc/grpc.go +++ b/internal/grpc/grpc.go @@ -214,6 +214,14 @@ func DialOptions(agentConfig *config.Config, commandConfig *config.Command, reso opts = append(opts, sendRecOpts...) + // Proxy support: If proxy config exists, use HTTP CONNECT dialer + if commandConfig.Server.Proxy != nil && commandConfig.Server.Proxy.URL != "" { + opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + slog.InfoContext(ctx, "Dialing grpc server via proxy") + return DialViaHTTPProxy(ctx, commandConfig.Server.Proxy, addr) + })) + } + opts, skipToken := addTransportCredentials(commandConfig, opts) if commandConfig.Auth != nil && !skipToken { diff --git a/internal/grpc/proxy_dialer.go b/internal/grpc/proxy_dialer.go new file mode 100644 index 000000000..6825af47c --- /dev/null +++ b/internal/grpc/proxy_dialer.go @@ -0,0 +1,156 @@ +// Copyright (c) F5, Inc. +// +// This source code is licensed under the Apache License, Version 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +package grpc + +import ( + "bufio" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "time" + + "github.com/nginx/agent/v3/internal/config" +) + +// DialViaHTTPProxy establishes a tunnel via HTTP CONNECT and returns a net.Conn +func DialViaHTTPProxy(ctx context.Context, proxyConf *config.Proxy, targetAddr string) (net.Conn, error) { + proxyURL, err := url.Parse(proxyConf.URL) + if err != nil { + return nil, wrapProxyError(ctx, "Invalid proxy URL", err, proxyConf.URL) + } + + dialConn, err := dialProxy(ctx, proxyURL, proxyConf) + if err != nil { + return nil, err + } + + if err = writeConnectRequest(dialConn, targetAddr, proxyConf); err != nil { + dialConn.Close() + return nil, wrapProxyError(ctx, "Failed to write CONNECT request", err, proxyConf.URL) + } + + resp, err := readConnectResponse(dialConn) + if err != nil { + dialConn.Close() + return nil, wrapProxyError(ctx, "Failed to read CONNECT response", err, proxyConf.URL) + } + + if err = validateProxyResponse(ctx, resp, dialConn); err != nil { + return nil, err + } + + slog.InfoContext(ctx, "Established proxy tunnel", "proxy_url", proxyConf.URL, "target_addr", targetAddr) + + return dialConn, nil +} + +func buildProxyTLSConfig(proxyConf *config.Proxy) (*tls.Config, error) { + tlsConf := &tls.Config{} + if proxyConf.TLS == nil { + return tlsConf, nil + } + + if err := addRootCAs(tlsConf, proxyConf.TLS.Ca); err != nil { + return nil, err + } + if err := addCertKeyPair(tlsConf, proxyConf.TLS.Cert, proxyConf.TLS.Key); err != nil { + return nil, err + } + setServerName(tlsConf, proxyConf.TLS.ServerName) + tlsConf.InsecureSkipVerify = proxyConf.TLS.SkipVerify + + return tlsConf, nil +} + +func dialToProxyTLS(proxyURL *url.URL, tlsConf *tls.Config, timeout time.Duration) (net.Conn, error) { + dialer := &net.Dialer{Timeout: timeout} + return tls.DialWithDialer(dialer, "tcp", proxyURL.Host, tlsConf) +} + +func dialToProxyTCP(ctx context.Context, proxyURL *url.URL, timeout time.Duration) (net.Conn, error) { + dialer := &net.Dialer{Timeout: timeout} + return dialer.DialContext(ctx, "tcp", proxyURL.Host) +} + +func writeConnectRequest(conn net.Conn, targetAddr string, proxyConf *config.Proxy) error { + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Opaque: targetAddr}, + Host: targetAddr, + Header: make(http.Header), + } + if proxyConf.AuthMethod == "basic" && proxyConf.Username != "" && proxyConf.Password != "" { + auth := base64.StdEncoding.EncodeToString([]byte(proxyConf.Username + ":" + proxyConf.Password)) + req.Header.Set("Proxy-Authorization", "Basic "+auth) + } else if proxyConf.AuthMethod == "bearer" && proxyConf.Token != "" { + req.Header.Set("Proxy-Authorization", "Bearer "+proxyConf.Token) + } + + return req.Write(conn) +} + +func readConnectResponse(conn net.Conn) (*http.Response, error) { + return http.ReadResponse(bufio.NewReader(conn), nil) +} + +func wrapProxyError(ctx context.Context, msg string, err error, proxyURL string) error { + return fmt.Errorf("%s: %s : proxyurl : %s : %w", ctx, msg, proxyURL, err) +} + +func dialProxy(ctx context.Context, proxyURL *url.URL, proxyConf *config.Proxy) (net.Conn, error) { + if proxyURL.Scheme == "https" { + tlsConf, err := buildProxyTLSConfig(proxyConf) + if err != nil { + return nil, wrapProxyError(ctx, "Failed to build TLS config", err, proxyConf.URL) + } + + return dialToProxyTLS(proxyURL, tlsConf, proxyConf.Timeout) + } + + return dialToProxyTCP(ctx, proxyURL, proxyConf.Timeout) +} + +func validateProxyResponse(ctx context.Context, resp *http.Response, dialConn net.Conn) error { + if resp.StatusCode == http.StatusOK { + return nil + } + if _, err := io.Copy(io.Discard, resp.Body); err != nil { + slog.ErrorContext(ctx, "Failed to discard response body", "error", err) + } + resp.Body.Close() + dialConn.Close() + + return errors.New("proxy CONNECT failed: " + resp.Status) +} + +func addRootCAs(tlsConf *tls.Config, caPath string) error { + if caPath == "" { + return nil + } + + return appendRootCAs(tlsConf, caPath) +} + +func addCertKeyPair(tlsConf *tls.Config, certPath, keyPath string) error { + if certPath == "" || keyPath == "" { + return nil + } + + return appendCertKeyPair(tlsConf, certPath, keyPath) +} + +func setServerName(tlsConf *tls.Config, serverName string) { + if serverName != "" { + tlsConf.ServerName = serverName + } +} diff --git a/internal/grpc/proxy_dialer_test.go b/internal/grpc/proxy_dialer_test.go new file mode 100644 index 000000000..b8c3f5d74 --- /dev/null +++ b/internal/grpc/proxy_dialer_test.go @@ -0,0 +1,181 @@ +// Copyright (c) F5, Inc. +// +// This source code is licensed under the Apache License, Version 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +package grpc + +import ( + "bufio" + "context" + "net" + "os" + "strings" + "testing" + "time" + + "github.com/nginx/agent/v3/internal/config" + "github.com/stretchr/testify/require" +) + +func TestDialViaHTTPProxy_NoProxy(t *testing.T) { + // This test attempts to connect directly to a known open port (localhost:80) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + proxyConf := &config.Proxy{ + URL: "", + Timeout: 2 * time.Second, + } + _, err := DialViaHTTPProxy(ctx, proxyConf, "localhost:80") + require.Error(t, err, "expected failure with empty proxy URL") +} + +func TestDialViaHTTPProxy_InvalidProxy(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + proxyConf := &config.Proxy{ + URL: "http://invalid:9999", + Timeout: 2 * time.Second, + } + _, err := DialViaHTTPProxy(ctx, proxyConf, "localhost:80") + require.Error(t, err, "expected failure with invalid proxy") +} + +// To fully test with a real proxy, set the env var TEST_HTTP_PROXY_URL and have a proxy listening. +func TestDialViaHTTPProxy_RealProxy(t *testing.T) { + proxyURL := os.Getenv("TEST_HTTP_PROXY_URL") + if proxyURL == "" { + t.Skip("TEST_HTTP_PROXY_URL not set") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + proxyConf := &config.Proxy{ + URL: proxyURL, + Timeout: 3 * time.Second, + } + conn, err := DialViaHTTPProxy(ctx, proxyConf, "example.com:80") + require.NoError(t, err, "failed to connect via proxy") + defer conn.Close() + + // Basic write/read to check tunnel + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")); err != nil { + t.Errorf("failed to write to tunnel: %v", err) + } + buf := make([]byte, 128) + if deadlineErr := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); deadlineErr != nil { + // Optionally log + t.Logf("Failed to set read deadline: %v", deadlineErr) + } + _, err = conn.Read(buf) + if err != nil && err != context.DeadlineExceeded && !isTimeout(err) { + t.Errorf("failed to read from tunnel: %v", err) + } +} + +func TestDialViaHTTPProxy_BearerTokenHeader(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "failed to listen") + defer ln.Close() + + done := make(chan struct{}) + go func() { + conn, acceptErr := ln.Accept() + if acceptErr != nil { + t.Errorf("Failed to accept connection: %v", acceptErr) + return + } + defer conn.Close() + reader := bufio.NewReader(conn) + headerLines := readHeaders(reader) + if hasBearerHeader(headerLines, "testtoken") { + close(done) + return + } + close(done) + }() + + proxyConf := &config.Proxy{ + URL: "http://" + ln.Addr().String(), + AuthMethod: "bearer", + Token: "testtoken", + Timeout: 2 * time.Second, + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, _ = DialViaHTTPProxy(ctx, proxyConf, "example.com:443") + + select { + case <-done: + // success + case <-time.After(1 * time.Second): + t.Errorf("Proxy-Authorization Bearer header was not sent") + } +} + +func isTimeout(err error) bool { + nerr, ok := err.(net.Error) + return ok && nerr.Timeout() +} + +func readHeaders(reader *bufio.Reader) []string { + var headerLines []string + for { + line, err := reader.ReadString('\n') + if err != nil || line == "\r\n" { + break + } + headerLines = append(headerLines, line) + } + + return headerLines +} + +func hasBearerHeader(headerLines []string, token string) bool { + expected := "Proxy-Authorization: Bearer " + token + for _, h := range headerLines { + if strings.HasPrefix(h, expected) { + return true + } + } + + return false +} + +func TestDialViaHTTPProxy_InvalidCAPath(t *testing.T) { + proxyConf := &config.Proxy{ + URL: "https://localhost:9999", + TLS: &config.TLSConfig{Ca: "/invalid/path/to/ca.pem"}, + Timeout: 1 * time.Second, + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := DialViaHTTPProxy(ctx, proxyConf, "example.com:443") + require.Error(t, err, "expected error for invalid CA path") +} + +func TestDialViaHTTPProxy_MissingCertKey(t *testing.T) { + proxyConf := &config.Proxy{ + URL: "https://localhost:9999", + TLS: &config.TLSConfig{}, // No cert/key + Timeout: 1 * time.Second, + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := DialViaHTTPProxy(ctx, proxyConf, "example.com:443") + // No assert needed: just covers the branch + require.Error(t, err, "expected error for missing cert") +} + +func TestDialViaHTTPProxy_InvalidProxyURL(t *testing.T) { + proxyConf := &config.Proxy{ + URL: "://bad-url", + Timeout: 1 * time.Second, + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := DialViaHTTPProxy(ctx, proxyConf, "example.com:443") + require.Error(t, err, "expected error for invalid proxy URL") +}