diff --git a/config/config.go b/config/config.go index 6ecc90ca9..130aa9e26 100644 --- a/config/config.go +++ b/config/config.go @@ -18,6 +18,7 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" + "golang.org/x/exp/slices" "golang.org/x/oauth2" ) @@ -63,6 +64,8 @@ const ( InvalidConfig ConfigType = "INVALID_CONFIG" ) +var defaultScopes = []string{"all-apis"} + // Config represents configuration for Databricks Connectivity type Config struct { // Credentials holds an instance of Credentials Strategy to authenticate with Databricks REST APIs. @@ -135,6 +138,26 @@ type Config struct { ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth" auth_types:"oauth-m2m"` ClientSecret string `name:"client_secret" env:"DATABRICKS_CLIENT_SECRET" auth:"oauth,sensitive" auth_types:"oauth-m2m"` + // Scopes is a list of OAuth scopes to request when authenticating. + // + // WARNING: + // - This feature is still in development and may not work as expected + // - This feature is EXPERIMENTAL and may change or be removed without notice. + // - Do NOT use this feature in production environments. + // + // Notes: + // - If Scopes is nil or empty, the default ["all-apis"] scope will be used for backward compatibility. + // - For U2M authentication, the "offline_access" scope will automatically be added to obtain a refresh token + // unless you set DisableOAuthRefreshToken to true. + // - You cannot set Scopes via environment variables. + // - The scopes list will be sorted in-place during configuration resolution. + // - The U2M token cache currently does NOT support differentiated caching for scopes. + Scopes []string `name:"scopes" auth:"-"` + + // DisableOAuthRefreshToken controls whether a refresh token should be requested + // during the U2M authentication flow (default to false). + DisableOAuthRefreshToken bool `name:"disable_oauth_refresh_token" env:"DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN" auth:"-"` + // Path to the Databricks CLI (version >= 0.100.0). DatabricksCliPath string `name:"databricks_cli_path" env:"DATABRICKS_CLI_PATH" auth_types:"databricks-cli"` @@ -445,6 +468,8 @@ func (c *Config) EnsureResolved() error { }, } } + slices.Sort(c.Scopes) + c.Scopes = slices.Compact(c.Scopes) c.resolved = true return nil } @@ -460,6 +485,13 @@ func (c *Config) CanonicalHostName() string { return c.Host } +func (c *Config) GetScopes() []string { + if len(c.Scopes) == 0 { + return defaultScopes + } + return c.Scopes +} + func (c *Config) wrapDebug(err error) error { debug := ConfigAttributes.DebugString(c) if debug == "" { diff --git a/config/config_attribute.go b/config/config_attribute.go index b3e4db1ae..ff47e90f9 100644 --- a/config/config_attribute.go +++ b/config/config_attribute.go @@ -5,8 +5,17 @@ import ( "os" "reflect" "strconv" + "strings" ) +// getenv is the function used to read environment variables. +// It defaults to os.Getenv but can be overwritten in tests. +var getenv = os.Getenv + +// getUserHomeDir is the function used to get user home directory. +// It defaults to os.UserHomeDir but can be overwritten in tests. +var getUserHomeDir = os.UserHomeDir + type Source struct { Type SourceType `json:"type"` Name string `json:"name,omitempty"` @@ -44,7 +53,7 @@ type ConfigAttribute struct { func (a *ConfigAttribute) ReadEnv() (string, string) { for _, envName := range a.EnvVars { - v := os.Getenv(envName) + v := getenv(envName) if v == "" { continue } @@ -69,6 +78,16 @@ func (a *ConfigAttribute) SetS(cfg *Config, v string) error { return err } return a.Set(cfg, vv) + case reflect.Slice: + rawParts := strings.Split(v, ",") + var parts []string + for _, part := range rawParts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + parts = append(parts, trimmed) + } + } + return a.Set(cfg, parts) default: return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind)) @@ -85,6 +104,8 @@ func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error { field.SetBool(i.(bool)) case reflect.Int: field.SetInt(int64(i.(int))) + case reflect.Slice: + field.Set(reflect.ValueOf(i.([]string))) default: // must extensively test with providerFixture to avoid this one return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind)) diff --git a/config/config_attributes.go b/config/config_attributes.go index 1d821618f..009e3ffa9 100644 --- a/config/config_attributes.go +++ b/config/config_attributes.go @@ -2,7 +2,6 @@ package config import ( "fmt" - "os" "reflect" "sort" "strings" @@ -31,7 +30,7 @@ func (a attributes) DebugString(cfg *Config) string { } attrsUsed = append(attrsUsed, fmt.Sprintf("%s=%s", attr.Name, v)) for _, envName := range attr.EnvVars { - v := os.Getenv(envName) + v := getenv(envName) if v == "" { continue } diff --git a/config/config_attributes_test.go b/config/config_attributes_test.go new file mode 100644 index 000000000..b6b4f853f --- /dev/null +++ b/config/config_attributes_test.go @@ -0,0 +1,61 @@ +package config + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +// TestConfigFile_Configure_ListParsing tests that comma-separated list values +// in configuration files are correctly parsed into slices. +func TestConfigFile_Configure_ListParsing(t *testing.T) { + testCases := []struct { + name string + profile string + want []string + }{ + { + name: "single item", + profile: "single-item", + want: []string{"clusters"}, + }, + { + name: "multiple items", + profile: "multiple-items", + want: []string{"alpha", "beta", "gamma"}, + }, + { + name: "whitespace around items is trimmed", + profile: "whitespace-around-items", + want: []string{"alpha", "beta", "gamma"}, + }, + { + name: "empty items are filtered out", + profile: "empty-items-filtered", + want: []string{"alpha", "beta"}, + }, + { + name: "whitespace-only items are filtered out", + profile: "whitespace-only-items-filtered", + want: []string{"alpha", "beta"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + withMockEnv(t, map[string]string{}) + + cfg := &Config{ + Profile: tc.profile, + ConfigFile: "testdata/list-parsing/.databrickscfg", + } + err := ConfigFile.Configure(cfg) + if err != nil { + t.Fatalf("Configure failed: %v", err) + } + if diff := cmp.Diff(tc.want, cfg.Scopes); diff != "" { + t.Errorf("list mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/config/config_file.go b/config/config_file.go index 2c7c33803..a0923c3ad 100644 --- a/config/config_file.go +++ b/config/config_file.go @@ -38,7 +38,7 @@ func LoadFile(path string) (*File, error) { // Expand ~ to home directory. if strings.HasPrefix(path, "~") { - homedir, err := os.UserHomeDir() + homedir, err := getUserHomeDir() if err != nil { return nil, fmt.Errorf("cannot find homedir: %w", err) } diff --git a/config/config_file_test.go b/config/config_file_test.go index 2312b2954..48281cef4 100644 --- a/config/config_file_test.go +++ b/config/config_file_test.go @@ -3,14 +3,40 @@ package config import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/google/go-cmp/cmp" ) +// withMockEnv mocks environment variables for testing config file loading +// without relying on the actual system environment or filesystem. +// getUserHomeDir falls back to the real implementation when HOME is not in +// the env map, allowing tests to optionally override HOME without breaking +// tests that don't need to control the home directory path. +func withMockEnv(t *testing.T, env map[string]string) { + original := getenv + originalUserHomeDir := getUserHomeDir + t.Cleanup(func() { + getenv = original + getUserHomeDir = originalUserHomeDir + }) + getenv = func(key string) string { + return env[key] + } + getUserHomeDir = func() (string, error) { + if home, ok := env["HOME"]; ok { + return home, nil + } + return originalUserHomeDir() + } +} + func TestConfigFileLoad(t *testing.T) { f, err := LoadFile("testdata/.databrickscfg") - require.NoError(t, err) - assert.NotNil(t, f) + if err != nil { + t.Fatalf("LoadFile failed: %v", err) + } + if f == nil { + t.Fatal("expected file to be non-nil") + } for _, name := range []string{ "password-with-double-quotes", @@ -18,7 +44,52 @@ func TestConfigFileLoad(t *testing.T) { "password-without-quotes", } { section := f.Section(name) - require.NotNil(t, section) - assert.Equal(t, "%Y#X$Z", section.Key("password").String()) + if section == nil { + t.Fatalf("expected section %q to be non-nil", name) + } + if got, want := section.Key("password").String(), "%Y#X$Z"; got != want { + t.Errorf("password mismatch for %q: got %q, want %q", name, got, want) + } + } +} + +func TestConfigFile_Scopes(t *testing.T) { + tests := []struct { + name string + profile string + want []string + }{ + { + name: "empty defaults to all-apis", + profile: "scope-empty", + want: []string{"all-apis"}, + }, + { + name: "single scope", + profile: "scope-single", + want: []string{"clusters"}, + }, + { + name: "multiple scopes sorted", + profile: "scope-multiple", + want: []string{"clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving", "pipelines"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + withMockEnv(t, map[string]string{ + "HOME": "testdata/scopes", + }) + + cfg := &Config{Profile: tt.profile} + err := cfg.EnsureResolved() + if err != nil { + t.Fatalf("EnsureResolved failed: %v", err) + } + if diff := cmp.Diff(tt.want, cfg.GetScopes()); diff != "" { + t.Errorf("GetScopes mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/config/config_test.go b/config/config_test.go index 3f6b93392..decf27d98 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -7,10 +7,22 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// mockLoader is a test helper that implements the Loader interface. +type mockLoader func(cfg *Config) error + +func (m mockLoader) Name() string { + return "mockLoader" +} + +func (m mockLoader) Configure(cfg *Config) error { + return m(cfg) +} + func TestHostType_AwsAccount(t *testing.T) { c := &Config{ Host: "https://accounts.cloud.databricks.com", @@ -299,3 +311,68 @@ func TestConfig_getOAuthArgument_Unified(t *testing.T) { }) } } + +func TestConfig_EnsureResolved_scopeNormalization(t *testing.T) { + testCases := []struct { + desc string + scopes []string + want []string + }{ + { + desc: "nil scopes", + scopes: nil, + want: nil, + }, + { + desc: "empty scopes", + scopes: []string{}, + want: []string{}, + }, + { + desc: "single scope", + scopes: []string{"clusters"}, + want: []string{"clusters"}, + }, + { + desc: "already sorted no duplicates", + scopes: []string{"a", "b", "c"}, + want: []string{"a", "b", "c"}, + }, + { + desc: "unsorted scopes are sorted", + scopes: []string{"jobs", "clusters", "pipelines"}, + want: []string{"clusters", "jobs", "pipelines"}, + }, + { + desc: "duplicate scopes are removed", + scopes: []string{"clusters", "jobs", "clusters", "pipelines:read", "jobs"}, + want: []string{"clusters", "jobs", "pipelines:read"}, + }, + { + desc: "all duplicates reduced to one", + scopes: []string{"all-apis", "all-apis", "all-apis"}, + want: []string{"all-apis"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cfg := &Config{ + Host: "https://example.cloud.databricks.com", + Loaders: []Loader{mockLoader(func(cfg *Config) error { + cfg.Scopes = tc.scopes + return nil + })}, + } + + err := cfg.EnsureResolved() + if err != nil { + t.Fatalf("EnsureResolved() error: %v", err) + } + + if diff := cmp.Diff(tc.want, cfg.Scopes); diff != "" { + t.Errorf("EnsureResolved() scopes mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/config/testdata/list-parsing/.databrickscfg b/config/testdata/list-parsing/.databrickscfg new file mode 100644 index 000000000..370db8bbc --- /dev/null +++ b/config/testdata/list-parsing/.databrickscfg @@ -0,0 +1,20 @@ +[single-item] +host = https://example.cloud.databricks.com +scopes = clusters + +[multiple-items] +host = https://example.cloud.databricks.com +scopes = alpha,beta,gamma + +[whitespace-around-items] +host = https://example.cloud.databricks.com +scopes = alpha, beta , gamma + +[empty-items-filtered] +host = https://example.cloud.databricks.com +scopes = alpha,,beta, + +[whitespace-only-items-filtered] +host = https://example.cloud.databricks.com +scopes = alpha, ,beta + diff --git a/config/testdata/scopes/.databrickscfg b/config/testdata/scopes/.databrickscfg new file mode 100644 index 000000000..aec957b51 --- /dev/null +++ b/config/testdata/scopes/.databrickscfg @@ -0,0 +1,10 @@ +[scope-empty] +host = https://example.cloud.databricks.com + +[scope-single] +host = https://example.cloud.databricks.com +scopes = clusters + +[scope-multiple] +host = https://example.cloud.databricks.com +scopes = clusters, jobs, pipelines, iam:read, files:read, mlflow, model-serving