Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
)

Expand Down Expand Up @@ -63,6 +64,8 @@ const (
InvalidConfig ConfigType = "INVALID_CONFIG"
)

var defaultScopes = []string{"all-apis"}

// Config represents configuration for Databricks Connectivity
type Config struct {
// Credentials holds an instance of Credentials Strategy to authenticate with Databricks REST APIs.
Expand Down Expand Up @@ -135,6 +138,26 @@ type Config struct {
ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth" auth_types:"oauth-m2m"`
ClientSecret string `name:"client_secret" env:"DATABRICKS_CLIENT_SECRET" auth:"oauth,sensitive" auth_types:"oauth-m2m"`

// Scopes is a list of OAuth scopes to request when authenticating.
//
// WARNING:
// - This feature is still in development and may not work as expected
// - This feature is EXPERIMENTAL and may change or be removed without notice.
// - Do NOT use this feature in production environments.
//
// Notes:
// - If Scopes is nil or empty, the default ["all-apis"] scope will be used for backward compatibility.
// - For U2M authentication, the "offline_access" scope will automatically be added to obtain a refresh token
// unless you set DisableOAuthRefreshToken to true.
// - You cannot set Scopes via environment variables.
// - The scopes list will be sorted in-place during configuration resolution.
// - The U2M token cache currently does NOT support differentiated caching for scopes.
Scopes []string `name:"scopes" auth:"-"`

// DisableOAuthRefreshToken controls whether a refresh token should be requested
// during the U2M authentication flow (default to false).
DisableOAuthRefreshToken bool `name:"disable_oauth_refresh_token" env:"DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN" auth:"-"`

// Path to the Databricks CLI (version >= 0.100.0).
DatabricksCliPath string `name:"databricks_cli_path" env:"DATABRICKS_CLI_PATH" auth_types:"databricks-cli"`

Expand Down Expand Up @@ -445,6 +468,8 @@ func (c *Config) EnsureResolved() error {
},
}
}
slices.Sort(c.Scopes)
c.Scopes = slices.Compact(c.Scopes)
c.resolved = true
return nil
}
Expand All @@ -460,6 +485,13 @@ func (c *Config) CanonicalHostName() string {
return c.Host
}

func (c *Config) GetScopes() []string {
if len(c.Scopes) == 0 {
return defaultScopes
}
return c.Scopes
}

