Skip to content

Commit c2e3322

Browse files
authored
Split and move ssh logic to experimental/ssh (#3544)
## Changes Previous PRs in the stack: - #3471 - #3475 Move from flat structure to cmd and internal client/server/proxy/keys/workspace sub packages. Moved a bunch of hardcoded values to constants (some hardcoded stuff remains in the `keys` and `workspace` sub packages, will be improved in a follow up) Added unit tests for servers connection manager logic and updated proxy tests. The coverage isn't great yet, the goal is to avoid unit tests with excessive mocking and focus on integration tests from now on. ## Why `/experimental` already has the databricks-bundles python project, seems like a good place to put ssh stuff too ## Tests Unit and manual
1 parent d976ed3 commit c2e3322

File tree

25 files changed

+814
-505
lines changed

25 files changed

+814
-505
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
default: checks fmt lint
22

3-
PACKAGES=./acceptance/... ./libs/... ./internal/... ./cmd/... ./bundle/... .
3+
PACKAGES=./acceptance/... ./libs/... ./internal/... ./cmd/... ./bundle/... ./experimental/ssh/... .
44

55
GOTESTSUM_FORMAT ?= pkgname-and-test-fails
66
GOTESTSUM_CMD ?= go tool gotestsum --format ${GOTESTSUM_FORMAT} --no-summary=skipped --jsonfile test-output.json
@@ -136,4 +136,4 @@ generate:
136136
$(GENKIT_BINARY) update-sdk
137137

138138

139-
.PHONY: lint lintfull tidy lintcheck fmt fmtfull test cover showcover build snapshot schema integration integration-short acc-cover acc-showcover docs ws links checks test-update test-update-aws test-update-all generate-validation
139+
.PHONY: lint lintfull tidy lintcheck fmt fmtfull test cover showcover build snapshot snapshot-release schema integration integration-short acc-cover acc-showcover docs ws links checks test-update test-update-aws test-update-all generate-validation

cmd/cmd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"strings"
66

77
"github.com/databricks/cli/cmd/psql"
8-
"github.com/databricks/cli/cmd/ssh"
8+
ssh "github.com/databricks/cli/experimental/ssh/cmd"
99

1010
"github.com/databricks/cli/cmd/account"
1111
"github.com/databricks/cli/cmd/api"
File renamed without changes.
Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,12 @@ import (
44
"time"
55

66
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/experimental/ssh/internal/client"
8+
"github.com/databricks/cli/experimental/ssh/internal/proxy"
79
"github.com/databricks/cli/libs/cmdctx"
8-
"github.com/databricks/cli/libs/ssh"
910
"github.com/spf13/cobra"
1011
)
1112

12-
const (
13-
defaultClientPublicKeyName = "client-public-key"
14-
defaultShutdownDelay = 10 * time.Minute
15-
defaultHandoverTimeout = 30 * time.Minute
16-
defaultMaxClients = 10
17-
)
18-
1913
func newConnectCommand() *cobra.Command {
2014
cmd := &cobra.Command{
2115
Use: "connect",
@@ -54,8 +48,8 @@ the SSH server and handling the connection proxy.
5448
cmd.PreRunE = root.MustWorkspaceClient
5549
cmd.RunE = func(cmd *cobra.Command, args []string) error {
5650
ctx := cmd.Context()
57-
client := cmdctx.WorkspaceClient(ctx)
58-
opts := ssh.ClientOptions{
51+
wsClient := cmdctx.WorkspaceClient(ctx)
52+
opts := client.ClientOptions{
5953
ClusterID: clusterID,
6054
ProxyMode: proxyMode,
6155
ServerMetadata: serverMetadata,
@@ -65,8 +59,13 @@ the SSH server and handling the connection proxy.
6559
ReleasesDir: releasesDir,
6660
AdditionalArgs: args,
6761
ClientPublicKeyName: defaultClientPublicKeyName,
62+
ServerTimeout: serverTimeout,
63+
}
64+
err := client.RunClient(ctx, wsClient, opts)
65+
if err != nil && proxy.IsNormalClosure(err) {
66+
return nil
6867
}
69-
return ssh.RunClient(ctx, client, opts)
68+
return err
7069
}
7170

7271
return cmd

experimental/ssh/cmd/constants.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package ssh
2+
3+
import "time"
4+
5+
const (
6+
defaultServerPort = 7772
7+
defaultMaxClients = 10
8+
defaultShutdownDelay = 10 * time.Minute
9+
defaultHandoverTimeout = 30 * time.Minute
10+
11+
serverTimeout = 24 * time.Hour
12+
serverPortRange = 100
13+
serverConfigDir = ".ssh-tunnel"
14+
serverPrivateKeyName = "server-private-key"
15+
serverPublicKeyName = "server-public-key"
16+
defaultClientPublicKeyName = "client-public-key"
17+
)
Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"time"
55

66
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/experimental/ssh/internal/server"
78
"github.com/databricks/cli/libs/cmdctx"
8-
"github.com/databricks/cli/libs/ssh"
99
"github.com/spf13/cobra"
1010
)
1111

@@ -27,33 +27,38 @@ and proxies them to local SSH daemon processes.
2727
var shutdownDelay time.Duration
2828
var clusterID string
2929
var version string
30+
var secretsScope string
31+
var publicKeySecretName string
3032

3133
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3234
cmd.MarkFlagRequired("cluster")
33-
cmd.Flags().IntVar(&maxClients, "max-clients", 10, "Maximum number of SSH clients")
34-
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "Delay before shutting down after no pings from clients")
35+
cmd.Flags().StringVar(&secretsScope, "secrets-scope-name", "", "Databricks secrets scope name")
36+
cmd.MarkFlagRequired("secrets-scope-name")
37+
cmd.Flags().StringVar(&publicKeySecretName, "client-key-name", "", "Databricks client key name")
38+
cmd.MarkFlagRequired("client-key-name")
39+
40+
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
41+
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
3542
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
3643

3744
cmd.PreRunE = root.MustWorkspaceClient
3845
cmd.RunE = func(cmd *cobra.Command, args []string) error {
3946
ctx := cmd.Context()
4047
client := cmdctx.WorkspaceClient(ctx)
41-
opts := ssh.ServerOptions{
48+
opts := server.ServerOptions{
4249
ClusterID: clusterID,
4350
MaxClients: maxClients,
4451
ShutdownDelay: shutdownDelay,
4552
Version: version,
46-
ConfigDir: ".ssh-tunnel",
47-
ServerPrivateKeyName: "server-private-key",
48-
ServerPublicKeyName: "server-public-key",
49-
DefaultPort: 7772,
50-
PortRange: 100,
51-
}
52-
err := ssh.RunServer(ctx, client, opts)
53-
if err != nil && ssh.IsNormalClosure(err) {
54-
return nil
53+
ConfigDir: serverConfigDir,
54+
SecretsScope: secretsScope,
55+
ClientPublicKeyName: publicKeySecretName,
56+
ServerPrivateKeyName: serverPrivateKeyName,
57+
ServerPublicKeyName: serverPublicKeyName,
58+
DefaultPort: defaultServerPort,
59+
PortRange: serverPortRange,
5560
}
56-
return err
61+
return server.RunServer(ctx, client, opts)
5762
}
5863

5964
return cmd
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"time"
55

66
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/experimental/ssh/internal/setup"
78
"github.com/databricks/cli/libs/cmdctx"
8-
"github.com/databricks/cli/libs/ssh"
99
"github.com/spf13/cobra"
1010
)
1111

@@ -31,20 +31,20 @@ an SSH host configuration to your SSH config file.
3131
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3232
cmd.MarkFlagRequired("cluster")
3333
cmd.Flags().StringVar(&sshConfigPath, "ssh-config", "", "Path to SSH config file (default ~/.ssh/config)")
34-
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "SSH server will terminate after this delay if there are no active connections")
34+
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "SSH server will terminate after this delay if there are no active connections")
3535

