Skip to content

Commit 1012632

Browse files
authored
Add cluster selector UI and auto-start option to ssh commands (#3607)
## Changes - Add a cluster selector UI to the `ssh setup` command. We don't filter or check the clusters during that stage, as it's done later by the `validateClusterAccess`. Cluster option is no longer required for the `setup` command. A few customers asked for this already. - Add an `--auto-start-cluster` flag to both `ssh setup` and `ssh connect` commands, which defaults to true. Useful for PyCharm users, where the IDE checks every host in the config to see the if the connection can be established, starting all the clusters. - Change a few `client.go` functions to accept ClientOptions instead of a long list of separate arguments Based on #3569 ## Tests No unit tests, integration tests will be in the follow ups.
1 parent f9c6279 commit 1012632

File tree

4 files changed

+97
-36
lines changed

4 files changed

+97
-36
lines changed

experimental/ssh/cmd/connect.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ the SSH server and handling the connection proxy.
2828
var maxClients int
2929
var handoverTimeout time.Duration
3030
var releasesDir string
31+
var autoStartCluster bool
3132

3233
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)")
3334
cmd.MarkFlagRequired("cluster")
3435
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
3536
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
37+
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")
3638

3739
cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode")
3840
cmd.Flags().MarkHidden("proxy")
@@ -59,6 +61,7 @@ the SSH server and handling the connection proxy.
5961
AdditionalArgs: args,
6062
ClientPublicKeyName: defaultClientPublicKeyName,
6163
ServerTimeout: serverTimeout,
64+
AutoStartCluster: autoStartCluster,
6265
}
6366
return client.Run(ctx, wsClient, opts)
6467
}

experimental/ssh/cmd/setup.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ an SSH host configuration to your SSH config file.
2525
var clusterID string
2626
var sshConfigPath string
2727
var shutdownDelay time.Duration
28+
var autoStartCluster bool
2829

2930
cmd.Flags().StringVar(&hostName, "name", "", "Host name to use in SSH config")
3031
cmd.MarkFlagRequired("name")
3132
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
32-
cmd.MarkFlagRequired("cluster")
33+
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster when establishing the ssh connection")
3334
cmd.Flags().StringVar(&sshConfigPath, "ssh-config", "", "Path to SSH config file (default ~/.ssh/config)")
3435
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "SSH server will terminate after this delay if there are no active connections")
3536

@@ -38,11 +39,12 @@ an SSH host configuration to your SSH config file.
3839
ctx := cmd.Context()
3940
client := cmdctx.WorkspaceClient(ctx)
4041
opts := setup.SetupOptions{
41-
HostName: hostName,
42-
ClusterID: clusterID,
43-
SSHConfigPath: sshConfigPath,
44-
ShutdownDelay: shutdownDelay,
45-
Profile: client.Config.Profile,
42+
HostName: hostName,
43+
ClusterID: clusterID,
44+
AutoStartCluster: autoStartCluster,
45+
SSHConfigPath: sshConfigPath,
46+
ShutdownDelay: shutdownDelay,
47+
Profile: client.Config.Profile,
4648
}
4749
return setup.Setup(ctx, client, opts)
4850
}

experimental/ssh/internal/client/client.go

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/databricks/cli/internal/build"
2424
"github.com/databricks/cli/libs/cmdio"
2525
"github.com/databricks/databricks-sdk-go"
26+
"github.com/databricks/databricks-sdk-go/service/compute"
2627
"github.com/databricks/databricks-sdk-go/service/jobs"
2728
"github.com/databricks/databricks-sdk-go/service/workspace"
2829
"github.com/gorilla/websocket"
@@ -58,6 +59,8 @@ type ClientOptions struct {
5859
SSHKeysDir string
5960
// Name of the client public key file to be used in the ssh-tunnel secrets scope.
6061
ClientPublicKeyName string
62+
// If true, the CLI will attempt to start the cluster if it is not running.
63+
AutoStartCluster bool
6164
// Additional arguments to pass to the SSH client in the non proxy mode.
6265
AdditionalArgs []string
6366
}
@@ -74,6 +77,11 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
7477
cancel()
7578
}()
7679

80+
err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster)
81+
if err != nil {
82+
return err
83+
}
84+
7785
keyPath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir)
7886
if err != nil {
7987
return fmt.Errorf("failed to get local keys folder: %w", err)
@@ -100,7 +108,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
100108
if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil {
101109
return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err)
102110
}
103-
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, keysSecretScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients, opts.ServerTimeout)
111+
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, version, keysSecretScopeName, opts)
104112
if err != nil {
105113
return fmt.Errorf("failed to ensure that ssh server is running: %w", err)
106114
}
@@ -123,10 +131,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
123131
cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort))
124132

