@@ -23,6 +23,7 @@ import (
2323 "github.com/databricks/cli/internal/build"
2424 "github.com/databricks/cli/libs/cmdio"
2525 "github.com/databricks/databricks-sdk-go"
26+ "github.com/databricks/databricks-sdk-go/service/compute"
2627 "github.com/databricks/databricks-sdk-go/service/jobs"
2728 "github.com/databricks/databricks-sdk-go/service/workspace"
2829 "github.com/gorilla/websocket"
@@ -58,6 +59,8 @@ type ClientOptions struct {
5859 SSHKeysDir string
5960 // Name of the client public key file to be used in the ssh-tunnel secrets scope.
6061 ClientPublicKeyName string
62+ // If true, the CLI will attempt to start the cluster if it is not running.
63+ AutoStartCluster bool
6164 // Additional arguments to pass to the SSH client in the non proxy mode.
6265 AdditionalArgs []string
6366}
@@ -74,6 +77,11 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
7477 cancel ()
7578 }()
7679
80+ err := checkClusterState (ctx , client , opts .ClusterID , opts .AutoStartCluster )
81+ if err != nil {
82+ return err
83+ }
84+
7785 keyPath , err := keys .GetLocalSSHKeyPath (opts .ClusterID , opts .SSHKeysDir )
7886 if err != nil {
7987 return fmt .Errorf ("failed to get local keys folder: %w" , err )
@@ -100,7 +108,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
100108 if err := UploadTunnelReleases (ctx , client , version , opts .ReleasesDir ); err != nil {
101109 return fmt .Errorf ("failed to upload ssh-tunnel binaries: %w" , err )
102110 }
103- userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts . ClusterID , keysSecretScopeName , opts . ClientPublicKeyName , version , opts . ShutdownDelay , opts . MaxClients , opts . ServerTimeout )
111+ userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , version , keysSecretScopeName , opts )
104112 if err != nil {
105113 return fmt .Errorf ("failed to ensure that ssh server is running: %w" , err )
106114 }
@@ -123,10 +131,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
123131 cmdio .LogString (ctx , fmt .Sprintf ("Server port: %d" , serverPort ))
124132
125133 if opts .ProxyMode {
126- return runSSHProxy (ctx , client , opts . ClusterID , serverPort , opts . HandoverTimeout )
134+ return runSSHProxy (ctx , client , serverPort , opts )
127135 } else {
128136 cmdio .LogString (ctx , fmt .Sprintf ("Additional SSH arguments: %v" , opts .AdditionalArgs ))
129- return spawnSSHClient (ctx , opts . ClusterID , userName , privateKeyPath , serverPort , opts . HandoverTimeout , opts . AdditionalArgs )
137+ return spawnSSHClient (ctx , userName , privateKeyPath , serverPort , opts )
130138 }
131139}
132140
@@ -164,8 +172,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
164172 return serverPort , string (bodyBytes ), nil
165173}
166174
167- func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , keysSecretScopeName , publicKeySecretName , version string , shutdownDelay time. Duration , maxClients int , serverTimeout time. Duration ) (int64 , error ) {
168- contentDir , err := sshWorkspace .GetWorkspaceContentDir (ctx , client , version , clusterID )
175+ func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , version , keysSecretScopeName string , opts ClientOptions ) (int64 , error ) {
176+ contentDir , err := sshWorkspace .GetWorkspaceContentDir (ctx , client , version , opts . ClusterID )
169177 if err != nil {
170178 return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
171179 }
@@ -175,7 +183,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
175183 return 0 , fmt .Errorf ("failed to create directory in the remote workspace: %w" , err )
176184 }
177185
178- sshTunnelJobName := "ssh-server-bootstrap-" + clusterID
186+ sshTunnelJobName := "ssh-server-bootstrap-" + opts . ClusterID
179187 jobNotebookPath := filepath .ToSlash (filepath .Join (contentDir , "ssh-server-bootstrap" ))
180188 notebookContent := "# Databricks notebook source\n " + sshServerBootstrapScript
181189 encodedContent := base64 .StdEncoding .EncodeToString ([]byte (notebookContent ))
@@ -193,7 +201,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
193201
194202 submitRun := jobs.SubmitRun {
195203 RunName : sshTunnelJobName ,
196- TimeoutSeconds : int (serverTimeout .Seconds ()),
204+ TimeoutSeconds : int (opts . ServerTimeout .Seconds ()),
197205 Tasks : []jobs.SubmitTask {
198206 {
199207 TaskKey : "start_ssh_server" ,
@@ -202,13 +210,13 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
202210 BaseParameters : map [string ]string {
203211 "version" : version ,
204212 "keysSecretScopeName" : keysSecretScopeName ,
205- "authorizedKeySecretName" : publicKeySecretName ,
206- "shutdownDelay" : shutdownDelay .String (),
207- "maxClients" : strconv .Itoa (maxClients ),
213+ "authorizedKeySecretName" : opts . ClientPublicKeyName ,
214+ "shutdownDelay" : opts . ShutdownDelay .String (),
215+ "maxClients" : strconv .Itoa (opts . MaxClients ),
208216 },
209217 },
210- TimeoutSeconds : int (serverTimeout .Seconds ()),
211- ExistingClusterId : clusterID ,
218+ TimeoutSeconds : int (opts . ServerTimeout .Seconds ()),
219+ ExistingClusterId : opts . ClusterID ,
212220 },
213221 },
214222 }
@@ -222,24 +230,24 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
222230 return runResult .Response .RunId , nil
223231}
224232
225- func spawnSSHClient (ctx context.Context , clusterID , userName , privateKeyPath string , serverPort int , handoverTimeout time. Duration , additionalArgs [] string ) error {
233+ func spawnSSHClient (ctx context.Context , userName , privateKeyPath string , serverPort int , opts ClientOptions ) error {
226234 executablePath , err := os .Executable ()
227235 if err != nil {
228236 return fmt .Errorf ("failed to get current executable path: %w" , err )
229237 }
230238
231- proxyCommand := fmt .Sprintf ("%s ssh connect --proxy --cluster=%s --handover-timeout=%s --metadata=%s,%d" ,
232- executablePath , clusterID , handoverTimeout . String (), userName , serverPort )
239+ proxyCommand := fmt .Sprintf ("%s ssh connect --proxy --cluster=%s --handover-timeout=%s --metadata=%s,%d --auto-start-cluster=%t " ,
240+ executablePath , opts . ClusterID , opts . HandoverTimeout . String (), userName , serverPort , opts . AutoStartCluster )
233241
234242 sshArgs := []string {
235243 "-l" , userName ,
236244 "-i" , privateKeyPath ,
237245 "-o" , "StrictHostKeyChecking=accept-new" ,
238246 "-o" , "ConnectTimeout=360" ,
239247 "-o" , "ProxyCommand=" + proxyCommand ,
240- clusterID ,
248+ opts . ClusterID ,
241249 }
242- sshArgs = append (sshArgs , additionalArgs ... )
250+ sshArgs = append (sshArgs , opts . AdditionalArgs ... )
243251
244252 cmdio .LogString (ctx , "Launching SSH client: ssh " + strings .Join (sshArgs , " " ))
245253
@@ -252,25 +260,39 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
252260 return sshCmd .Run ()
253261}
254262
255- func runSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time. Duration ) error {
263+ func runSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , serverPort int , opts ClientOptions ) error {
256264 createConn := func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
257- return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
265+ return createWebsocketConnection (ctx , client , connID , opts . ClusterID , serverPort )
258266 }
259- return proxy .RunClientProxy (ctx , os .Stdin , os .Stdout , handoverTimeout , createConn )
267+ return proxy .RunClientProxy (ctx , os .Stdin , os .Stdout , opts . HandoverTimeout , createConn )
260268}
261269
262- func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , keysSecretScopeName , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (string , int , error ) {
263- cmdio .LogString (ctx , "Ensuring the cluster is running: " + clusterID )
264- err := client .Clusters .EnsureClusterIsRunning (ctx , clusterID )
265- if err != nil {
266- return "" , 0 , fmt .Errorf ("failed to ensure that the cluster is running: %w" , err )
270+ func checkClusterState (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , autoStart bool ) error {
271+ if autoStart {
272+ cmdio .LogString (ctx , "Ensuring the cluster is running: " + clusterID )
273+ err := client .Clusters .EnsureClusterIsRunning (ctx , clusterID )
274+ if err != nil {
275+ return fmt .Errorf ("failed to ensure that the cluster is running: %w" , err )
276+ }
277+ } else {
278+ cmdio .LogString (ctx , "Checking cluster state: " + clusterID )
279+ cluster , err := client .Clusters .GetByClusterId (ctx , clusterID )
280+ if err != nil {
281+ return fmt .Errorf ("failed to get cluster info: %w" , err )
282+ }
283+ if cluster .State != compute .StateRunning {
284+ return fmt .Errorf ("cluster %s is not running, current state: %s. Use --auto-start-cluster to start it automatically" , clusterID , cluster .State )
285+ }
267286 }
287+ return nil
288+ }
268289
269- serverPort , userName , err := getServerMetadata (ctx , client , clusterID , version )
290+ func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , version , keysSecretScopeName string , opts ClientOptions ) (string , int , error ) {
291+ serverPort , userName , err := getServerMetadata (ctx , client , opts .ClusterID , version )
270292 if errors .Is (err , errServerMetadata ) {
271293 cmdio .LogString (ctx , "SSH server is not running, starting it now..." )
272294
273- runID , err := submitSSHTunnelJob (ctx , client , clusterID , keysSecretScopeName , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
295+ runID , err := submitSSHTunnelJob (ctx , client , version , keysSecretScopeName , opts )
274296 if err != nil {
275297 return "" , 0 , fmt .Errorf ("failed to submit ssh server job: %w" , err )
276298 }
@@ -282,7 +304,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
282304 if ctx .Err () != nil {
283305 return "" , 0 , ctx .Err ()
284306 }
285- serverPort , userName , err = getServerMetadata (ctx , client , clusterID , version )
307+ serverPort , userName , err = getServerMetadata (ctx , client , opts . ClusterID , version )
286308 if err == nil {
287309 cmdio .LogString (ctx , "Health check successful, starting ssh WebSocket connection..." )
288310 break
0 commit comments