3636
cmd.PreRunE = root.MustWorkspaceClient
3737
cmd.RunE = func(cmd *cobra.Command, args []string) error {
3838
ctx := cmd.Context()
3939
client := cmdctx.WorkspaceClient(ctx)
40-
opts := ssh.SetupOptions{
40+
opts := setup.SetupOptions{
4141
HostName: hostName,
4242
ClusterID: clusterID,
4343
SSHConfigPath: sshConfigPath,
4444
ShutdownDelay: shutdownDelay,
4545
Profile: client.Config.Profile,
4646
}
47-
return ssh.Setup(ctx, client, opts)
47+
return setup.Setup(ctx, client, opts)
4848
}
4949

5050
return cmd

libs/ssh/client.go renamed to experimental/ssh/internal/client/client.go

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
package ssh
1+
package client
22

33
import (
44
"context"
55
_ "embed"
66
"encoding/base64"
7-
"encoding/json"
87
"errors"
98
"fmt"
109
"io"
@@ -18,6 +17,9 @@ import (
1817
"syscall"
1918
"time"
2019

20+
"github.com/databricks/cli/experimental/ssh/internal/keys"
21+
"github.com/databricks/cli/experimental/ssh/internal/proxy"
22+
sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
2123
"github.com/databricks/cli/internal/build"
2224
"github.com/databricks/cli/libs/cmdio"
2325
"github.com/databricks/databricks-sdk-go"
@@ -27,17 +29,11 @@ import (
2729
"golang.org/x/sync/errgroup"
2830
)
2931

30-
type PortMetadata struct {
31-
Port int `json:"port"`
32-
}
33-
3432
//go:embed ssh-server-bootstrap.py
3533
var sshServerBootstrapScript string
3634

3735
var errServerMetadata = errors.New("server metadata error")
3836

39-
const serverJobTimeoutSeconds = 24 * 60 * 60
40-
4137
type ClientOptions struct {
4238
// Id of the cluster to connect to
4339
ClusterID string
@@ -54,6 +50,8 @@ type ClientOptions struct {
5450
ServerMetadata string
5551
// How often the CLI should reconnect to the server with new auth.
5652
HandoverTimeout time.Duration
53+
// Max amount of time the server process is allowed to live
54+
ServerTimeout time.Duration
5755
// Directory for local SSH tunnel development releases.
5856
// If not present, the CLI will use github releases with the current version.
5957
ReleasesDir string
@@ -77,16 +75,17 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
7775
cancel()
7876
}()
7977

80-
keyPath, err := getLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir)
78+
keyPath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir)
8179
if err != nil {
8280
return fmt.Errorf("failed to get local keys folder: %w", err)
8381
}
84-
privateKeyPath, publicKey, err := checkAndGenerateSSHKeyPair(ctx, keyPath)
82+
privateKeyPath, publicKey, err := keys.CheckAndGenerateSSHKeyPair(ctx, keyPath)
8583
if err != nil {
8684
return fmt.Errorf("failed to check or generate SSH key pair: %w", err)
8785
}
86+
cmdio.LogString(ctx, "Using SSH key: "+privateKeyPath)
8887

89-
secretsScopeName, err := putSecretInScope(ctx, client, opts.ClusterID, opts.ClientPublicKeyName, publicKey)
88+
secretsScopeName, err := keys.PutSecretInScope(ctx, client, opts.ClusterID, opts.ClientPublicKeyName, publicKey)
9089
if err != nil {
9190
return fmt.Errorf("failed to store public key in secret scope: %w", err)
9291
}
@@ -99,10 +98,10 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
9998

10099
if opts.ServerMetadata == "" {
101100
cmdio.LogString(ctx, "Checking for ssh-tunnel binaries to upload...")
102-
if err := uploadTunnelBinaries(ctx, client, version, opts.ReleasesDir); err != nil {
101+
if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil {
103102
return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err)
104103
}
105-
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, secretsScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients)
104+
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, secretsScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients, opts.ServerTimeout)
106105
if err != nil {
107106
return fmt.Errorf("failed to ensure that ssh server is running: %w", err)
108107
}
@@ -132,36 +131,8 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
132131
}
133132
}
134133

