@@ -2,12 +2,15 @@ package middlewares
22
33import (
44 "context"
5+ "errors"
56 "fmt"
7+ "net/url"
8+ "sort"
69 "sync"
710
811 "github.com/databricks/cli/experimental/apps-mcp/lib/session"
9- "github.com/databricks/cli/libs/databrickscfg/cfgpickers"
1012 "github.com/databricks/cli/libs/env"
13+ "github.com/databricks/databricks-sdk-go/httpclient"
1114 "github.com/databricks/databricks-sdk-go/service/sql"
1215)
1316
@@ -80,14 +83,13 @@ func GetWarehouseID(ctx context.Context) (string, error) {
8083}
8184
8285func getDefaultWarehouse (ctx context.Context ) (* sql.EndpointInfo , error ) {
83- w , err := GetDatabricksClient (ctx )
84- if err != nil {
85- return nil , fmt .Errorf ("get databricks client: %w" , err )
86- }
87-
8886 // first resolve DATABRICKS_WAREHOUSE_ID env variable
8987 warehouseID := env .Get (ctx , "DATABRICKS_WAREHOUSE_ID" )
9088 if warehouseID != "" {
89+ w , err := GetDatabricksClient (ctx )
90+ if err != nil {
91+ return nil , fmt .Errorf ("get databricks client: %w" , err )
92+ }
9193 warehouse , err := w .Warehouses .Get (ctx , sql.GetWarehouseRequest {
9294 Id : warehouseID ,
9395 })
@@ -101,5 +103,48 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
101103 }, nil
102104 }
103105
104- return cfgpickers .GetDefaultWarehouse (ctx , w )
106+ apiClient , err := GetApiClient (ctx )
107+ if err != nil {
108+ return nil , err
109+ }
110+
111+ apiPath := "/api/2.0/sql/warehouses"
112+ params := url.Values {}
113+ params .Add ("skip_cannot_use" , "true" )
114+ fullPath := fmt .Sprintf ("%s?%s" , apiPath , params .Encode ())
115+
116+ var response sql.ListWarehousesResponse
117+ err = apiClient .Do (ctx , "GET" , fullPath , httpclient .WithResponseUnmarshal (& response ))
118+ if err != nil {
119+ return nil , err
120+ }
121+
122+ priorities := map [sql.State ]int {
123+ sql .StateRunning : 1 ,
124+ sql .StateStarting : 2 ,
125+ sql .StateStopped : 3 ,
126+ sql .StateStopping : 4 ,
127+ sql .StateDeleted : 99 ,
128+ sql .StateDeleting : 99 ,
129+ }
130+
131+ warehouses := response .Warehouses
132+ sort .Slice (warehouses , func (i , j int ) bool {
133+ return priorities [warehouses [i ].State ] < priorities [warehouses [j ].State ]
134+ })
135+
136+ if len (warehouses ) == 0 {
137+ return nil , errNoWarehouses ()
138+ }
139+
140+ firstWarehouse := warehouses [0 ]
141+ if firstWarehouse .State == sql .StateDeleted || firstWarehouse .State == sql .StateDeleting {
142+ return nil , errNoWarehouses ()
143+ }
144+
145+ return & firstWarehouse , nil
146+ }
147+
148+ func errNoWarehouses () error {
149+ return errors .New ("no warehouse found. You can explicitly set the warehouse ID using the DATABRICKS_WAREHOUSE_ID environment variable" )
105150}
0 commit comments