Skip to content

Commit ca31089

Browse files
authored
MCP: properly pass along databricks profile (#4040)
## Changes Make sure all tools and CLI invocations use the same profile and workspace ID. ## Why <!-- Why are these changes needed? Provide the context that the reviewer might be missing. For example, were there any decisions behind the change that are not reflected in the code itself? --> ## Tests <!-- How have you tested the changes? --> <!-- If your PR needs to be included in the release notes for next release, add a separate entry in NEXT_CHANGELOG.md as part of your PR. -->
1 parent b28a03c commit ca31089

File tree

14 files changed

+138
-72
lines changed

14 files changed

+138
-72
lines changed

experimental/apps-mcp/lib/mcp/middleware_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ func TestServerMiddleware(t *testing.T) {
224224
Name: "test-server",
225225
Version: "1.0.0",
226226
}
227-
server := mcp.NewServer(impl, nil)
227+
server := mcp.NewServer(impl, nil, nil)
228228

229229
var executionOrder []string
230230

@@ -264,7 +264,7 @@ func TestServerSessionPersistence(t *testing.T) {
264264
Name: "test-server",
265265
Version: "1.0.0",
266266
}
267-
server := mcp.NewServer(impl, nil)
267+
server := mcp.NewServer(impl, nil, nil)
268268

269269
// Add middleware that increments a counter
270270
server.AddMiddlewareFunc(func(ctx *mcp.MiddlewareContext, next mcp.NextFunc) (*mcp.CallToolResult, error) {

experimental/apps-mcp/lib/mcp/server.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ type serverTool struct {
3030
}
3131

3232
// NewServer creates a new MCP server.
33-
func NewServer(impl *Implementation, options any) *Server {
33+
// If sess is nil, a new session will be created.
34+
func NewServer(impl *Implementation, options any, sess *session.Session) *Server {
35+
if sess == nil {
36+
sess = session.NewSession()
37+
}
3438
return &Server{
3539
impl: impl,
3640
tools: make(map[string]*serverTool),
37-
session: session.NewSession(),
41+
session: sess,
3842
}
3943
}
4044

experimental/apps-mcp/lib/middlewares/databricks_client.go

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ import (
99
"github.com/databricks/cli/experimental/apps-mcp/lib/mcp"
1010
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
1111
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
12+
"github.com/databricks/cli/libs/databrickscfg/profile"
1213
"github.com/databricks/databricks-sdk-go"
1314
"github.com/databricks/databricks-sdk-go/config"
1415
"github.com/databricks/databricks-sdk-go/httpclient"
1516
)
1617

1718
const (
18-
DatabricksClientKey = "databricks_client"
19+
DatabricksClientKey = "databricks_client"
20+
DatabricksProfileKey = "databricks_profile"
1921
)
2022

2123
func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middleware {
@@ -40,8 +42,41 @@ func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middlewar
4042
})
4143
}
4244

43-
func MustGetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
44-
w := MustGetDatabricksClient(ctx)
45+
func GetDatabricksProfile(ctx context.Context) string {
46+
sess, err := session.GetSession(ctx)
47+
if err != nil {
48+
return ""
49+
}
50+
profile, ok := sess.Get(DatabricksProfileKey)
51+
if !ok {
52+
return ""
53+
}
54+
return profile.(string)
55+
}
56+
57+
// GetAvailableProfiles returns all available profiles from ~/.databrickscfg.
58+
func GetAvailableProfiles(ctx context.Context) profile.Profiles {
59+
profiles, err := profile.DefaultProfiler.LoadProfiles(ctx, profile.MatchAllProfiles)
60+
if err != nil {
61+
// If we can't load profiles, return empty list (config file might not exist)
62+
return profile.Profiles{}
63+
}
64+
return profiles
65+
}
66+
67+
func MustGetApiClient(ctx context.Context) *httpclient.ApiClient {
68+
client, err := GetApiClient(ctx)
69+
if err != nil {
70+
panic(err)
71+
}
72+
return client
73+
}
74+
75+
func GetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
76+
w, err := GetDatabricksClient(ctx)
77+
if err != nil {
78+
return nil, err
79+
}
4580
clientCfg, err := config.HTTPClientConfigFromConfig(w.Config)
4681
if err != nil {
4782
return nil, fmt.Errorf("failed to create HTTP client config: %w", err)
@@ -64,28 +99,36 @@ func GetDatabricksClient(ctx context.Context) (*databricks.WorkspaceClient, erro
6499
}
65100
w, ok := sess.Get(DatabricksClientKey)
66101
if !ok {
67-
return nil, errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
102+
return nil, newAuthError(ctx)
68103
}
69104
return w.(*databricks.WorkspaceClient), nil
70105
}
71106

72107
func checkAuth(ctx context.Context) (*databricks.WorkspaceClient, error) {
73108
w, err := databricks.NewWorkspaceClient()
74109
if err != nil {
75-
return nil, wrapAuthError(err)
110+
return nil, WrapAuthError(ctx, err)
76111
}
77112

78113
_, err = w.CurrentUser.Me(ctx)
79114
if err != nil {
80-
return nil, wrapAuthError(err)
115+
return nil, WrapAuthError(ctx, err)
81116
}
82117

83118
return w, nil
84119
}
85120

86-
func wrapAuthError(err error) error {
121+
func WrapAuthError(ctx context.Context, err error) error {
87122
if errors.Is(err, config.ErrCannotConfigureDefault) {
88-
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
123+
return newAuthError(ctx)
89124
}
90125
return err
91126
}
127+
128+
func newAuthError(ctx context.Context) error {
129+
// Prepare template data
130+
data := map[string]any{
131+
"Profiles": GetAvailableProfiles(ctx),
132+
}
133+
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", data))
134+
}

experimental/apps-mcp/lib/middlewares/warehouse.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
8686
// first resolve DATABRICKS_WAREHOUSE_ID env variable
8787
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
8888
if warehouseID != "" {
89-
w := MustGetDatabricksClient(ctx)
89+
w, err := GetDatabricksClient(ctx)
90+
if err != nil {
91+
return nil, fmt.Errorf("get databricks client: %w", err)
92+
}
9093
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
9194
Id: warehouseID,
9295
})
@@ -100,7 +103,7 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
100103
}, nil
101104
}
102105