135-
func getWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (int, error) {
136-
contentDir, err := getWorkspaceContentDir(ctx, client, version, clusterID)
137-
if err != nil {
138-
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
139-
}
140-
141-
metadataPath := filepath.ToSlash(filepath.Join(contentDir, "metadata.json"))
142-
143-
content, err := client.Workspace.Download(ctx, metadataPath)
144-
if err != nil {
145-
return 0, fmt.Errorf("failed to download metadata file: %w", err)
146-
}
147-
defer content.Close()
148-
149-
metadataBytes, err := io.ReadAll(content)
150-
if err != nil {
151-
return 0, fmt.Errorf("failed to read metadata content: %w", err)
152-
}
153-
154-
var metadata PortMetadata
155-
err = json.Unmarshal(metadataBytes, &metadata)
156-
if err != nil {
157-
return 0, fmt.Errorf("failed to parse metadata JSON: %w", err)
158-
}
159-
160-
return metadata.Port, nil
161-
}
162-
163134
func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, clusterID, version string) (int, string, error) {
164-
serverPort, err := getWorkspaceMetadata(ctx, client, version, clusterID)
135+
serverPort, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, clusterID)
165136
if err != nil {
166137
return 0, "", errors.Join(errServerMetadata, err)
167138
}
@@ -194,8 +165,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
194165
return serverPort, string(bodyBytes), nil
195166
}
196167

