Skip to content

Commit f18e34a

Browse files
authored
Add lazy Databricks client initialization (#3972)
Implements lazy authentication removing need for upfront configuration. ## Changes - Add DatabricksClientMiddleware for on-demand auth - Add ConfigureAuth() and databricks_configure_auth tool - Add Get/MustGetDatabricksClient() helpers in session - Remove PreRunE and DatabricksHost from config - Add EngineGuideMiddleware and ToolCounterMiddleware - Update providers to use lazy client from session ## Dependencies - Requires PR #3970 (middleware infrastructure) - Includes PR #3971 (prompt templates) ## Testing - Databricks provider tests pass - Auth test added
1 parent 51e416d commit f18e34a

File tree

13 files changed

+410
-79
lines changed

13 files changed

+410
-79
lines changed

experimental/apps-mcp/cmd/apps_mcp.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
package mcp
22

33
import (
4-
"errors"
5-
"os"
6-
7-
"github.com/databricks/cli/cmd/root"
84
mcplib "github.com/databricks/cli/experimental/apps-mcp/lib"
95
"github.com/databricks/cli/experimental/apps-mcp/lib/server"
10-
"github.com/databricks/cli/libs/cmdctx"
116
"github.com/databricks/cli/libs/log"
127
"github.com/spf13/cobra"
138
)
@@ -37,25 +32,13 @@ The server communicates via stdio using the Model Context Protocol.`,
3732
3833
# Start with deployment tools enabled
3934
databricks experimental apps-mcp --warehouse-id abc123 --allow-deployment`,
40-
PreRunE: root.MustWorkspaceClient,
4135
RunE: func(cmd *cobra.Command, args []string) error {
4236
ctx := cmd.Context()
4337

44-
if warehouseID == "" {
45-
warehouseID = os.Getenv("DATABRICKS_WAREHOUSE_ID")
46-
if warehouseID == "" {
47-
return errors.New("DATABRICKS_WAREHOUSE_ID environment variable is required")
48-
}
49-
}
50-
51-
w := cmdctx.WorkspaceClient(ctx)
52-
5338
// Build MCP config from flags
5439
cfg := &mcplib.Config{
5540
AllowDeployment: allowDeployment,
5641
WithWorkspaceTools: withWorkspaceTools,
57-
WarehouseID: warehouseID,
58-
DatabricksHost: w.Config.Host,
5942
IoConfig: &mcplib.IoConfig{
6043
Validation: &mcplib.ValidationConfig{},
6144
},

experimental/apps-mcp/lib/config.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ package mcp
77
type Config struct {
88
AllowDeployment bool
99
WithWorkspaceTools bool
10-
WarehouseID string
11-
DatabricksHost string
1210
IoConfig *IoConfig
1311
}
1412

@@ -52,6 +50,5 @@ func DefaultConfig() *Config {
5250
},
5351
Validation: validationCfg,
5452
},
55-
WarehouseID: "",
5653
}
5754
}

experimental/apps-mcp/lib/mcp/server.go

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"sync"
99

10+
"github.com/databricks/cli/experimental/apps-mcp/lib/errors"
1011
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
1112
)
1213

