@@ -26,7 +26,6 @@ import (
2626 "github.com/databricks/databricks-sdk-go/service/jobs"
2727 "github.com/databricks/databricks-sdk-go/service/workspace"
2828 "github.com/gorilla/websocket"
29- "golang.org/x/sync/errgroup"
3029)
3130
3231//go:embed ssh-server-bootstrap.py
@@ -63,7 +62,7 @@ type ClientOptions struct {
6362 AdditionalArgs []string
6463}
6564
66- func RunClient (ctx context.Context , client * databricks.WorkspaceClient , opts ClientOptions ) error {
65+ func Run (ctx context.Context , client * databricks.WorkspaceClient , opts ClientOptions ) error {
6766 ctx , cancel := context .WithCancel (ctx )
6867 defer cancel ()
6968
@@ -85,11 +84,11 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
8584 }
8685 cmdio .LogString (ctx , "Using SSH key: " + privateKeyPath )
8786
88- secretsScopeName , err := keys .PutSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
87+ keysSecretScopeName , err := keys .PutSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
8988 if err != nil {
9089 return fmt .Errorf ("failed to store public key in secret scope: %w" , err )
9190 }
92- cmdio .LogString (ctx , fmt .Sprintf ("Secrets scope: %s, key name: %s" , secretsScopeName , opts .ClientPublicKeyName ))
91+ cmdio .LogString (ctx , fmt .Sprintf ("Secrets scope: %s, key name: %s" , keysSecretScopeName , opts .ClientPublicKeyName ))
9392
9493 var userName string
9594 var serverPort int
@@ -101,7 +100,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
101100 if err := UploadTunnelReleases (ctx , client , version , opts .ReleasesDir ); err != nil {
102101 return fmt .Errorf ("failed to upload ssh-tunnel binaries: %w" , err )
103102 }
104- userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , secretsScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients , opts .ServerTimeout )
103+ userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , keysSecretScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients , opts .ServerTimeout )
105104 if err != nil {
106105 return fmt .Errorf ("failed to ensure that ssh server is running: %w" , err )
107106 }
@@ -124,7 +123,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
124123 cmdio .LogString (ctx , fmt .Sprintf ("Server port: %d" , serverPort ))
125124
126125 if opts .ProxyMode {
127- return startSSHProxy (ctx , client , opts .ClusterID , serverPort , opts .HandoverTimeout )
126+ return runSSHProxy (ctx , client , opts .ClusterID , serverPort , opts .HandoverTimeout )
128127 } else {
129128 cmdio .LogString (ctx , fmt .Sprintf ("Additional SSH arguments: %v" , opts .AdditionalArgs ))
130129 return spawnSSHClient (ctx , opts .ClusterID , userName , privateKeyPath , serverPort , opts .HandoverTimeout , opts .AdditionalArgs )
@@ -165,7 +164,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
165164 return serverPort , string (bodyBytes ), nil
166165}
167166
168- func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (int64 , error ) {
167+ func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , keysSecretScopeName , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (int64 , error ) {
169168 contentDir , err := sshWorkspace .GetWorkspaceContentDir (ctx , client , version , clusterID )
170169 if err != nil {
171170 return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
@@ -201,11 +200,11 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
201200 NotebookTask : & jobs.NotebookTask {
202201 NotebookPath : jobNotebookPath ,
203202 BaseParameters : map [string ]string {
204- "version" : version ,
205- "secretsScope " : secretsScope ,
206- "publicKeySecretName " : publicKeySecretName ,
207- "shutdownDelay" : shutdownDelay .String (),
208- "maxClients" : strconv .Itoa (maxClients ),
203+ "version" : version ,
204+ "keysSecretScopeName " : keysSecretScopeName ,
205+ "authorizedKeySecretName " : publicKeySecretName ,
206+ "shutdownDelay" : shutdownDelay .String (),
207+ "maxClients" : strconv .Itoa (maxClients ),
209208 },
210209 },
211210 TimeoutSeconds : int (serverTimeout .Seconds ()),
@@ -253,45 +252,14 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
253252 return sshCmd .Run ()
254253}
255254
256- func startSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time.Duration ) error {
257- g , gCtx := errgroup .WithContext (ctx )
258-
259- cmdio .LogString (ctx , "Establishing SSH proxy connection..." )
260- conn := proxy .NewProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
255+ func runSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time.Duration ) error {
256+ createConn := func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
261257 return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
262- })
263- if err := conn .Connect (gCtx ); err != nil {
264- return fmt .Errorf ("failed to connect to proxy: %w" , err )
265258 }
266- defer conn .Close ()
267- cmdio .LogString (ctx , "SSH proxy connection established" )
268-
269- cmdio .LogString (ctx , fmt .Sprintf ("Connection handover timeout: %v" , handoverTimeout ))
270- handoverTicker := time .NewTicker (handoverTimeout )
271- defer handoverTicker .Stop ()
272-
273- g .Go (func () error {
274- for {
275- select {
276- case <- gCtx .Done ():
277- return gCtx .Err ()
278- case <- handoverTicker .C :
279- err := conn .InitiateHandover (gCtx )
280- if err != nil {
281- return err
282- }
283- }
284- }
285- })
286-
287- g .Go (func () error {
288- return conn .Start (gCtx , os .Stdin , os .Stdout )
289- })
290-
291- return g .Wait ()
259+ return proxy .RunClientProxy (ctx , os .Stdin , os .Stdout , handoverTimeout , createConn )
292260}
293261
294- func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (string , int , error ) {
262+ func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , keysSecretScopeName , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (string , int , error ) {
295263 cmdio .LogString (ctx , "Ensuring the cluster is running: " + clusterID )
296264 err := client .Clusters .EnsureClusterIsRunning (ctx , clusterID )
297265 if err != nil {
@@ -302,7 +270,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
302270 if errors .Is (err , errServerMetadata ) {
303271 cmdio .LogString (ctx , "SSH server is not running, starting it now..." )
304272
305- runID , err := submitSSHTunnelJob (ctx , client , clusterID , secretsScope , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
273+ runID , err := submitSSHTunnelJob (ctx , client , clusterID , keysSecretScopeName , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
306274 if err != nil {
307275 return "" , 0 , fmt .Errorf ("failed to submit ssh server job: %w" , err )
308276 }
0 commit comments