func (c *Config) wrapDebug(err error) error {
debug := ConfigAttributes.DebugString(c)
if debug == "" {
Expand Down
23 changes: 22 additions & 1 deletion config/config_attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@ import (
"os"
"reflect"
"strconv"
"strings"
)

// getenv is the function used to read environment variables.
// It defaults to os.Getenv but can be overwritten in tests.
var getenv = os.Getenv

// getUserHomeDir is the function used to get user home directory.
// It defaults to os.UserHomeDir but can be overwritten in tests.
var getUserHomeDir = os.UserHomeDir

type Source struct {
Type SourceType `json:"type"`
Name string `json:"name,omitempty"`
Expand Down Expand Up @@ -44,7 +53,7 @@ type ConfigAttribute struct {

func (a *ConfigAttribute) ReadEnv() (string, string) {
for _, envName := range a.EnvVars {
v := os.Getenv(envName)
v := getenv(envName)
if v == "" {
continue
}
Expand All @@ -69,6 +78,16 @@ func (a *ConfigAttribute) SetS(cfg *Config, v string) error {
return err
}
return a.Set(cfg, vv)
case reflect.Slice:
rawParts := strings.Split(v, ",")
var parts []string
for _, part := range rawParts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
parts = append(parts, trimmed)
}
}
return a.Set(cfg, parts)
default:
return fmt.Errorf("cannot set %s of unknown type %s",
a.Name, reflectKind(a.Kind))
Expand All @@ -85,6 +104,8 @@ func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error {
field.SetBool(i.(bool))
case reflect.Int:
field.SetInt(int64(i.(int)))
case reflect.Slice:
field.Set(reflect.ValueOf(i.([]string)))
default:
// must extensively test with providerFixture to avoid this one
return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind))
Expand Down
3 changes: 1 addition & 2 deletions config/config_attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package config

import (
"fmt"
"os"
"reflect"
"sort"
"strings"
Expand Down Expand Up @@ -31,7 +30,7 @@ func (a attributes) DebugString(cfg *Config) string {
}
attrsUsed = append(attrsUsed, fmt.Sprintf("%s=%s", attr.Name, v))
for _, envName := range attr.EnvVars {
v := os.Getenv(envName)
v := getenv(envName)
if v == "" {
continue
}
Expand Down
61 changes: 61 additions & 0 deletions config/config_attributes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package config

import (
"testing"

"github.com/google/go-cmp/cmp"
)

// TestConfigFile_Configure_ListParsing tests that comma-separated list values
// in configuration files are correctly parsed into slices.
func TestConfigFile_Configure_ListParsing(t *testing.T) {
testCases := []struct {
name string
profile string
want []string
}{
{
name: "single item",
profile: "single-item",
want: []string{"clusters"},
},
{
name: "multiple items",
profile: "multiple-items",
want: []string{"alpha", "beta", "gamma"},
},
{
name: "whitespace around items is trimmed",
profile: "whitespace-around-items",
want: []string{"alpha", "beta", "gamma"},
},
{
name: "empty items are filtered out",
profile: "empty-items-filtered",
want: []string{"alpha", "beta"},
},
{
name: "whitespace-only items are filtered out",
profile: "whitespace-only-items-filtered",
want: []string{"alpha", "beta"},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
withMockEnv(t, map[string]string{})

cfg := &Config{
Profile: tc.profile,
ConfigFile: "testdata/list-parsing/.databrickscfg",
}
err := ConfigFile.Configure(cfg)
if err != nil {
t.Fatalf("Configure failed: %v", err)
}
if diff := cmp.Diff(tc.want, cfg.Scopes); diff != "" {
t.Errorf("list mismatch (-want +got):\n%s", diff)
}
})
}
}
2 changes: 1 addition & 1 deletion config/config_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func LoadFile(path string) (*File, error) {

// Expand ~ to home directory.
if strings.HasPrefix(path, "~") {
homedir, err := os.UserHomeDir()
homedir, err := getUserHomeDir()
if err != nil {
return nil, fmt.Errorf("cannot find homedir: %w", err)
}
Expand Down
83 changes: 77 additions & 6 deletions config/config_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,93 @@ package config
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/google/go-cmp/cmp"
)

// withMockEnv mocks environment variables for testing config file loading
// without relying on the actual system environment or filesystem.
// getUserHomeDir falls back to the real implementation when HOME is not in
// the env map, allowing tests to optionally override HOME without breaking
// tests that don't need to control the home directory path.
func withMockEnv(t *testing.T, env map[string]string) {
original := getenv
originalUserHomeDir := getUserHomeDir
t.Cleanup(func() {
getenv = original
getUserHomeDir = originalUserHomeDir
})
getenv = func(key string) string {
return env[key]
}
getUserHomeDir = func() (string, error) {
if home, ok := env["HOME"]; ok {
return home, nil
}
return originalUserHomeDir()
}
}

func TestConfigFileLoad(t *testing.T) {
f, err := LoadFile("testdata/.databrickscfg")
require.NoError(t, err)
assert.NotNil(t, f)
if err != nil {
t.Fatalf("LoadFile failed: %v", err)
}
if f == nil {
t.Fatal("expected file to be non-nil")
}

for _, name := range []string{
"password-with-double-quotes",
"password-with-single-quotes",
"password-without-quotes",
} {
section := f.Section(name)
require.NotNil(t, section)
assert.Equal(t, "%Y#X$Z", section.Key("password").String())
if section == nil {
t.Fatalf("expected section %q to be non-nil", name)
}
if got, want := section.Key("password").String(), "%Y#X$Z"; got != want {
t.Errorf("password mismatch for %q: got %q, want %q", name, got, want)
}
}
}

func TestConfigFile_Scopes(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this test? It looks redundant with the two tests we have added. If you decide to keep it, please update:

			withMockEnv(t, map[string]string{})
			t.Setenv("HOME", "testdata/scopes")

to

			withMockEnv(t, map[string]string{"HOME": "testdata/scopes"})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer keeping it as it acts as an integration test and ensures backwards compatibility (the first case checks that "all-apis" is used if no scopes are provided in the config file).

Made the change.

tests := []struct {
name string
profile string
want []string
}{
{
name: "empty defaults to all-apis",
profile: "scope-empty",
want: []string{"all-apis"},
},
{
name: "single scope",
profile: "scope-single",
want: []string{"clusters"},
},
{
name: "multiple scopes sorted",
profile: "scope-multiple",
want: []string{"clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving", "pipelines"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
withMockEnv(t, map[string]string{
"HOME": "testdata/scopes",
})

cfg := &Config{Profile: tt.profile}
err := cfg.EnsureResolved()
if err != nil {
t.Fatalf("EnsureResolved failed: %v", err)
}
if diff := cmp.Diff(tt.want, cfg.GetScopes()); diff != "" {
t.Errorf("GetScopes mismatch (-want +got):\n%s", diff)
}
})
}
}
77 changes: 77 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,22 @@ import (

"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// mockLoader is a test helper that implements the Loader interface.
type mockLoader func(cfg *Config) error

func (m mockLoader) Name() string {
return "mockLoader"
}

func (m mockLoader) Configure(cfg *Config) error {
return m(cfg)
}

func TestHostType_AwsAccount(t *testing.T) {
c := &Config{
Host: "https://accounts.cloud.databricks.com",
Expand Down Expand Up @@ -299,3 +311,68 @@ func TestConfig_getOAuthArgument_Unified(t *testing.T) {
})
}
}

func TestConfig_EnsureResolved_scopeNormalization(t *testing.T) {
testCases := []struct {
desc string
scopes []string
want []string
}{
{
desc: "nil scopes",
scopes: nil,
want: nil,
},
{
desc: "empty scopes",
scopes: []string{},
want: []string{},
},
{
desc: "single scope",
scopes: []string{"clusters"},
want: []string{"clusters"},
},
{
desc: "already sorted no duplicates",
scopes: []string{"a", "b", "c"},
want: []string{"a", "b", "c"},
},
{
desc: "unsorted scopes are sorted",
scopes: []string{"jobs", "clusters", "pipelines"},
want: []string{"clusters", "jobs", "pipelines"},
},
{
desc: "duplicate scopes are removed",
scopes: []string{"clusters", "jobs", "clusters", "pipelines:read", "jobs"},
want: []string{"clusters", "jobs", "pipelines:read"},
},
{
desc: "all duplicates reduced to one",
scopes: []string{"all-apis", "all-apis", "all-apis"},
want: []string{"all-apis"},
},
}

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cfg := &Config{
Host: "https://example.cloud.databricks.com",
Loaders: []Loader{mockLoader(func(cfg *Config) error {
cfg.Scopes = tc.scopes
return nil
})},
}

err := cfg.EnsureResolved()
if err != nil {
t.Fatalf("EnsureResolved() error: %v", err)
}

if diff := cmp.Diff(tc.want, cfg.Scopes); diff != "" {
t.Errorf("EnsureResolved() scopes mismatch (-want +got):\n%s", diff)
}
})
}
}
Loading
Loading