Skip to content

Commit 6f4b03f

Browse files
authored
Add lazy warehouse resolution with smart discovery (#3973)
Implements intelligent warehouse auto-discovery removing need for configuration. ## Changes - Add GetWarehouseID() with smart fallback chain - Auto-discover warehouses via API - Prefer RUNNING, fall back to STOPPED (auto-start) - Cache warehouse ID in session - Remove warehouse ID from config - Update providers to use GetWarehouseID ## Dependencies - Requires PR #3972 (lazy auth) ## Testing - Databricks provider tests pass
1 parent 9bd2f89 commit 6f4b03f

File tree

6 files changed

+152
-8
lines changed

6 files changed

+152
-8
lines changed

experimental/apps-mcp/lib/middlewares/databricks_client.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middlewar
3131
return mcp.CreateNewTextContentResultError(err), nil
3232
}
3333
ctx.Session.Set(DatabricksClientKey, w)
34+
35+
// Start background warehouse loading once client is initialized
36+
go loadWarehouseInBackground(ctx.Ctx)
3437
}
3538

3639
return next()
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package middlewares
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/url"
8+
"sort"
9+
"sync"
10+
11+
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
12+
"github.com/databricks/cli/libs/env"
13+
"github.com/databricks/databricks-sdk-go/httpclient"
14+
"github.com/databricks/databricks-sdk-go/service/sql"
15+
)
16+
17+
const (
18+
warehouseLoadingKey = "warehouse_loading"
19+
warehouseErrorKey = "warehouse_error"
20+
)
21+
22+
// loadWarehouseInBackground loads the default warehouse in a background goroutine.
23+
func loadWarehouseInBackground(ctx context.Context) {
24+
sess, err := session.GetSession(ctx)
25+
if err != nil {
26+
return
27+
}
28+
29+
// Create a WaitGroup to track loading state
30+
var wg sync.WaitGroup
31+
wg.Add(1)
32+
sess.Set(warehouseLoadingKey, &wg)
33+
34+
defer wg.Done()
35+
36+
warehouse, err := getDefaultWarehouse(ctx)
37+
if err != nil {
38+
sess.Set(warehouseErrorKey, err)
39+
return
40+
}
41+
42+
sess.Set("warehouse_id", warehouse.Id)
43+
}
44+
45+
func GetWarehouseID(ctx context.Context) (string, error) {
46+
sess, err := session.GetSession(ctx)
47+
if err != nil {
48+
return "", err
49+
}
50+
51+
// Wait for background loading if in progress
52+
if wgRaw, ok := sess.Get(warehouseLoadingKey); ok {
53+
wg := wgRaw.(*sync.WaitGroup)
54+
wg.Wait()
55+
sess.Delete(warehouseLoadingKey)
56+
57+
// Check if there was an error during background loading
58+
if errRaw, ok := sess.Get(warehouseErrorKey); ok {
59+
sess.Delete(warehouseErrorKey)
60+
return "", errRaw.(error)
61+
}
62+
}
63+
64+
warehouseID, ok := sess.Get("warehouse_id")
65+
if !ok {
66+
// Fallback: synchronously load if background loading didn't happen
67+
warehouse, err := getDefaultWarehouse(ctx)
68+
if err != nil {
69+
return "", err
70+
}
71+
warehouseID = warehouse.Id
72+
sess.Set("warehouse_id", warehouseID.(string))
73+
}
74+
75+
return warehouseID.(string), nil
76+
}
77+
78+
func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
79+
// first resolve DATABRICKS_WAREHOUSE_ID env variable
80+
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
81+
if warehouseID != "" {
82+
w := MustGetDatabricksClient(ctx)
83+
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
84+
Id: warehouseID,
85+
})
86+
if err != nil {
87+
return nil, fmt.Errorf("get warehouse: %w", err)
88+
}
89+
return &sql.EndpointInfo{
90+
Id: warehouse.Id,
91+
}, nil
92+
}
93+
94+
apiClient, err := MustGetApiClient(ctx)
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
apiPath := "/api/2.0/sql/warehouses"
100+
params := url.Values{}
101+
params.Add("skip_cannot_use", "true")
102+
fullPath := fmt.Sprintf("%s?%s", apiPath, params.Encode())
103+
104+
var response sql.ListWarehousesResponse
105+
err = apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response))
106+
if err != nil {
107+
return nil, err
108+
}
109+
110+
priorities := map[sql.State]int{
111+
sql.StateRunning: 1,
112+
sql.StateStarting: 2,
113+
sql.StateStopped: 3,
114+
sql.StateStopping: 4,
115+
sql.StateDeleted: 99,
116+
sql.StateDeleting: 99,
117+
}
118+
119+
warehouses := response.Warehouses
120+
sort.Slice(warehouses, func(i, j int) bool {
121+
return priorities[warehouses[i].State] < priorities[warehouses[j].State]
122+
})
123+
124+
if len(warehouses) == 0 {
125+
return nil, errNoWarehouses()
126+
}
127+
128+
firstWarehouse := warehouses[0]
129+
if firstWarehouse.State == sql.StateDeleted || firstWarehouse.State == sql.StateDeleting {
130+
return nil, errNoWarehouses()
131+
}
132+
133+
return &firstWarehouse, nil
134+
}
135+
136+
func errNoWarehouses() error {
137+
return errors.New("no warehouse found. You can explicitly set the warehouse ID using the DATABRICKS_WAREHOUSE_ID environment variable")
138+
}