197-
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int) (int64, error) {
198-
contentDir, err := getWorkspaceContentDir(ctx, client, version, clusterID)
168+
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (int64, error) {
169+
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, clusterID)
199170
if err != nil {
200171
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
201172
}
@@ -223,7 +194,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
223194

224195
submitRun := jobs.SubmitRun{
225196
RunName: sshTunnelJobName,
226-
TimeoutSeconds: serverJobTimeoutSeconds,
197+
TimeoutSeconds: int(serverTimeout.Seconds()),
227198
Tasks: []jobs.SubmitTask{
228199
{
229200
TaskKey: "start_ssh_server",
@@ -237,7 +208,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
237208
"maxClients": strconv.Itoa(maxClients),
238209
},
239210
},
240-
TimeoutSeconds: serverJobTimeoutSeconds,
211+
TimeoutSeconds: int(serverTimeout.Seconds()),
241212
ExistingClusterId: clusterID,
242213
},
243214
},
@@ -286,13 +257,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
286257
g, gCtx := errgroup.WithContext(ctx)
287258

288259
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
289-
proxy := newProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
260+
conn := proxy.NewProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
290261
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
291262
})
292-
if err := proxy.Connect(gCtx); err != nil {
263+
if err := conn.Connect(gCtx); err != nil {
293264
return fmt.Errorf("failed to connect to proxy: %w", err)
294265
}
295-
defer proxy.Close()
266+
defer conn.Close()
296267
cmdio.LogString(ctx, "SSH proxy connection established")
297268

298269
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
@@ -305,7 +276,7 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
305276
case <-gCtx.Done():
306277
return gCtx.Err()
307278
case <-handoverTicker.C:
308-
err := proxy.InitiateHandover(gCtx)
279+
err := conn.InitiateHandover(gCtx)
309280
if err != nil {
310281
return err
311282
}
@@ -314,13 +285,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
314285
})
315286

316287
g.Go(func() error {
317-
return proxy.Start(gCtx, os.Stdin, os.Stdout)
288+
return conn.Start(gCtx, os.Stdin, os.Stdout)
318289
})
319290

320291
return g.Wait()
321292
}
322293

323-
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int) (string, int, error) {
294+
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (string, int, error) {
324295
cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID)
325296
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
326297
if err != nil {
@@ -331,7 +302,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
331302
if errors.Is(err, errServerMetadata) {
332303
cmdio.LogString(ctx, "SSH server is not running, starting it now...")
333304

334-
runID, err := submitSSHTunnelJob(ctx, client, clusterID, secretsScope, publicKeySecretName, version, shutdownDelay, maxClients)
305+
runID, err := submitSSHTunnelJob(ctx, client, clusterID, secretsScope, publicKeySecretName, version, shutdownDelay, maxClients, serverTimeout)
335306
if err != nil {
336307
return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err)
337308
}

0 commit comments

Comments
 (0)