@@ -74,6 +74,8 @@ type ClientOptions struct {
7474 AdditionalArgs []string
7575 // Optional path to the user known hosts file.
7676 UserKnownHostsFile string
77+ // Liteswap header value for traffic routing (dev/test only).
78+ Liteswap string
7779}
7880
7981func (o * ClientOptions ) IsServerlessMode () bool {
@@ -107,7 +109,8 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
107109 }
108110
109111 // Only check cluster state for dedicated clusters
110- if ! opts .IsServerlessMode () {
112+ // TODO: we can remove liteswap check when we can start serverless GPU clusters via API.
113+ if ! opts .IsServerlessMode () && opts .Liteswap == "" {
111114 err := checkClusterState (ctx , client , opts .ClusterID , opts .AutoStartCluster )
112115 if err != nil {
113116 return err
@@ -195,7 +198,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
195198// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
196199// For dedicated clusters, clusterID should be the same as sessionID.
197200// For serverless, clusterID is read from the workspace metadata.
198- func getServerMetadata (ctx context.Context , client * databricks.WorkspaceClient , sessionID , clusterID , version string ) (int , string , string , error ) {
201+ func getServerMetadata (ctx context.Context , client * databricks.WorkspaceClient , sessionID , clusterID , version , liteswap string ) (int , string , string , error ) {
199202 wsMetadata , err := sshWorkspace .GetWorkspaceMetadata (ctx , client , version , sessionID )
200203 if err != nil {
201204 return 0 , "" , "" , errors .Join (errServerMetadata , err )
@@ -222,6 +225,9 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
222225 if err != nil {
223226 return 0 , "" , "" , err
224227 }
228+ if liteswap != "" {
229+ req .Header .Set ("x-databricks-traffic-id" , "testenv://liteswap/" + liteswap )
230+ }
225231 if err := client .Config .Authenticate (req ); err != nil {
226232 return 0 , "" , "" , err
227233 }
@@ -356,7 +362,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server
356362
357363func runSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , serverPort int , clusterID string , opts ClientOptions ) error {
358364 createConn := func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
359- return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
365+ return createWebsocketConnection (ctx , client , connID , clusterID , serverPort , opts . Liteswap )
360366 }
361367 requestHandoverTick := func () <- chan time.Time {
362368 return time .After (opts .HandoverTimeout )
@@ -389,7 +395,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
389395 // For dedicated clusters, use clusterID; for serverless, it will be read from metadata
390396 clusterID := opts .ClusterID
391397
392- serverPort , userName , effectiveClusterID , err := getServerMetadata (ctx , client , sessionID , clusterID , version )
398+ serverPort , userName , effectiveClusterID , err := getServerMetadata (ctx , client , sessionID , clusterID , version , opts . Liteswap )
393399 if errors .Is (err , errServerMetadata ) {
394400 cmdio .LogString (ctx , "SSH server is not running, starting it now..." )
395401
@@ -405,7 +411,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
405411 if ctx .Err () != nil {
406412 return "" , 0 , "" , ctx .Err ()
407413 }
408- serverPort , userName , effectiveClusterID , err = getServerMetadata (ctx , client , sessionID , clusterID , version )
414+ serverPort , userName , effectiveClusterID , err = getServerMetadata (ctx , client , sessionID , clusterID , version , opts . Liteswap )
409415 if err == nil {
410416 cmdio .LogString (ctx , "Health check successful, starting ssh WebSocket connection..." )
411417 break
0 commit comments