Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions experimental/apps-mcp/lib/mcp/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestServerMiddleware(t *testing.T) {
Name: "test-server",
Version: "1.0.0",
}
server := mcp.NewServer(impl, nil)
server := mcp.NewServer(impl, nil, nil)

var executionOrder []string

Expand Down Expand Up @@ -264,7 +264,7 @@ func TestServerSessionPersistence(t *testing.T) {
Name: "test-server",
Version: "1.0.0",
}
server := mcp.NewServer(impl, nil)
server := mcp.NewServer(impl, nil, nil)

// Add middleware that increments a counter
server.AddMiddlewareFunc(func(ctx *mcp.MiddlewareContext, next mcp.NextFunc) (*mcp.CallToolResult, error) {
Expand Down
8 changes: 6 additions & 2 deletions experimental/apps-mcp/lib/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ type serverTool struct {
}

// NewServer creates a new MCP server.
func NewServer(impl *Implementation, options any) *Server {
// If sess is nil, a new session will be created.
func NewServer(impl *Implementation, options any, sess *session.Session) *Server {
if sess == nil {
sess = session.NewSession()
}
return &Server{
impl: impl,
tools: make(map[string]*serverTool),
session: session.NewSession(),
session: sess,
}
}

Expand Down
59 changes: 51 additions & 8 deletions experimental/apps-mcp/lib/middlewares/databricks_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (
"github.com/databricks/cli/experimental/apps-mcp/lib/mcp"
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/httpclient"
)

const (
DatabricksClientKey = "databricks_client"
DatabricksClientKey = "databricks_client"
DatabricksProfileKey = "databricks_profile"
)

func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middleware {
Expand All @@ -40,8 +42,41 @@ func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middlewar
})
}

func MustGetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
w := MustGetDatabricksClient(ctx)
func GetDatabricksProfile(ctx context.Context) string {
sess, err := session.GetSession(ctx)
if err != nil {
return ""
}
profile, ok := sess.Get(DatabricksProfileKey)
if !ok {
return ""
}
return profile.(string)
}

// GetAvailableProfiles returns all available profiles from ~/.databrickscfg.
func GetAvailableProfiles(ctx context.Context) profile.Profiles {
profiles, err := profile.DefaultProfiler.LoadProfiles(ctx, profile.MatchAllProfiles)
if err != nil {
// If we can't load profiles, return empty list (config file might not exist)
return profile.Profiles{}
}
return profiles
}

func MustGetApiClient(ctx context.Context) *httpclient.ApiClient {
client, err := GetApiClient(ctx)
if err != nil {
panic(err)
}
return client
}

func GetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
w, err := GetDatabricksClient(ctx)
if err != nil {
return nil, err
}
clientCfg, err := config.HTTPClientConfigFromConfig(w.Config)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client config: %w", err)
Expand All @@ -64,28 +99,36 @@ func GetDatabricksClient(ctx context.Context) (*databricks.WorkspaceClient, erro
}
w, ok := sess.Get(DatabricksClientKey)
if !ok {
return nil, errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
return nil, newAuthError(ctx)
}
return w.(*databricks.WorkspaceClient), nil
}

func checkAuth(ctx context.Context) (*databricks.WorkspaceClient, error) {
w, err := databricks.NewWorkspaceClient()
if err != nil {
return nil, wrapAuthError(err)
return nil, WrapAuthError(ctx, err)
}

_, err = w.CurrentUser.Me(ctx)
if err != nil {
return nil, wrapAuthError(err)
return nil, WrapAuthError(ctx, err)
}

return w, nil
}

func wrapAuthError(err error) error {
func WrapAuthError(ctx context.Context, err error) error {
if errors.Is(err, config.ErrCannotConfigureDefault) {
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
return newAuthError(ctx)
}
return err
}

func newAuthError(ctx context.Context) error {
// Prepare template data
data := map[string]any{
"Profiles": GetAvailableProfiles(ctx),
}
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", data))
}
7 changes: 5 additions & 2 deletions experimental/apps-mcp/lib/middlewares/warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
// first resolve DATABRICKS_WAREHOUSE_ID env variable
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
if warehouseID != "" {
w := MustGetDatabricksClient(ctx)
w, err := GetDatabricksClient(ctx)
if err != nil {
return nil, fmt.Errorf("get databricks client: %w", err)
}
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
Id: warehouseID,
})
Expand All @@ -100,7 +103,7 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
}, nil
}

