Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Add `HTTPHeaders` and `HTTPPathPrefix` configuration options to support HTTP proxies that require custom headers or URL path rewriting.

### Bug Fixes

### Documentation
Expand Down
11 changes: 11 additions & 0 deletions config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ func HTTPClientConfigFromConfig(cfg *Config) (httpclient.ClientConfig, error) {
}
r.URL.Host = url.Host
r.URL.Scheme = url.Scheme
// Prepend path prefix if configured
if cfg.HTTPPathPrefix != "" {
r.URL.Path = cfg.HTTPPathPrefix + r.URL.Path
}
return nil
},
// Add custom HTTP headers if configured
func(r *http.Request) error {
for key, value := range cfg.HTTPHeaders {
r.Header.Set(key, value)
}
return nil
},
authInUserAgentVisitor(cfg),
Expand Down
8 changes: 8 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ type Config struct {
// Number of seconds for HTTP timeout. Default is 60 (1 minute).
HTTPTimeoutSeconds int `name:"http_timeout_seconds" auth:"-"`

// Custom HTTP headers to be added to each API request.
// Format for environment variable: key1=value1;key2=value2
HTTPHeaders map[string]string `name:"http_headers" env:"DATABRICKS_HTTP_HEADERS" auth:"-"`

// Path prefix to be prepended to all API URLs. Useful for HTTP proxies
// that require URL path rewriting.
HTTPPathPrefix string `name:"http_path_prefix" env:"DATABRICKS_HTTP_PATH_PREFIX" auth:"-"`

// Truncate JSON fields in JSON above this limit. Default is 96.
DebugTruncateBytes int `name:"debug_truncate_bytes" env:"DATABRICKS_DEBUG_TRUNCATE_BYTES" auth:"-"`

Expand Down
35 changes: 35 additions & 0 deletions config/config_attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"reflect"
"strconv"
"strings"
)

type Source struct {
Expand Down Expand Up @@ -69,12 +70,44 @@ func (a *ConfigAttribute) SetS(cfg *Config, v string) error {
return err
}
return a.Set(cfg, vv)
case reflect.Map:
// Parse map from string format: key1=value1;key2=value2
vv, err := parseMapFromString(v)
if err != nil {
return fmt.Errorf("cannot parse %s: %w", a.Name, err)
}
return a.Set(cfg, vv)
default:
return fmt.Errorf("cannot set %s of unknown type %s",
a.Name, reflectKind(a.Kind))
}
}

// parseMapFromString parses a string of format "key1=value1;key2=value2" into a map.
// Empty keys or values are allowed. If the string is empty, an empty map is returned.
func parseMapFromString(s string) (map[string]string, error) {
result := make(map[string]string)
if s == "" {
return result, nil
}
pairs := strings.Split(s, ";")
for _, pair := range pairs {
pair = strings.TrimSpace(pair)
if pair == "" {
continue
}
// Split only on the first '=' to allow '=' in values
idx := strings.Index(pair, "=")
if idx == -1 {
return nil, fmt.Errorf("invalid key-value pair %q: missing '='", pair)
}
key := pair[:idx]
value := pair[idx+1:]
result[key] = value
}
return result, nil
}

func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error {
rv := reflect.ValueOf(cfg)
field := rv.Elem().Field(a.num)
Expand All @@ -85,6 +118,8 @@ func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error {
field.SetBool(i.(bool))
case reflect.Int:
field.SetInt(int64(i.(int)))
case reflect.Map:
field.Set(reflect.ValueOf(i))
default:
// must extensively test with providerFixture to avoid this one
return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind))
Expand Down
186 changes: 186 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"context"
"net/http"
"net/url"
"testing"

"github.com/databricks/databricks-sdk-go/credentials/u2m"
Expand Down Expand Up @@ -299,3 +300,188 @@ func TestConfig_getOAuthArgument_Unified(t *testing.T) {
})
}
}

func TestConfig_HTTPHeaders(t *testing.T) {
c := &Config{
Host: "https://my-workspace.cloud.databricks.com",
Token: "test-token",
HTTPHeaders: map[string]string{
"X-Custom-Header-1": "value1",
"X-Custom-Header-2": "value2",
},
}

cfg, err := HTTPClientConfigFromConfig(c)
require.NoError(t, err)

// Create a test request and apply only the first two visitors (host/path and headers)
// Skip other visitors that require context to be set
req, err := http.NewRequestWithContext(context.Background(), "GET", "/api/2.0/clusters/list", nil)
require.NoError(t, err)
req.URL = &url.URL{Path: "/api/2.0/clusters/list"}

// Apply only the first two visitors (host/path rewrite and custom headers)
for i := 0; i < 2 && i < len(cfg.Visitors); i++ {
err = cfg.Visitors[i](req)
require.NoError(t, err)
}

// Verify custom headers are set
assert.Equal(t, "value1", req.Header.Get("X-Custom-Header-1"))
assert.Equal(t, "value2", req.Header.Get("X-Custom-Header-2"))
}

