Skip to content

Commit a913b1a

Browse files
committed
properly pass along databricks profile
1 parent a93dc5a commit a913b1a

File tree

7 files changed

+81
-56
lines changed

7 files changed

+81
-56
lines changed

experimental/apps-mcp/lib/middlewares/databricks_client.go

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ import (
99
"github.com/databricks/cli/experimental/apps-mcp/lib/mcp"
1010
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
1111
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
12+
"github.com/databricks/cli/libs/databrickscfg/profile"
1213
"github.com/databricks/databricks-sdk-go"
1314
"github.com/databricks/databricks-sdk-go/config"
1415
"github.com/databricks/databricks-sdk-go/httpclient"
1516
)
1617

1718
const (
18-
DatabricksClientKey = "databricks_client"
19+
DatabricksClientKey = "databricks_client"
20+
DatabricksProfileKey = "databricks_profile"
1921
)
2022

2123
func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middleware {
@@ -40,8 +42,41 @@ func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middlewar
4042
})
4143
}
4244

43-
func MustGetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
44-
w := MustGetDatabricksClient(ctx)
45+
func GetDatabricksProfile(ctx context.Context) string {
46+
sess, err := session.GetSession(ctx)
47+
if err != nil {
48+
return ""
49+
}
50+
profile, ok := sess.Get(DatabricksProfileKey)
51+
if !ok {
52+
return ""
53+
}
54+
return profile.(string)
55+
}
56+
57+
// GetAvailableProfiles returns all available profiles from ~/.databrickscfg.
58+
func GetAvailableProfiles(ctx context.Context) profile.Profiles {
59+
profiles, err := profile.DefaultProfiler.LoadProfiles(ctx, profile.MatchAllProfiles)
60+
if err != nil {
61+
// If we can't load profiles, return empty list (config file might not exist)
62+
return profile.Profiles{}
63+
}
64+
return profiles
65+
}
66+
67+
func MustGetApiClient(ctx context.Context) *httpclient.ApiClient {
68+
client, err := GetApiClient(ctx)
69+
if err != nil {
70+
panic(err)
71+
}
72+
return client
73+
}
74+
75+
func GetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
76+
w, err := GetDatabricksClient(ctx)
77+
if err != nil {
78+
return nil, err
79+
}
4580
clientCfg, err := config.HTTPClientConfigFromConfig(w.Config)
4681
if err != nil {
4782
return nil, fmt.Errorf("failed to create HTTP client config: %w", err)
@@ -64,28 +99,36 @@ func GetDatabricksClient(ctx context.Context) (*databricks.WorkspaceClient, erro
6499
}
65100
w, ok := sess.Get(DatabricksClientKey)
66101
if !ok {
67-
return nil, errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
102+
return nil, newAuthError(ctx)
68103
}
69104
return w.(*databricks.WorkspaceClient), nil
70105
}
71106

72107
func checkAuth(ctx context.Context) (*databricks.WorkspaceClient, error) {
73108
w, err := databricks.NewWorkspaceClient()
74109
if err != nil {
75-
return nil, wrapAuthError(err)
110+
return nil, WrapAuthError(ctx, err)
76111
}
77112

78113
_, err = w.CurrentUser.Me(ctx)
79114
if err != nil {
80-
return nil, wrapAuthError(err)
115+
return nil, WrapAuthError(ctx, err)
81116
}
82117

83118
return w, nil
84119
}
85120

86-
func wrapAuthError(err error) error {
121+
func WrapAuthError(ctx context.Context, err error) error {
87122
if errors.Is(err, config.ErrCannotConfigureDefault) {
88-
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
123+
return newAuthError(ctx)
89124
}
90125
return err
91126
}
127+
128+
func newAuthError(ctx context.Context) error {
129+
// Prepare template data
130+
data := map[string]any{
131+
"Profiles": GetAvailableProfiles(ctx),
132+
}
133+
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", data))
134+
}

experimental/apps-mcp/lib/middlewares/warehouse.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
8686
// first resolve DATABRICKS_WAREHOUSE_ID env variable
8787
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
8888
if warehouseID != "" {
89-
w := MustGetDatabricksClient(ctx)
89+
w, err := GetDatabricksClient(ctx)
90+
if err != nil {
91+
return nil, fmt.Errorf("get databricks client: %w", err)
92+
}
9093
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
9194
Id: warehouseID,
9295
})
@@ -100,7 +103,7 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
100103
}, nil
101104
}
102105

103-
apiClient, err := MustGetApiClient(ctx)
106+
apiClient, err := GetApiClient(ctx)
104107
if err != nil {
105108
return nil, err
106109
}

experimental/apps-mcp/lib/prompts/auth_error.tmpl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
Not authenticated to Databricks
77

88
I need to know either the Databricks workspace URL or the Databricks profile name.
9-
You can list the available profiles by running `databricks auth profiles`.
109

11-
ASK the user which of the configured profiles or databricks workspace URL they want to use.
10+
The available profiles are:
11+
12+
{{- range .Profiles }}
13+
- {{ .Name }} ({{ .Host }})
14+
{{- end }}
15+
16+
IMPORTANT: YOU MUSTASK the user which of the configured profiles or databricks workspace URL they want to use.
1217
Only then call the `databricks_configure_auth` tool to configure the authentication.
1318

1419
Do not run anything else before authenticating successfully.