125133
if opts.ProxyMode {
126-
return runSSHProxy(ctx, client, opts.ClusterID, serverPort, opts.HandoverTimeout)
134+
return runSSHProxy(ctx, client, serverPort, opts)
127135
} else {
128136
cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs))
129-
return spawnSSHClient(ctx, opts.ClusterID, userName, privateKeyPath, serverPort, opts.HandoverTimeout, opts.AdditionalArgs)
137+
return spawnSSHClient(ctx, userName, privateKeyPath, serverPort, opts)
130138
}
131139
}
132140

@@ -164,8 +172,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
164172
return serverPort, string(bodyBytes), nil
165173
}
166174

167-
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, keysSecretScopeName, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (int64, error) {
168-
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, clusterID)
175+
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, keysSecretScopeName string, opts ClientOptions) (int64, error) {
176+
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, opts.ClusterID)
169177
if err != nil {
170178
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
171179
}
@@ -175,7 +183,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
175183
return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err)
176184
}
177185

178-
sshTunnelJobName := "ssh-server-bootstrap-" + clusterID
186+
sshTunnelJobName := "ssh-server-bootstrap-" + opts.ClusterID
179187
jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap"))
180188
notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript
181189
encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent))
@@ -193,7 +201,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
193201

194202
submitRun := jobs.SubmitRun{
195203
RunName: sshTunnelJobName,
196-
TimeoutSeconds: int(serverTimeout.Seconds()),
204+
TimeoutSeconds: int(opts.ServerTimeout.Seconds()),
197205
Tasks: []jobs.SubmitTask{
198206
{
199207
TaskKey: "start_ssh_server",
@@ -202,13 +210,13 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
202210
BaseParameters: map[string]string{
203211
"version": version,
204212
"keysSecretScopeName": keysSecretScopeName,
205-
"authorizedKeySecretName": publicKeySecretName,
206-
"shutdownDelay": shutdownDelay.String(),
207-
"maxClients": strconv.Itoa(maxClients),
213+
"authorizedKeySecretName": opts.ClientPublicKeyName,
214+
"shutdownDelay": opts.ShutdownDelay.String(),
215+
"maxClients": strconv.Itoa(opts.MaxClients),
208216
},
209217
},
210-
TimeoutSeconds: int(serverTimeout.Seconds()),
211-
ExistingClusterId: clusterID,
218+
TimeoutSeconds: int(opts.ServerTimeout.Seconds()),
219+
ExistingClusterId: opts.ClusterID,
212220
},
213221
},
214222
}
@@ -222,24 +230,24 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
222230
return runResult.Response.RunId, nil
223231
}
224232

225-
func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath string, serverPort int, handoverTimeout time.Duration, additionalArgs []string) error {
233+
func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, opts ClientOptions) error {
226234
executablePath, err := os.Executable()
227235
if err != nil {
228236
return fmt.Errorf("failed to get current executable path: %w", err)
229237
}
230238

231-
proxyCommand := fmt.Sprintf("%s ssh connect --proxy --cluster=%s --handover-timeout=%s --metadata=%s,%d",
232-
executablePath, clusterID, handoverTimeout.String(), userName, serverPort)
239+
proxyCommand := fmt.Sprintf("%s ssh connect --proxy --cluster=%s --handover-timeout=%s --metadata=%s,%d --auto-start-cluster=%t",
240+
executablePath, opts.ClusterID, opts.HandoverTimeout.String(), userName, serverPort, opts.AutoStartCluster)
233241

234242
sshArgs := []string{
235243
"-l", userName,
236244
"-i", privateKeyPath,
237245
"-o", "StrictHostKeyChecking=accept-new",
238246
"-o", "ConnectTimeout=360",
239247
"-o", "ProxyCommand=" + proxyCommand,
240-
clusterID,
248+
opts.ClusterID,
241249
}
242-
sshArgs = append(sshArgs, additionalArgs...)
250+
sshArgs = append(sshArgs, opts.AdditionalArgs...)
243251

244252
cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " "))
245253

@@ -252,25 +260,39 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
252260
return sshCmd.Run()
253261
}
254262

255-
func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int, handoverTimeout time.Duration) error {
263+
func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, opts ClientOptions) error {
256264
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
257-
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
265+
return createWebsocketConnection(ctx, client, connID, opts.ClusterID, serverPort)
258266
}
259-
return proxy.RunClientProxy(ctx, os.Stdin, os.Stdout, handoverTimeout, createConn)
267+
return proxy.RunClientProxy(ctx, os.Stdin, os.Stdout, opts.HandoverTimeout, createConn)
260268
}
261269