@@ -133,7 +134,7 @@ func (s *Server) handleRequest(ctx context.Context, req *JSONRPCRequest) *JSONRP
133134
JSONRPC: "2.0",
134135
ID: req.ID,
135136
Error: &JSONRPCError{
136-
Code: -32601,
137+
Code: errors.CodeMethodNotFound,
137138
Message: "method not found: " + req.Method,
138139
},
139140
}
@@ -158,7 +159,7 @@ func (s *Server) handleInitialize(req *JSONRPCRequest) *JSONRPCResponse {
158159
JSONRPC: "2.0",
159160
ID: req.ID,
160161
Error: &JSONRPCError{
161-
Code: -32603,
162+
Code: errors.CodeInternalError,
162163
Message: fmt.Sprintf("failed to marshal result: %v", err),
163164
},
164165
}
@@ -187,14 +188,7 @@ func (s *Server) handleToolsList(req *JSONRPCRequest) *JSONRPCResponse {
187188

188189
data, err := json.Marshal(result)
189190
if err != nil {
190-
return &JSONRPCResponse{
191-
JSONRPC: "2.0",
192-
ID: req.ID,
193-
Error: &JSONRPCError{
194-
Code: -32603,
195-
Message: fmt.Sprintf("failed to marshal result: %v", err),
196-
},
197-
}
191+
return CreateNewErrorResponse(req.ID, errors.CodeInternalError, fmt.Sprintf("failed to marshal result: %v", err))
198192
}
199193

200194
return &JSONRPCResponse{
@@ -208,46 +202,26 @@ func (s *Server) handleToolsList(req *JSONRPCRequest) *JSONRPCResponse {
208202
func (s *Server) handleToolsCall(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
209203
var params CallToolParams
210204
if err := json.Unmarshal(req.Params, &params); err != nil {
211-
return &JSONRPCResponse{
212-
JSONRPC: "2.0",
213-
ID: req.ID,
214-
Error: &JSONRPCError{
215-
Code: -32602,
216-
Message: fmt.Sprintf("invalid params: %v", err),
217-
},
218-
}
205+
return CreateNewErrorResponse(req.ID, errors.CodeInvalidParams, fmt.Sprintf("invalid params: %v", err))
219206
}
220207

221208
s.toolsMu.RLock()
222209
st, ok := s.tools[params.Name]
223210
s.toolsMu.RUnlock()
224211

225212
if !ok {
226-
return &JSONRPCResponse{
227-
JSONRPC: "2.0",
228-
ID: req.ID,
229-
Error: &JSONRPCError{
230-
Code: -32602,
231-
Message: "tool not found: " + params.Name,
232-
},
233-
}
213+
return CreateNewErrorResponse(req.ID, errors.CodeInvalidParams, "tool not found: "+params.Name)
234214
}
235215

236216
toolReq := &CallToolRequest{
217+
ID: req.ID,
237218
Tool: st.tool,
238219
Params: params,
239220
}
240221

241222
result, err := st.handler(ctx, toolReq)
242223
if err != nil {
243-
return &JSONRPCResponse{
244-
JSONRPC: "2.0",
245-
ID: req.ID,
246-
Error: &JSONRPCError{
247-
Code: -32603,
248-
Message: fmt.Sprintf("tool execution error: %v", err),
249-
},
250-
}
224+
result = CreateNewTextContentResultError(err)
251225
}
252226

253227
// Convert Content slice to []any for JSON marshaling
@@ -266,14 +240,7 @@ func (s *Server) handleToolsCall(ctx context.Context, req *JSONRPCRequest) *JSON
266240

267241
data, err := json.Marshal(resultData)
268242
if err != nil {
269-
return &JSONRPCResponse{
270-
JSONRPC: "2.0",
271-
ID: req.ID,
272-
Error: &JSONRPCError{
273-
Code: -32603,
274-
Message: fmt.Sprintf("failed to marshal result: %v", err),
275-
},
276-
}
243+
return CreateNewErrorResponse(req.ID, errors.CodeInternalError, fmt.Sprintf("failed to marshal result: %v", err))
277244
}
278245

279246
return &JSONRPCResponse{

experimental/apps-mcp/lib/mcp/types.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type Tool struct {
2121

2222
// CallToolRequest represents a request to call a tool.
2323
type CallToolRequest struct {
24+
ID any
2425
Tool *Tool
2526
Params CallToolParams
2627
}
@@ -56,3 +57,14 @@ func CreateNewTextContentResultError(err error) *CallToolResult {
5657
IsError: true,
5758
}
5859
}
60+
61+
func CreateNewErrorResponse(id any, code int, message string) *JSONRPCResponse {
62+
return &JSONRPCResponse{
63+
JSONRPC: "2.0",
64+
ID: id,
65+
Error: &JSONRPCError{
66+
Code: code,
67+
Message: message,
68+
},
69+
}
70+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package middlewares
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"slices"
8+
9+
"github.com/databricks/cli/experimental/apps-mcp/lib/mcp"
10+
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
11+
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
12+
"github.com/databricks/databricks-sdk-go"
13+
"github.com/databricks/databricks-sdk-go/config"
14+
"github.com/databricks/databricks-sdk-go/httpclient"
15+
)
16+
17+
const (
18+
DatabricksClientKey = "databricks_client"
19+
)
20+
21+
func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middleware {
22+
return mcp.NewMiddleware(func(ctx *mcp.MiddlewareContext, next mcp.NextFunc) (*mcp.CallToolResult, error) {
23+
if slices.Contains(unauthorizedToolNames, ctx.Request.Tool.Name) {
24+
return next()
25+
}
26+
27+
_, ok := ctx.Session.Get(DatabricksClientKey)
28+
if !ok {
29+
w, err := checkAuth(ctx.Ctx)
30+
if err != nil {
31+
return mcp.CreateNewTextContentResultError(err), nil
32+
}
33+
ctx.Session.Set(DatabricksClientKey, w)
34+
}
35+
36+
return next()
37+
})
38+
}
39+
40+
func MustGetApiClient(ctx context.Context) (*httpclient.ApiClient, error) {
41+
w := MustGetDatabricksClient(ctx)
42+
clientCfg, err := config.HTTPClientConfigFromConfig(w.Config)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to create HTTP client config: %w", err)
45+
}
46+
return httpclient.NewApiClient(clientCfg), nil
47+
}
48+
49+
func MustGetDatabricksClient(ctx context.Context) *databricks.WorkspaceClient {
50+
w, err := GetDatabricksClient(ctx)
51+
if err != nil {
52+
panic(err)
53+
}
54+
return w
55+
}
56+
57+
func GetDatabricksClient(ctx context.Context) (*databricks.WorkspaceClient, error) {
58+
sess, err := session.GetSession(ctx)
59+
if err != nil {
60+
return nil, err
61+
}
62+
w, ok := sess.Get(DatabricksClientKey)
63+
if !ok {
64+
return nil, errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
65+
}
66+
return w.(*databricks.WorkspaceClient), nil
67+
}
68+
69+
func checkAuth(ctx context.Context) (*databricks.WorkspaceClient, error) {
70+
w, err := databricks.NewWorkspaceClient()
71+
if err != nil {
72+
return nil, wrapAuthError(err)
73+
}
74+
75+
_, err = w.CurrentUser.Me(ctx)
76+
if err != nil {
77+
return nil, wrapAuthError(err)
78+
}
79+
80+
return w, nil
81+
}
82+
83+
func wrapAuthError(err error) error {
84+
if errors.Is(err, config.ErrCannotConfigureDefault) {
85+
return errors.New(prompts.MustExecuteTemplate("auth_error.tmpl", nil))
86+
}
87+
return err
88+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package middlewares
2+
3+
import (
4+
"github.com/databricks/cli/experimental/apps-mcp/lib/mcp"
5+
"github.com/databricks/cli/experimental/apps-mcp/lib/prompts"
6+
)
7+
8+
// NewEngineGuideMiddleware creates middleware that injects the initialization message on the first tool call.
9+
func NewEngineGuideMiddleware() mcp.Middleware {
10+
return mcp.NewMiddleware(func(ctx *mcp.MiddlewareContext, next mcp.NextFunc) (*mcp.CallToolResult, error) {
11+
isFirst := ctx.Session.GetBool("isFirstToolCall", true)
12+
initializationMessage := prompts.MustExecuteTemplate("initialization_message.tmpl", nil)
13+
14+
// If this was the first call and execution was successful, prepend the guide
15+
if isFirst {
16+
ctx.Session.Set("isFirstToolCall", false)
17+
result, err := next()
18+
if err != nil {
19+
result = mcp.CreateNewTextContentResultError(err)
20+
}
21+
if result != nil && len(result.Content) > 0 {
22+
if textContent, ok := result.Content[0].(*mcp.TextContent); ok {
23+
textContent.Text = initializationMessage + "\n\n---\n\n" + textContent.Text
24+
}
25+
}
26+
27+
return result, nil
28+
}
29+
return next()
30+
})
31+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package middlewares
2+
3+
import (
4+
"github.com/databricks/cli/experimental/apps-mcp/lib/mcp"
5+
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
6+
"github.com/databricks/cli/experimental/apps-mcp/lib/trajectory"
7+
)
8+
9+
// NewTrajectoryMiddleware creates middleware that records tool calls in a trajectory tracker.
10+
func NewTrajectoryMiddleware(tracker *trajectory.Tracker) mcp.Middleware {
11+
return mcp.NewMiddleware(func(ctx *mcp.MiddlewareContext, next mcp.NextFunc) (*mcp.CallToolResult, error) {
12+
result, err := next()
13+
if tracker != nil {
14+
tracker.RecordToolCall(ctx.Request.Params.Name, ctx.Request.Params.Arguments, result, err)
15+
}
16+
17+
return result, err
18+
})
19+
}
20+
21+
// NewToolCounterMiddleware creates middleware that increments the tool call counter.
22+
func NewToolCounterMiddleware(session *session.Session) mcp.Middleware {
23+
return mcp.NewMiddleware(func(ctx *mcp.MiddlewareContext, next mcp.NextFunc) (*mcp.CallToolResult, error) {
24+
session.IncrementToolCalls()
25+
return next()
26+
})
27+
}

0 commit comments

Comments
 (0)