apiClient, err := MustGetApiClient(ctx)
apiClient, err := GetApiClient(ctx)
if err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions experimental/apps-mcp/lib/prompts/auth_error.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
Not authenticated to Databricks

I need to know either the Databricks workspace URL or the Databricks profile name.
You can list the available profiles by running `databricks auth profiles`.

ASK the user which of the configured profiles or databricks workspace URL they want to use.
The available profiles are:

{{- range .Profiles }}
- {{ .Name }} ({{ .Host }})
{{- end }}

IMPORTANT: YOU MUST ASK the user which of the configured profiles or databricks workspace URL they want to use.
Only then call the `databricks_configure_auth` tool to configure the authentication.

Do not run anything else before authenticating successfully.
Expand Down
26 changes: 7 additions & 19 deletions experimental/apps-mcp/lib/providers/clitools/configure_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
)

// ConfigureAuth creates and validates a Databricks workspace client with optional host and profile.
Expand All @@ -20,9 +19,8 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
return nil, nil
}

var cfg *databricks.Config
cfg := &databricks.Config{}
if host != nil || profile != nil {
cfg = &databricks.Config{}
if host != nil {
cfg.Host = *host
}
Expand All @@ -32,12 +30,7 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
}

var client *databricks.WorkspaceClient
var err error
if cfg != nil {
client, err = databricks.NewWorkspaceClient(cfg)
} else {
client, err = databricks.NewWorkspaceClient()
}
client, err := databricks.NewWorkspaceClient(cfg)
if err != nil {
return nil, err
}
Expand All @@ -49,19 +42,14 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st
"WorkspaceURL": *host,
}))
}
return nil, wrapAuthError(err)
return nil, middlewares.WrapAuthError(ctx, err)
}

// Store client in session data
// Store client and profile in session data
sess.Set(middlewares.DatabricksClientKey, client)
if profile != nil {
sess.Set(middlewares.DatabricksProfileKey, *profile)
}

return client, nil
}

// wrapAuthError wraps configuration errors with helpful messages
func wrapAuthError(err error) error {
if errors.Is(err, config.ErrCannotConfigureDefault) {
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
}
return err
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func TestConfigureAuthWithCustomHost(t *testing.T) {
}

func TestWrapAuthError(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
err error
Expand All @@ -89,7 +91,7 @@ func TestWrapAuthError(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wrapped := wrapAuthError(tt.err)
wrapped := middlewares.WrapAuthError(ctx, tt.err)
assert.Contains(t, wrapped.Error(), tt.expected)
})
}
Expand Down
26 changes: 3 additions & 23 deletions experimental/apps-mcp/lib/providers/clitools/explore.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go/service/sql"
)
Expand All @@ -21,32 +20,12 @@ func Explore(ctx context.Context) (string, error) {
warehouse = nil
}

currentProfile := getCurrentProfile(ctx)
profiles := getAvailableProfiles(ctx)
currentProfile := middlewares.GetDatabricksProfile(ctx)
profiles := middlewares.GetAvailableProfiles(ctx)

return generateExploreGuidance(ctx, warehouse, currentProfile, profiles), nil
}

// getCurrentProfile returns the currently active profile name.
func getCurrentProfile(ctx context.Context) string {
// Check DATABRICKS_CONFIG_PROFILE env var
profileName := env.Get(ctx, "DATABRICKS_CONFIG_PROFILE")
if profileName == "" {
return "DEFAULT"
}
return profileName
}

// getAvailableProfiles returns all available profiles from ~/.databrickscfg.
func getAvailableProfiles(ctx context.Context) profile.Profiles {
profiles, err := profile.DefaultProfiler.LoadProfiles(ctx, profile.MatchAllProfiles)
if err != nil {
// If we can't load profiles, return empty list (config file might not exist)
return profile.Profiles{}
}
return profiles
}

