Skip to content

Commit 963e80a

Browse files
authored
Add serverless GPU compute support to SSH tunnel (#4162)
## Changes Add serverless GPU support for `databricks ssh connect` command All the new flags and options are hidden for now. The main change for existing users (on dedicated clusters) is config management. Based on the PR stack: #4452 #4453 ## Why <!-- Why are these changes needed? Provide the context that the reviewer might be missing. For example, were there any decisions behind the change that are not reflected in the code itself? --> ## Tests <!-- How have you tested the changes? --> <!-- If your PR needs to be included in the release notes for next release, add a separate entry in NEXT_CHANGELOG.md as part of your PR. -->
1 parent dccd96b commit 963e80a

File tree

17 files changed

+1146
-489
lines changed

17 files changed

+1146
-489
lines changed

experimental/ssh/cmd/connect.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ssh
22

33
import (
4+
"errors"
45
"time"
56

67
"github.com/databricks/cli/cmd/root"
@@ -22,21 +23,31 @@ the SSH server and handling the connection proxy.
2223
}
2324

2425
var clusterID string
26+
var connectionName string
27+
var accelerator string
2528
var proxyMode bool
29+
var ide string
2630
var serverMetadata string
2731
var shutdownDelay time.Duration
2832
var maxClients int
2933
var handoverTimeout time.Duration
3034
var releasesDir string
3135
var autoStartCluster bool
3236
var userKnownHostsFile string
37+
var liteswap string
3338

34-
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)")
35-
cmd.MarkFlagRequired("cluster")
39+
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
3640
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
3741
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
3842
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")
3943

44+
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
45+
cmd.Flags().MarkHidden("name")
46+
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)")
47+
cmd.Flags().MarkHidden("accelerator")
48+
cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)")
49+
cmd.Flags().MarkHidden("ide")
50+
4051
cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode")
4152
cmd.Flags().MarkHidden("proxy")
4253
cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: <user_name>,<port>)")
@@ -50,6 +61,9 @@ the SSH server and handling the connection proxy.
5061
cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client")
5162
cmd.Flags().MarkHidden("user-known-hosts-file")
5263

64+
cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)")
65+
cmd.Flags().MarkHidden("liteswap")
66+
5367
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
5468
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
5569
if proxyMode {
@@ -64,20 +78,41 @@ the SSH server and handling the connection proxy.
6478
cmd.RunE = func(cmd *cobra.Command, args []string) error {
6579
ctx := cmd.Context()
6680
wsClient := cmdctx.WorkspaceClient(ctx)
81+
82+
if !proxyMode && clusterID == "" && connectionName == "" {
83+
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
84+
}
85+
86+
if accelerator != "" && connectionName == "" {
87+
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
88+
}
89+
90+
// Remove when we add support for serverless CPU
91+
if connectionName != "" && accelerator == "" {
92+
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
93+
}
94+
95+
// TODO: validate connectionName if provided
96+
6797
opts := client.ClientOptions{
6898
Profile: wsClient.Config.Profile,
6999
ClusterID: clusterID,
100+
ConnectionName: connectionName,
101+
Accelerator: accelerator,
70102
ProxyMode: proxyMode,
103+
IDE: ide,
71104
ServerMetadata: serverMetadata,
72105
ShutdownDelay: shutdownDelay,
73106
MaxClients: maxClients,
74107
HandoverTimeout: handoverTimeout,
75108
ReleasesDir: releasesDir,
76109
ServerTimeout: max(serverTimeout, shutdownDelay),
110+
TaskStartupTimeout: taskStartupTimeout,
77111
AutoStartCluster: autoStartCluster,
78112
ClientPublicKeyName: clientPublicKeyName,
79113
ClientPrivateKeyName: clientPrivateKeyName,
80114
UserKnownHostsFile: userKnownHostsFile,
115+
Liteswap: liteswap,
81116
AdditionalArgs: args,
82117
}
83118
return client.Run(ctx, wsClient, opts)

experimental/ssh/cmd/constants.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const (
99
defaultHandoverTimeout = 30 * time.Minute
1010

1111
serverTimeout = 24 * time.Hour
12+
taskStartupTimeout = 10 * time.Minute
1213
serverPortRange = 100
1314
serverConfigDir = ".ssh-tunnel"
1415
serverPrivateKeyName = "server-private-key"

experimental/ssh/cmd/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes.
2626
var maxClients int
2727
var shutdownDelay time.Duration
2828
var clusterID string
29+
var sessionID string
2930
var version string
3031
var secretScopeName string
3132
var authorizedKeySecretName string
3233

3334
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3435
cmd.MarkFlagRequired("cluster")
36+
cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)")
37+
cmd.MarkFlagRequired("session-id")
3538
cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys")
3639
cmd.MarkFlagRequired("secret-scope-name")
3740
cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key")
@@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes.
5659
wsc := cmdctx.WorkspaceClient(ctx)
5760
opts := server.ServerOptions{
5861
ClusterID: clusterID,
62+
SessionID: sessionID,
5963
MaxClients: maxClients,
6064
ShutdownDelay: shutdownDelay,
6165
Version: version,

experimental/ssh/cmd/setup.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package ssh
22

33
import (
4+
"fmt"
45
"time"
56

67
"github.com/databricks/cli/cmd/root"
8+
"github.com/databricks/cli/experimental/ssh/internal/client"
79
"github.com/databricks/cli/experimental/ssh/internal/setup"
810
"github.com/databricks/cli/libs/cmdctx"
911
"github.com/spf13/cobra"
@@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file.
4345

4446
cmd.RunE = func(cmd *cobra.Command, args []string) error {
4547
ctx := cmd.Context()
46-
client := cmdctx.WorkspaceClient(ctx)
47-
opts := setup.SetupOptions{
48+
wsClient := cmdctx.WorkspaceClient(ctx)
49+
setupOpts := setup.SetupOptions{
4850
HostName: hostName,
4951
ClusterID: clusterID,
5052
AutoStartCluster: autoStartCluster,
5153
SSHConfigPath: sshConfigPath,
5254
ShutdownDelay: shutdownDelay,
53-
Profile: client.Config.Profile,
55+
Profile: wsClient.Config.Profile,
5456
}
55-
return setup.Setup(ctx, client, opts)
57+
clientOpts := client.ClientOptions{
58+
ClusterID: setupOpts.ClusterID,
59+
AutoStartCluster: setupOpts.AutoStartCluster,
60+
ShutdownDelay: setupOpts.ShutdownDelay,
61+
Profile: setupOpts.Profile,
62+
}
63+
proxyCommand, err := clientOpts.ToProxyCommand()
64+
if err != nil {
65+
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
66+
}
67+
setupOpts.ProxyCommand = proxyCommand
68+
return setup.Setup(ctx, wsClient, setupOpts)
5669
}
5770

5871
return cmd

0 commit comments

Comments
 (0)