262-
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, keysSecretScopeName, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (string, int, error) {
263-
cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID)
264-
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
265-
if err != nil {
266-
return "", 0, fmt.Errorf("failed to ensure that the cluster is running: %w", err)
270+
func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, autoStart bool) error {
271+
if autoStart {
272+
cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID)
273+
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
274+
if err != nil {
275+
return fmt.Errorf("failed to ensure that the cluster is running: %w", err)
276+
}
277+
} else {
278+
cmdio.LogString(ctx, "Checking cluster state: "+clusterID)
279+
cluster, err := client.Clusters.GetByClusterId(ctx, clusterID)
280+
if err != nil {
281+
return fmt.Errorf("failed to get cluster info: %w", err)
282+
}
283+
if cluster.State != compute.StateRunning {
284+
return fmt.Errorf("cluster %s is not running, current state: %s. Use --auto-start-cluster to start it automatically", clusterID, cluster.State)
285+
}
267286
}
287+
return nil
288+
}
268289

269-
serverPort, userName, err := getServerMetadata(ctx, client, clusterID, version)
290+
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, keysSecretScopeName string, opts ClientOptions) (string, int, error) {
291+
serverPort, userName, err := getServerMetadata(ctx, client, opts.ClusterID, version)
270292
if errors.Is(err, errServerMetadata) {
271293
cmdio.LogString(ctx, "SSH server is not running, starting it now...")
272294

273-
runID, err := submitSSHTunnelJob(ctx, client, clusterID, keysSecretScopeName, publicKeySecretName, version, shutdownDelay, maxClients, serverTimeout)
295+
runID, err := submitSSHTunnelJob(ctx, client, version, keysSecretScopeName, opts)
274296
if err != nil {
275297
return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err)
276298
}
@@ -282,7 +304,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
282304
if ctx.Err() != nil {
283305
return "", 0, ctx.Err()
284306
}
285-
serverPort, userName, err = getServerMetadata(ctx, client, clusterID, version)
307+
serverPort, userName, err = getServerMetadata(ctx, client, opts.ClusterID, version)
286308
if err == nil {
287309
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
288310
break

experimental/ssh/internal/setup/setup.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package setup
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"os"
78
"path/filepath"
@@ -20,6 +21,8 @@ type SetupOptions struct {
2021
HostName string
2122
// The cluster ID to connect to
2223
ClusterID string
24+
// Whether to automatically start the cluster during ssh connection if it is not running
25+
AutoStartCluster bool
2326
// Delay before shutting down the SSH tunnel, will be added as a --shutdown-delay flag to the ProxyCommand
2427
ShutdownDelay time.Duration
2528
// Optional path to the local ssh config. Defaults to ~/.ssh/config
@@ -74,8 +77,8 @@ Host %s
7477
ConnectTimeout 360
7578
StrictHostKeyChecking accept-new
7679
IdentityFile %q
77-
ProxyCommand %q ssh connect --proxy --cluster=%s --shutdown-delay=%s %s
78-
`, opts.HostName, identityFilePath, execPath, opts.ClusterID, opts.ShutdownDelay, profileOption)
80+
ProxyCommand %q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s %s
81+
`, opts.HostName, identityFilePath, execPath, opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, profileOption)
7982

8083
return hostConfig, nil
8184
}
@@ -141,7 +144,38 @@ func updateSSHConfigFile(configPath, hostConfig, hostName string) error {
141144
return nil
142145
}
143146

147+
func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
148+
spinnerChan := cmdio.Spinner(ctx)
149+
spinnerChan <- "Loading clusters."
150+
clusters, err := client.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{
151+
FilterBy: &compute.ListClustersFilterBy{
152+
ClusterSources: []compute.ClusterSource{compute.ClusterSourceApi, compute.ClusterSourceUi},
153+
},
154+
})
155+
close(spinnerChan)
156+
if err != nil {
157+
return "", fmt.Errorf("failed to load names for Clusters drop-down. Please manually specify cluster argument. Original error: %w", err)
158+
}
159+
id, err := cmdio.Select(ctx, clusters, "The cluster to connect to")
160+
if err != nil {
161+
return "", err
162+
}
163+
return id, nil
164+
}
165+
144166
func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOptions) error {
167+
if opts.ClusterID == "" {
168+
id, err := clusterSelectionPrompt(ctx, client)
169+
if err != nil {
170+
return err
171+
}
172+
opts.ClusterID = id
173+
}
174+
175+
if opts.ClusterID == "" {
176+
return errors.New("cluster ID is required")
177+
}
178+
145179
err := validateClusterAccess(ctx, client, opts.ClusterID)
146180
if err != nil {
147181
return err

0 commit comments

Comments
 (0)