103-
apiClient, err := MustGetApiClient(ctx)
106+
apiClient, err := GetApiClient(ctx)
104107
if err != nil {
105108
return nil, err
106109
}

experimental/apps-mcp/lib/prompts/auth_error.tmpl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
Not authenticated to Databricks
77

88
I need to know either the Databricks workspace URL or the Databricks profile name.
9-
You can list the available profiles by running `databricks auth profiles`.
109

11-
ASK the user which of the configured profiles or databricks workspace URL they want to use.
10+
The available profiles are:
11+
12+
{{- range .Profiles }}
13+
- {{ .Name }} ({{ .Host }})
14+
{{- end }}
15+
16+
IMPORTANT: YOU MUST ASK the user which of the configured profiles or databricks workspace URL they want to use.
1217
Only then call the `databricks_configure_auth` tool to configure the authentication.
1318

1419
Do not run anything else before authenticating successfully.

experimental/apps-mcp/lib/providers/clitools/configure_auth.go

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
1010
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
1111
"github.com/databricks/databricks-sdk-go"
12-
"github.com/databricks/databricks-sdk-go/config"
1312
)
1413

1514
// ConfigureAuth creates and validates a Databricks workspace client with optional host and profile.
@@ -20,9 +19,8 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
2019
return nil, nil
2120
}
2221

23-
var cfg *databricks.Config
22+
cfg := &databricks.Config{}
2423
if host != nil || profile != nil {
25-
cfg = &databricks.Config{}
2624
if host != nil {
2725
cfg.Host = *host
2826
}
@@ -32,12 +30,7 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
3230
}
3331

3432
var client *databricks.WorkspaceClient
35-
var err error
36-
if cfg != nil {
37-
client, err = databricks.NewWorkspaceClient(cfg)
38-
} else {
39-
client, err = databricks.NewWorkspaceClient()
40-
}
33+
client, err := databricks.NewWorkspaceClient(cfg)
4134
if err != nil {
4235
return nil, err
4336
}
@@ -49,19 +42,14 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
4942
"WorkspaceURL": *host,
5043
}))
5144
}
52-
return nil, wrapAuthError(err)
45+
return nil, middlewares.WrapAuthError(ctx, err)
5346
}
5447

55-
// Store client in session data
48+
// Store client and profile in session data
5649
sess.Set(middlewares.DatabricksClientKey, client)
50+
if profile != nil {
51+
sess.Set(middlewares.DatabricksProfileKey, *profile)
52+
}
5753

5854
return client, nil
5955
}
60-
61-
// wrapAuthError wraps configuration errors with helpful messages
62-
func wrapAuthError(err error) error {
63-
if errors.Is(err, config.ErrCannotConfigureDefault) {
64-
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
65-
}
66-
return err
67-
}

experimental/apps-mcp/lib/providers/clitools/configure_auth_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ func TestConfigureAuthWithCustomHost(t *testing.T) {
7575
}
7676