experimental/apps-mcp/lib/providers/clitools/configure_auth.go

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
1010
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
1111
"github.com/databricks/databricks-sdk-go"
12-
"github.com/databricks/databricks-sdk-go/config"
1312
)
1413

1514
// ConfigureAuth creates and validates a Databricks workspace client with optional host and profile.
@@ -20,9 +19,8 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
2019
return nil, nil
2120
}
2221

23-
var cfg *databricks.Config
22+
cfg := &databricks.Config{}
2423
if host != nil || profile != nil {
25-
cfg = &databricks.Config{}
2624
if host != nil {
2725
cfg.Host = *host
2826
}
@@ -32,12 +30,7 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
3230
}
3331

3432
var client *databricks.WorkspaceClient
35-
var err error
36-
if cfg != nil {
37-
client, err = databricks.NewWorkspaceClient(cfg)
38-
} else {
39-
client, err = databricks.NewWorkspaceClient()
40-
}
33+
client, err := databricks.NewWorkspaceClient(cfg)
4134
if err != nil {
4235
return nil, err
4336
}
@@ -49,19 +42,14 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
4942
"WorkspaceURL": *host,
5043
}))
5144
}
52-
return nil, wrapAuthError(err)
45+
return nil, middlewares.WrapAuthError(ctx, err)
5346
}
5447

55-
// Store client in session data
48+
// Store client and profile in session data
5649
sess.Set(middlewares.DatabricksClientKey, client)
50+
if profile != nil {
51+
sess.Set(middlewares.DatabricksProfileKey, *profile)
52+
}
5753

5854
return client, nil
5955
}
60-
61-
// wrapAuthError wraps configuration errors with helpful messages
62-
func wrapAuthError(err error) error {
63-
if errors.Is(err, config.ErrCannotConfigureDefault) {
64-
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
65-
}
66-
return err
67-
}

experimental/apps-mcp/lib/providers/clitools/explore.go

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
99
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
1010
"github.com/databricks/cli/libs/databrickscfg/profile"
11-
"github.com/databricks/cli/libs/env"
1211
"github.com/databricks/cli/libs/log"
1312
"github.com/databricks/databricks-sdk-go/service/sql"
1413
)
@@ -21,32 +20,12 @@ func Explore(ctx context.Context) (string, error) {
2120
warehouse = nil
2221
}
2322

24-
currentProfile := getCurrentProfile(ctx)
25-
profiles := getAvailableProfiles(ctx)
23+
currentProfile := middlewares.GetDatabricksProfile(ctx)
24+
profiles := middlewares.GetAvailableProfiles(ctx)
2625

2726
return generateExploreGuidance(ctx, warehouse, currentProfile, profiles), nil
2827
}
2928

30-
// getCurrentProfile returns the currently active profile name.
31-
func getCurrentProfile(ctx context.Context) string {
32-
// Check DATABRICKS_CONFIG_PROFILE env var
33-
profileName := env.Get(ctx, "DATABRICKS_CONFIG_PROFILE")
34-
if profileName == "" {
35-
return "DEFAULT"
36-
}
37-
return profileName
38-
}
39-
40-
// getAvailableProfiles returns all available profiles from ~/.databrickscfg.
41-
func getAvailableProfiles(ctx context.Context) profile.Profiles {
42-
profiles, err := profile.DefaultProfiler.LoadProfiles(ctx, profile.MatchAllProfiles)
43-
if err != nil {
44-
// If we can't load profiles, return empty list (config file might not exist)
45-
return profile.Profiles{}
46-
}
47-
return profiles
48-
}
49-
5029
// generateExploreGuidance creates comprehensive guidance for data exploration.
5130
func generateExploreGuidance(ctx context.Context, warehouse *sql.EndpointInfo, currentProfile string, profiles profile.Profiles) string {
5231
// Build workspace/profile information

experimental/apps-mcp/lib/providers/clitools/invoke_databricks_cli.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,22 @@ func InvokeDatabricksCLI(ctx context.Context, command []string, workingDirectory
1717
return "", errors.New("command is required")
1818
}
1919

20-
workspaceClient := middlewares.MustGetDatabricksClient(ctx)
20+
workspaceClient, err := middlewares.GetDatabricksClient(ctx)
21+
if err != nil {
22+
return "", fmt.Errorf("get databricks client: %w", err)
23+
}
2124
host := workspaceClient.Config.Host
25+
profile := middlewares.GetDatabricksProfile(ctx)
2226

2327
// GetCLIPath returns the path to the current CLI executable
2428
cliPath := common.GetCLIPath()
2529
cmd := exec.CommandContext(ctx, cliPath, command...)
2630
cmd.Dir = workingDirectory
2731
env := os.Environ()
2832
env = append(env, "DATABRICKS_HOST="+host)
33+
if profile != "" {
34+
env = append(env, "DATABRICKS_CONFIG_PROFILE="+profile)
35+
}
2936
cmd.Env = env
3037

3138
output, err := cmd.CombinedOutput()

experimental/apps-mcp/lib/providers/clitools/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (p *Provider) RegisterTools(server *mcpsdk.Server) error {
9393
},
9494
func(ctx context.Context, req *mcpsdk.CallToolRequest, args struct{}) (*mcpsdk.CallToolResult, any, error) {
9595
log.Debug(ctx, "explore called")
96-
result, err := Explore(session.WithSession(ctx, p.session))
96+
result, err := Explore(ctx)
9797
if err != nil {
9898
return nil, nil, err
9999
}

0 commit comments

Comments
 (0)