1- package ssh
1+ package client
22
33import (
44 "context"
55 _ "embed"
66 "encoding/base64"
7- "encoding/json"
87 "errors"
98 "fmt"
109 "io"
@@ -18,6 +17,9 @@ import (
1817 "syscall"
1918 "time"
2019
20+ "github.com/databricks/cli/experimental/ssh/internal/keys"
21+ "github.com/databricks/cli/experimental/ssh/internal/proxy"
22+ sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
2123 "github.com/databricks/cli/internal/build"
2224 "github.com/databricks/cli/libs/cmdio"
2325 "github.com/databricks/databricks-sdk-go"
@@ -27,17 +29,11 @@ import (
2729 "golang.org/x/sync/errgroup"
2830)
2931
30- type PortMetadata struct {
31- Port int `json:"port"`
32- }
33-
3432//go:embed ssh-server-bootstrap.py
3533var sshServerBootstrapScript string
3634
3735var errServerMetadata = errors .New ("server metadata error" )
3836
39- const serverJobTimeoutSeconds = 24 * 60 * 60
40-
4137type ClientOptions struct {
4238 // Id of the cluster to connect to
4339 ClusterID string
@@ -54,6 +50,8 @@ type ClientOptions struct {
5450 ServerMetadata string
5551 // How often the CLI should reconnect to the server with new auth.
5652 HandoverTimeout time.Duration
53+ // Max amount of time the server process is allowed to live
54+ ServerTimeout time.Duration
5755 // Directory for local SSH tunnel development releases.
5856 // If not present, the CLI will use github releases with the current version.
5957 ReleasesDir string
@@ -77,16 +75,17 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
7775 cancel ()
7876 }()
7977
80- keyPath , err := getLocalSSHKeyPath (opts .ClusterID , opts .SSHKeysDir )
78+ keyPath , err := keys . GetLocalSSHKeyPath (opts .ClusterID , opts .SSHKeysDir )
8179 if err != nil {
8280 return fmt .Errorf ("failed to get local keys folder: %w" , err )
8381 }
84- privateKeyPath , publicKey , err := checkAndGenerateSSHKeyPair (ctx , keyPath )
82+ privateKeyPath , publicKey , err := keys . CheckAndGenerateSSHKeyPair (ctx , keyPath )
8583 if err != nil {
8684 return fmt .Errorf ("failed to check or generate SSH key pair: %w" , err )
8785 }
86+ cmdio .LogString (ctx , "Using SSH key: " + privateKeyPath )
8887
89- secretsScopeName , err := putSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
88+ secretsScopeName , err := keys . PutSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
9089 if err != nil {
9190 return fmt .Errorf ("failed to store public key in secret scope: %w" , err )
9291 }
@@ -99,10 +98,10 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
9998
10099 if opts .ServerMetadata == "" {
101100 cmdio .LogString (ctx , "Checking for ssh-tunnel binaries to upload..." )
102- if err := uploadTunnelBinaries (ctx , client , version , opts .ReleasesDir ); err != nil {
101+ if err := UploadTunnelReleases (ctx , client , version , opts .ReleasesDir ); err != nil {
103102 return fmt .Errorf ("failed to upload ssh-tunnel binaries: %w" , err )
104103 }
105- userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , secretsScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients )
104+ userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , secretsScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients , opts . ServerTimeout )
106105 if err != nil {
107106 return fmt .Errorf ("failed to ensure that ssh server is running: %w" , err )
108107 }
@@ -132,36 +131,8 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
132131 }
133132}
134133
135- func getWorkspaceMetadata (ctx context.Context , client * databricks.WorkspaceClient , version , clusterID string ) (int , error ) {
136- contentDir , err := getWorkspaceContentDir (ctx , client , version , clusterID )
137- if err != nil {
138- return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
139- }
140-
141- metadataPath := filepath .ToSlash (filepath .Join (contentDir , "metadata.json" ))
142-
143- content , err := client .Workspace .Download (ctx , metadataPath )
144- if err != nil {
145- return 0 , fmt .Errorf ("failed to download metadata file: %w" , err )
146- }
147- defer content .Close ()
148-
149- metadataBytes , err := io .ReadAll (content )
150- if err != nil {
151- return 0 , fmt .Errorf ("failed to read metadata content: %w" , err )
152- }
153-
154- var metadata PortMetadata
155- err = json .Unmarshal (metadataBytes , & metadata )
156- if err != nil {
157- return 0 , fmt .Errorf ("failed to parse metadata JSON: %w" , err )
158- }
159-
160- return metadata .Port , nil
161- }
162-
163134func getServerMetadata (ctx context.Context , client * databricks.WorkspaceClient , clusterID , version string ) (int , string , error ) {
164- serverPort , err := getWorkspaceMetadata (ctx , client , version , clusterID )
135+ serverPort , err := sshWorkspace . GetWorkspaceMetadata (ctx , client , version , clusterID )
165136 if err != nil {
166137 return 0 , "" , errors .Join (errServerMetadata , err )
167138 }
@@ -194,8 +165,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
194165 return serverPort , string (bodyBytes ), nil
195166}
196167
197- func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int ) (int64 , error ) {
198- contentDir , err := getWorkspaceContentDir (ctx , client , version , clusterID )
168+ func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time. Duration ) (int64 , error ) {
169+ contentDir , err := sshWorkspace . GetWorkspaceContentDir (ctx , client , version , clusterID )
199170 if err != nil {
200171 return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
201172 }
@@ -223,7 +194,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
223194
224195 submitRun := jobs.SubmitRun {
225196 RunName : sshTunnelJobName ,
226- TimeoutSeconds : serverJobTimeoutSeconds ,
197+ TimeoutSeconds : int ( serverTimeout . Seconds ()) ,
227198 Tasks : []jobs.SubmitTask {
228199 {
229200 TaskKey : "start_ssh_server" ,
@@ -237,7 +208,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
237208 "maxClients" : strconv .Itoa (maxClients ),
238209 },
239210 },
240- TimeoutSeconds : serverJobTimeoutSeconds ,
211+ TimeoutSeconds : int ( serverTimeout . Seconds ()) ,
241212 ExistingClusterId : clusterID ,
242213 },
243214 },
@@ -286,13 +257,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
286257 g , gCtx := errgroup .WithContext (ctx )
287258
288259 cmdio .LogString (ctx , "Establishing SSH proxy connection..." )
289- proxy := newProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
260+ conn := proxy . NewProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
290261 return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
291262 })
292- if err := proxy .Connect (gCtx ); err != nil {
263+ if err := conn .Connect (gCtx ); err != nil {
293264 return fmt .Errorf ("failed to connect to proxy: %w" , err )
294265 }
295- defer proxy .Close ()
266+ defer conn .Close ()
296267 cmdio .LogString (ctx , "SSH proxy connection established" )
297268
298269 cmdio .LogString (ctx , fmt .Sprintf ("Connection handover timeout: %v" , handoverTimeout ))
@@ -305,7 +276,7 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
305276 case <- gCtx .Done ():
306277 return gCtx .Err ()
307278 case <- handoverTicker .C :
308- err := proxy .InitiateHandover (gCtx )
279+ err := conn .InitiateHandover (gCtx )
309280 if err != nil {
310281 return err
311282 }
@@ -314,13 +285,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
314285 })
315286
316287 g .Go (func () error {
317- return proxy .Start (gCtx , os .Stdin , os .Stdout )
288+ return conn .Start (gCtx , os .Stdin , os .Stdout )
318289 })
319290
320291 return g .Wait ()
321292}
322293
323- func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int ) (string , int , error ) {
294+ func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time. Duration ) (string , int , error ) {
324295 cmdio .LogString (ctx , "Ensuring the cluster is running: " + clusterID )
325296 err := client .Clusters .EnsureClusterIsRunning (ctx , clusterID )
326297 if err != nil {
@@ -331,7 +302,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
331302 if errors .Is (err , errServerMetadata ) {
332303 cmdio .LogString (ctx , "SSH server is not running, starting it now..." )
333304
334- runID , err := submitSSHTunnelJob (ctx , client , clusterID , secretsScope , publicKeySecretName , version , shutdownDelay , maxClients )
305+ runID , err := submitSSHTunnelJob (ctx , client , clusterID , secretsScope , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
335306 if err != nil {
336307 return "" , 0 , fmt .Errorf ("failed to submit ssh server job: %w" , err )
337308 }
0 commit comments