func TestConfig_HTTPPathPrefix(t *testing.T) {
c := &Config{
Host: "https://proxy.example.com",
Token: "test-token",
HTTPPathPrefix: "/prefix/path",
}

cfg, err := HTTPClientConfigFromConfig(c)
require.NoError(t, err)

// Create a test request
req, err := http.NewRequestWithContext(context.Background(), "GET", "/api/2.0/clusters/list", nil)
require.NoError(t, err)
req.URL = &url.URL{Path: "/api/2.0/clusters/list"}

// Apply only the first visitor (host/path rewrite)
err = cfg.Visitors[0](req)
require.NoError(t, err)

// Verify path prefix is prepended
assert.Equal(t, "/prefix/path/api/2.0/clusters/list", req.URL.Path)
}

func TestConfig_HTTPHeadersAndPathPrefix(t *testing.T) {
c := &Config{
Host: "https://proxy.example.com",
Token: "test-token",
HTTPHeaders: map[string]string{
"X-Custom-Header": "custom-value",
},
HTTPPathPrefix: "/prefix/path",
}

cfg, err := HTTPClientConfigFromConfig(c)
require.NoError(t, err)

// Create a test request
req, err := http.NewRequestWithContext(context.Background(), "GET", "/api/2.0/clusters/list", nil)
require.NoError(t, err)
req.URL = &url.URL{Path: "/api/2.0/clusters/list"}

// Apply only the first two visitors (host/path rewrite and custom headers)
for i := 0; i < 2 && i < len(cfg.Visitors); i++ {
err = cfg.Visitors[i](req)
require.NoError(t, err)
}

// Verify both custom header and path prefix
assert.Equal(t, "custom-value", req.Header.Get("X-Custom-Header"))
assert.Equal(t, "/prefix/path/api/2.0/clusters/list", req.URL.Path)
}

func TestConfig_HTTPHeadersFromEnvVar(t *testing.T) {
t.Setenv("DATABRICKS_HTTP_HEADERS", "X-Custom-Header-1=value1;X-Custom-Header-2=value2")
t.Setenv("DATABRICKS_HOST", "https://my-workspace.cloud.databricks.com")
t.Setenv("DATABRICKS_TOKEN", "test-token")

c := &Config{}
err := c.EnsureResolved()
require.NoError(t, err)

assert.Equal(t, map[string]string{
"X-Custom-Header-1": "value1",
"X-Custom-Header-2": "value2",
}, c.HTTPHeaders)
}

func TestConfig_HTTPPathPrefixFromEnvVar(t *testing.T) {
t.Setenv("DATABRICKS_HTTP_PATH_PREFIX", "/prefix/path")
t.Setenv("DATABRICKS_HOST", "https://my-workspace.cloud.databricks.com")
t.Setenv("DATABRICKS_TOKEN", "test-token")

c := &Config{}
err := c.EnsureResolved()
require.NoError(t, err)

assert.Equal(t, "/prefix/path", c.HTTPPathPrefix)
}

func TestParseMapFromString(t *testing.T) {
tests := []struct {
name string
input string
expected map[string]string
wantErr bool
}{
{
name: "empty string",
input: "",
expected: map[string]string{},
},
{
name: "single key-value pair",
input: "key1=value1",
expected: map[string]string{
"key1": "value1",
},
},
{
name: "multiple key-value pairs",
input: "key1=value1;key2=value2",
expected: map[string]string{
"key1": "value1",
"key2": "value2",
},
},
{
name: "value with equals sign",
input: "key1=value=with=equals",
expected: map[string]string{
"key1": "value=with=equals",
},
},
{
name: "empty value",
input: "key1=",
expected: map[string]string{
"key1": "",
},
},
{
name: "whitespace around pairs",
input: " key1=value1 ; key2=value2 ",
expected: map[string]string{
"key1": "value1",
"key2": "value2",
},
},
{
name: "trailing semicolon",
input: "key1=value1;",
expected: map[string]string{
"key1": "value1",
},
},
{
name: "missing equals sign",
input: "key1value1",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseMapFromString(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}