// generateExploreGuidance creates comprehensive guidance for data exploration.
func generateExploreGuidance(ctx context.Context, warehouse *sql.EndpointInfo, currentProfile string, profiles profile.Profiles) string {
// Build workspace/profile information
Expand Down Expand Up @@ -102,6 +81,7 @@ func generateExploreGuidance(ctx context.Context, warehouse *sql.EndpointInfo, c
"WarehouseName": warehouseName,
"WarehouseID": warehouseID,
"ProfilesInfo": profilesInfo,
"Profile": currentProfile,
}

// Render base explore template
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@ func InvokeDatabricksCLI(ctx context.Context, command []string, workingDirectory
return "", errors.New("command is required")
}

workspaceClient := middlewares.MustGetDatabricksClient(ctx)
workspaceClient, err := middlewares.GetDatabricksClient(ctx)
if err != nil {
return "", fmt.Errorf("get databricks client: %w", err)
}
host := workspaceClient.Config.Host
profile := middlewares.GetDatabricksProfile(ctx)

// GetCLIPath returns the path to the current CLI executable
cliPath := common.GetCLIPath()
cmd := exec.CommandContext(ctx, cliPath, command...)
cmd.Dir = workingDirectory
env := os.Environ()
env = append(env, "DATABRICKS_HOST="+host)
if profile != "" {
env = append(env, "DATABRICKS_CONFIG_PROFILE="+profile)
}
cmd.Env = env

output, err := cmd.CombinedOutput()
Expand Down
2 changes: 1 addition & 1 deletion experimental/apps-mcp/lib/providers/clitools/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (p *Provider) RegisterTools(server *mcpsdk.Server) error {
},
func(ctx context.Context, req *mcpsdk.CallToolRequest, args struct{}) (*mcpsdk.CallToolResult, any, error) {
log.Debug(ctx, "explore called")
result, err := Explore(session.WithSession(ctx, p.session))
result, err := Explore(ctx)
if err != nil {
return nil, nil, err
}
Expand Down
11 changes: 4 additions & 7 deletions experimental/apps-mcp/lib/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ func NewServer(ctx context.Context, cfg *mcp.Config) *Server {
Version: build.GetInfo().Version,
}

server := mcpsdk.NewServer(impl, nil)
sess := session.NewSession()

// Set enabled capabilities for this MCP server
sess.Set(session.CapabilitiesDataKey, []string{"apps"})
server := mcpsdk.NewServer(impl, nil, sess)

tracker, err := trajectory.NewTracker(ctx, sess, cfg)
if err != nil {
Expand All @@ -50,6 +47,9 @@ func NewServer(ctx context.Context, cfg *mcp.Config) *Server {

sess.SetTracker(tracker)

// Set enabled capabilities for this MCP server
sess.Set(session.CapabilitiesDataKey, []string{"apps"})

return &Server{
server: server,
config: cfg,
Expand Down Expand Up @@ -85,9 +85,6 @@ func (s *Server) RegisterTools(ctx context.Context) error {
func (s *Server) registerCLIToolsProvider(ctx context.Context) error {
log.Info(ctx, "Registering CLI tools provider")

// Add session to context
ctx = session.WithSession(ctx, s.session)

provider, err := clitools.NewProvider(ctx, s.config, s.session)
if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
"description": "SQL Warehouse ID",
"order": 2
},
"profile": {
"type": "string",
"description": "Profile Name",
"default": "",
"order": 3
},
"app_description": {
"type": "string",
"description": "App Description (Optional)",
"default": "A Databricks App powered by Databricks AppKit",
"order": 3
"order": 4
}

},
"success_message": "\nYour new project has been created in the '{{.project_name}}' directory!"
"success_message": "\nYour new project has been created in the '{{.project_name}}' directory!\nYOU MUST read {{.project_name}}/CLAUDE.md immediately. It is STRONGLY RECOMMENDED to immediately run `npm install`, run `npm run dev` in the background, and open http://localhost:8000 in your browser before making changes to the app."
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DATABRICKS_HOST={{workspace_host}}
{{if ne .profile ""}}DATABRICKS_CONFIG_PROFILE={{.profile}}{{else}}DATABRICKS_HOST={{workspace_host}}{{end}}
DATABRICKS_WAREHOUSE_ID={{.sql_warehouse_id}}
DATABRICKS_APP_PORT=8000
DATABRICKS_APP_NAME=minimal
Expand Down
Loading