experimental/apps-mcp/lib/providers/databricks/databricks.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,9 @@ type DatabricksRestClient struct {
435435
func NewDatabricksRestClient(ctx context.Context, cfg *mcp.Config) (*DatabricksRestClient, error) {
436436
client := middlewares.MustGetDatabricksClient(ctx)
437437

438-
warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID")
439-
if warehouseID == "" {
440-
return nil, errors.New("DATABRICKS_WAREHOUSE_ID not configured")
438+
warehouseID, err := middlewares.GetWarehouseID(ctx)
439+
if err != nil {
440+
return nil, fmt.Errorf("failed to get warehouse ID: %w", err)
441441
}
442442

443443
return &DatabricksRestClient{

experimental/apps-mcp/lib/providers/databricks/deployment.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package databricks
33
import (
44
"context"
55
"fmt"
6-
"os"
76
"os/exec"
87
"time"
98

109
mcp "github.com/databricks/cli/experimental/apps-mcp/lib"
10+
"github.com/databricks/cli/experimental/apps-mcp/lib/middlewares"
1111
"github.com/databricks/cli/libs/cmdctx"
1212
"github.com/databricks/databricks-sdk-go/service/apps"
1313
"github.com/databricks/databricks-sdk-go/service/iam"
@@ -102,8 +102,11 @@ func DeployApp(ctx context.Context, cfg *mcp.Config, appInfo *apps.App) error {
102102
return nil
103103
}
104104

105-
func ResourcesFromEnv(cfg *mcp.Config) (*apps.AppResource, error) {
106-
warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID")
105+
func ResourcesFromEnv(ctx context.Context, cfg *mcp.Config) (*apps.AppResource, error) {
106+
warehouseID, err := middlewares.GetWarehouseID(ctx)
107+
if err != nil {
108+
return nil, fmt.Errorf("failed to get warehouse ID: %w", err)
109+
}
107110

108111
return &apps.AppResource{
109112
Name: "base",

experimental/apps-mcp/lib/providers/databricks/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (p *Provider) RegisterTools(server *mcpsdk.Server) error {
6060
mcpsdk.AddTool(server,
6161
&mcpsdk.Tool{
6262
Name: "databricks_configure_auth",
63-
Description: "Configure authentication for Databricks. Validates credentials and stores the authenticated client in the session. Must be called before using other Databricks tools if using non-default host or profile.",
63+
Description: "Configure authentication for Databricks. Only call when Databricks authentication has has failed to authenticate automatically or when the user explicitly asks for using a specific host or profile. Validates credentials and stores the authenticated client in the session.",
6464
},
6565
func(ctx context.Context, req *mcpsdk.CallToolRequest, args ConfigureAuthInput) (*mcpsdk.CallToolResult, any, error) {
6666
log.Debug(ctx, "databricks_configure_auth called")

experimental/apps-mcp/lib/providers/deployment/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func (p *Provider) getOrCreateApp(ctx context.Context, name, description string,
289289

290290
log.Infof(ctx, "App not found, creating new app: name=%s", name)
291291

292-
resources, err := databricks.ResourcesFromEnv(p.config)
292+
resources, err := databricks.ResourcesFromEnv(ctx, p.config)
293293
if err != nil {
294294
return nil, err
295295
}

0 commit comments

Comments
 (0)