7777
func TestWrapAuthError(t *testing.T) {
78+
ctx := context.Background()
79+
7880
tests := []struct {
7981
name string
8082
err error
@@ -89,7 +91,7 @@ func TestWrapAuthError(t *testing.T) {
8991

9092
for _, tt := range tests {
9193
t.Run(tt.name, func(t *testing.T) {
92-
wrapped := wrapAuthError(tt.err)
94+
wrapped := middlewares.WrapAuthError(ctx, tt.err)
9395
assert.Contains(t, wrapped.Error(), tt.expected)
9496
})
9597
}

experimental/apps-mcp/lib/providers/clitools/explore.go

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
99
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
1010
"github.com/databricks/cli/libs/databrickscfg/profile"
11-
"github.com/databricks/cli/libs/env"
1211
"github.com/databricks/cli/libs/log"
1312
"github.com/databricks/databricks-sdk-go/service/sql"
1413
)
@@ -21,32 +20,12 @@ func Explore(ctx context.Context) (string, error) {
2120
warehouse = nil
2221
}
2322

24-
currentProfile := getCurrentProfile(ctx)
25-
profiles := getAvailableProfiles(ctx)
23+
currentProfile := middlewares.GetDatabricksProfile(ctx)
24+
profiles := middlewares.GetAvailableProfiles(ctx)
2625

2726
return generateExploreGuidance(ctx, warehouse, currentProfile, profiles), nil
2827
}
2928

30-
// getCurrentProfile returns the currently active profile name.
31-
func getCurrentProfile(ctx context.Context) string {
32-
// Check DATABRICKS_CONFIG_PROFILE env var
33-
profileName := env.Get(ctx, "DATABRICKS_CONFIG_PROFILE")
34-
if profileName == "" {
35-
return "DEFAULT"
36-
}
37-
return profileName
38-
}
39-
40-
// getAvailableProfiles returns all available profiles from ~/.databrickscfg.
41-
func getAvailableProfiles(ctx context.Context) profile.Profiles {
42-
profiles, err := profile.DefaultProfiler.LoadProfiles(ctx, profile.MatchAllProfiles)
43-
if err != nil {
44-
// If we can't load profiles, return empty list (config file might not exist)
45-
return profile.Profiles{}
46-
}
47-
return profiles
48-
}
49-
5029
// generateExploreGuidance creates comprehensive guidance for data exploration.
5130
func generateExploreGuidance(ctx context.Context, warehouse *sql.EndpointInfo, currentProfile string, profiles profile.Profiles) string {
5231
// Build workspace/profile information
@@ -102,6 +81,7 @@ func generateExploreGuidance(ctx context.Context, warehouse *sql.EndpointInfo, c
10281
"WarehouseName": warehouseName,
10382
"WarehouseID": warehouseID,
10483
"ProfilesInfo": profilesInfo,
84+
"Profile": currentProfile,
10585
}
10686

10787
// Render base explore template

experimental/apps-mcp/lib/providers/clitools/invoke_databricks_cli.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,22 @@ func InvokeDatabricksCLI(ctx context.Context, command []string, workingDirectory
1717
return "", errors.New("command is required")
1818
}
1919

20-
workspaceClient := middlewares.MustGetDatabricksClient(ctx)
20+
workspaceClient, err := middlewares.GetDatabricksClient(ctx)
21+
if err != nil {
22+
return "", fmt.Errorf("get databricks client: %w", err)
23+
}
2124
host := workspaceClient.Config.Host
25+
profile := middlewares.GetDatabricksProfile(ctx)
2226

2327
// GetCLIPath returns the path to the current CLI executable
2428
cliPath := common.GetCLIPath()
2529
cmd := exec.CommandContext(ctx, cliPath, command...)
2630
cmd.Dir = workingDirectory
2731
env := os.Environ()
2832
env = append(env, "DATABRICKS_HOST="+host)
33+
if profile != "" {
34+
env = append(env, "DATABRICKS_CONFIG_PROFILE="+profile)
35+
}
2936
cmd.Env = env
3037

3138
output, err := cmd.CombinedOutput()

experimental/apps-mcp/lib/providers/clitools/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (p *Provider) RegisterTools(server *mcpsdk.Server) error {
9393
},
9494
func(ctx context.Context, req *mcpsdk.CallToolRequest, args struct{}) (*mcpsdk.CallToolResult, any, error) {
9595
log.Debug(ctx, "explore called")
96-
result, err := Explore(session.WithSession(ctx, p.session))
96+
result, err := Explore(ctx)
9797
if err != nil {
9898
return nil, nil, err
9999
}

0 commit comments

Comments
 (0)