diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index dd4e64546..c9d09aae6 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features and Improvements +* Add `HTTPHeaders` and `HTTPPathPrefix` configuration options to support HTTP proxies that require custom headers or URL path rewriting. + ### Bug Fixes ### Documentation diff --git a/config/api_client.go b/config/api_client.go index 9b558bf34..e81102c4e 100644 --- a/config/api_client.go +++ b/config/api_client.go @@ -38,6 +38,17 @@ func HTTPClientConfigFromConfig(cfg *Config) (httpclient.ClientConfig, error) { } r.URL.Host = url.Host r.URL.Scheme = url.Scheme + // Prepend path prefix if configured + if cfg.HTTPPathPrefix != "" { + r.URL.Path = cfg.HTTPPathPrefix + r.URL.Path + } + return nil + }, + // Add custom HTTP headers if configured + func(r *http.Request) error { + for key, value := range cfg.HTTPHeaders { + r.Header.Set(key, value) + } return nil }, authInUserAgentVisitor(cfg), diff --git a/config/config.go b/config/config.go index 6ecc90ca9..68b267c77 100644 --- a/config/config.go +++ b/config/config.go @@ -155,6 +155,14 @@ type Config struct { // Number of seconds for HTTP timeout. Default is 60 (1 minute). HTTPTimeoutSeconds int `name:"http_timeout_seconds" auth:"-"` + // Custom HTTP headers to be added to each API request. + // Format for environment variable: key1=value1;key2=value2 + HTTPHeaders map[string]string `name:"http_headers" env:"DATABRICKS_HTTP_HEADERS" auth:"-"` + + // Path prefix to be prepended to all API URLs. Useful for HTTP proxies + // that require URL path rewriting. + HTTPPathPrefix string `name:"http_path_prefix" env:"DATABRICKS_HTTP_PATH_PREFIX" auth:"-"` + // Truncate JSON fields in JSON above this limit. Default is 96. DebugTruncateBytes int `name:"debug_truncate_bytes" env:"DATABRICKS_DEBUG_TRUNCATE_BYTES" auth:"-"` diff --git a/config/config_attribute.go b/config/config_attribute.go index b3e4db1ae..be191efd5 100644 --- a/config/config_attribute.go +++ b/config/config_attribute.go @@ -5,6 +5,7 @@ import ( "os" "reflect" "strconv" + "strings" ) type Source struct { @@ -69,12 +70,44 @@ func (a *ConfigAttribute) SetS(cfg *Config, v string) error { return err } return a.Set(cfg, vv) + case reflect.Map: + // Parse map from string format: key1=value1;key2=value2 + vv, err := parseMapFromString(v) + if err != nil { + return fmt.Errorf("cannot parse %s: %w", a.Name, err) + } + return a.Set(cfg, vv) default: return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind)) } } +// parseMapFromString parses a string of format "key1=value1;key2=value2" into a map. +// Empty keys or values are allowed. If the string is empty, an empty map is returned. +func parseMapFromString(s string) (map[string]string, error) { + result := make(map[string]string) + if s == "" { + return result, nil + } + pairs := strings.Split(s, ";") + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + // Split only on the first '=' to allow '=' in values + idx := strings.Index(pair, "=") + if idx == -1 { + return nil, fmt.Errorf("invalid key-value pair %q: missing '='", pair) + } + key := pair[:idx] + value := pair[idx+1:] + result[key] = value + } + return result, nil +} + func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error { rv := reflect.ValueOf(cfg) field := rv.Elem().Field(a.num) @@ -85,6 +118,8 @@ func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error { field.SetBool(i.(bool)) case reflect.Int: field.SetInt(int64(i.(int))) + case reflect.Map: + field.Set(reflect.ValueOf(i)) default: // must extensively test with providerFixture to avoid this one return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind)) diff --git a/config/config_test.go b/config/config_test.go index 3f6b93392..7d8be6c4b 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "context" "net/http" + "net/url" "testing" "github.com/databricks/databricks-sdk-go/credentials/u2m" @@ -299,3 +300,188 @@ func TestConfig_getOAuthArgument_Unified(t *testing.T) { }) } } + +func TestConfig_HTTPHeaders(t *testing.T) { + c := &Config{ + Host: "https://my-workspace.cloud.databricks.com", + Token: "test-token", + HTTPHeaders: map[string]string{ + "X-Custom-Header-1": "value1", + "X-Custom-Header-2": "value2", + }, + } + + cfg, err := HTTPClientConfigFromConfig(c) + require.NoError(t, err) + + // Create a test request and apply only the first two visitors (host/path and headers) + // Skip other visitors that require context to be set + req, err := http.NewRequestWithContext(context.Background(), "GET", "/api/2.0/clusters/list", nil) + require.NoError(t, err) + req.URL = &url.URL{Path: "/api/2.0/clusters/list"} + + // Apply only the first two visitors (host/path rewrite and custom headers) + for i := 0; i < 2 && i < len(cfg.Visitors); i++ { + err = cfg.Visitors[i](req) + require.NoError(t, err) + } + + // Verify custom headers are set + assert.Equal(t, "value1", req.Header.Get("X-Custom-Header-1")) + assert.Equal(t, "value2", req.Header.Get("X-Custom-Header-2")) +} + +func TestConfig_HTTPPathPrefix(t *testing.T) { + c := &Config{ + Host: "https://proxy.example.com", + Token: "test-token", + HTTPPathPrefix: "/prefix/path", + } + + cfg, err := HTTPClientConfigFromConfig(c) + require.NoError(t, err) + + // Create a test request + req, err := http.NewRequestWithContext(context.Background(), "GET", "/api/2.0/clusters/list", nil) + require.NoError(t, err) + req.URL = &url.URL{Path: "/api/2.0/clusters/list"} + + // Apply only the first visitor (host/path rewrite) + err = cfg.Visitors[0](req) + require.NoError(t, err) + + // Verify path prefix is prepended + assert.Equal(t, "/prefix/path/api/2.0/clusters/list", req.URL.Path) +} + +func TestConfig_HTTPHeadersAndPathPrefix(t *testing.T) { + c := &Config{ + Host: "https://proxy.example.com", + Token: "test-token", + HTTPHeaders: map[string]string{ + "X-Custom-Header": "custom-value", + }, + HTTPPathPrefix: "/prefix/path", + } + + cfg, err := HTTPClientConfigFromConfig(c) + require.NoError(t, err) + + // Create a test request + req, err := http.NewRequestWithContext(context.Background(), "GET", "/api/2.0/clusters/list", nil) + require.NoError(t, err) + req.URL = &url.URL{Path: "/api/2.0/clusters/list"} + + // Apply only the first two visitors (host/path rewrite and custom headers) + for i := 0; i < 2 && i < len(cfg.Visitors); i++ { + err = cfg.Visitors[i](req) + require.NoError(t, err) + } + + // Verify both custom header and path prefix + assert.Equal(t, "custom-value", req.Header.Get("X-Custom-Header")) + assert.Equal(t, "/prefix/path/api/2.0/clusters/list", req.URL.Path) +} + +func TestConfig_HTTPHeadersFromEnvVar(t *testing.T) { + t.Setenv("DATABRICKS_HTTP_HEADERS", "X-Custom-Header-1=value1;X-Custom-Header-2=value2") + t.Setenv("DATABRICKS_HOST", "https://my-workspace.cloud.databricks.com") + t.Setenv("DATABRICKS_TOKEN", "test-token") + + c := &Config{} + err := c.EnsureResolved() + require.NoError(t, err) + + assert.Equal(t, map[string]string{ + "X-Custom-Header-1": "value1", + "X-Custom-Header-2": "value2", + }, c.HTTPHeaders) +} + +func TestConfig_HTTPPathPrefixFromEnvVar(t *testing.T) { + t.Setenv("DATABRICKS_HTTP_PATH_PREFIX", "/prefix/path") + t.Setenv("DATABRICKS_HOST", "https://my-workspace.cloud.databricks.com") + t.Setenv("DATABRICKS_TOKEN", "test-token") + + c := &Config{} + err := c.EnsureResolved() + require.NoError(t, err) + + assert.Equal(t, "/prefix/path", c.HTTPPathPrefix) +} + +func TestParseMapFromString(t *testing.T) { + tests := []struct { + name string + input string + expected map[string]string + wantErr bool + }{ + { + name: "empty string", + input: "", + expected: map[string]string{}, + }, + { + name: "single key-value pair", + input: "key1=value1", + expected: map[string]string{ + "key1": "value1", + }, + }, + { + name: "multiple key-value pairs", + input: "key1=value1;key2=value2", + expected: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + { + name: "value with equals sign", + input: "key1=value=with=equals", + expected: map[string]string{ + "key1": "value=with=equals", + }, + }, + { + name: "empty value", + input: "key1=", + expected: map[string]string{ + "key1": "", + }, + }, + { + name: "whitespace around pairs", + input: " key1=value1 ; key2=value2 ", + expected: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + { + name: "trailing semicolon", + input: "key1=value1;", + expected: map[string]string{ + "key1": "value1", + }, + }, + { + name: "missing equals sign", + input: "key1value1", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMapFromString(tt.input) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expected, got) + }) + } +}