Skip to content

Commit 5bc507d

Browse files
Revert apps-mcp warehouse changes (for separate PR)
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 076c18e commit 5bc507d

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed

experimental/apps-mcp/lib/middlewares/warehouse.go

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ package middlewares
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"net/url"
8+
"sort"
69
"sync"
710

811
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
9-
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
1012
"github.com/databricks/cli/libs/env"
13+
"github.com/databricks/databricks-sdk-go/httpclient"
1114
"github.com/databricks/databricks-sdk-go/service/sql"
1215
)
1316

@@ -80,14 +83,13 @@ func GetWarehouseID(ctx context.Context) (string, error) {
8083
}
8184

8285
func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
83-
w, err := GetDatabricksClient(ctx)
84-
if err != nil {
85-
return nil, fmt.Errorf("get databricks client: %w", err)
86-
}
87-
8886
// first resolve DATABRICKS_WAREHOUSE_ID env variable
8987
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
9088
if warehouseID != "" {
89+
w, err := GetDatabricksClient(ctx)
90+
if err != nil {
91+
return nil, fmt.Errorf("get databricks client: %w", err)
92+
}
9193
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
9294
Id: warehouseID,
9395
})
@@ -101,5 +103,48 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
101103
}, nil
102104
}
103105

104-
return cfgpickers.GetDefaultWarehouse(ctx, w)
106+
apiClient, err := GetApiClient(ctx)
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
apiPath := "/api/2.0/sql/warehouses"
112+
params := url.Values{}
113+
params.Add("skip_cannot_use", "true")
114+
fullPath := fmt.Sprintf("%s?%s", apiPath, params.Encode())
115+
116+
var response sql.ListWarehousesResponse
117+
err = apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response))
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
priorities := map[sql.State]int{
123+
sql.StateRunning: 1,
124+
sql.StateStarting: 2,
125+
sql.StateStopped: 3,
126+
sql.StateStopping: 4,
127+
sql.StateDeleted: 99,
128+
sql.StateDeleting: 99,
129+
}
130+
131+
warehouses := response.Warehouses
132+
sort.Slice(warehouses, func(i, j int) bool {
133+
return priorities[warehouses[i].State] < priorities[warehouses[j].State]
134+
})
135+
136+
if len(warehouses) == 0 {
137+
return nil, errNoWarehouses()
138+
}
139+
140+
firstWarehouse := warehouses[0]
141+
if firstWarehouse.State == sql.StateDeleted || firstWarehouse.State == sql.StateDeleting {
142+
return nil, errNoWarehouses()
143+
}
144+
145+
return &firstWarehouse, nil
146+
}
147+
148+
func errNoWarehouses() error {
149+
return errors.New("no warehouse found. You can explicitly set the warehouse ID using the DATABRICKS_WAREHOUSE_ID environment variable")
105150
}

0 commit comments

Comments
 (0)