Skip to content

Commit 2e6e2a7

Browse files
committed
Add liteswap header value for traffic routing (dev/test only).
1 parent dd68f9d commit 2e6e2a7

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

experimental/ssh/cmd/connect.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ For serverless compute:
3939
var releasesDir string
4040
var autoStartCluster bool
4141
var userKnownHostsFile string
42+
var liteswap string
4243

4344
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
4445
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
@@ -59,6 +60,9 @@ For serverless compute:
5960
cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client")
6061
cmd.Flags().MarkHidden("user-known-hosts-file")
6162

63+
cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)")
64+
cmd.Flags().MarkHidden("liteswap")
65+
6266
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
6367
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
6468
if proxyMode {
@@ -95,6 +99,7 @@ For serverless compute:
9599
ClientPublicKeyName: clientPublicKeyName,
96100
ClientPrivateKeyName: clientPrivateKeyName,
97101
UserKnownHostsFile: userKnownHostsFile,
102+
Liteswap: liteswap,
98103
AdditionalArgs: args,
99104
}
100105
return client.Run(ctx, wsClient, opts)

experimental/ssh/internal/client/client.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ type ClientOptions struct {
7474
AdditionalArgs []string
7575
// Optional path to the user known hosts file.
7676
UserKnownHostsFile string
77+
// Liteswap header value for traffic routing (dev/test only).
78+
Liteswap string
7779
}
7880

7981
func (o *ClientOptions) IsServerlessMode() bool {
@@ -107,7 +109,8 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
107109
}
108110

109111
// Only check cluster state for dedicated clusters
110-
if !opts.IsServerlessMode() {
112+
// TODO: we can remove liteswap check when we can start serverless GPU clusters via API.
113+
if !opts.IsServerlessMode() && opts.Liteswap == "" {
111114
err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster)
112115
if err != nil {
113116
return err
@@ -195,7 +198,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
195198
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
196199
// For dedicated clusters, clusterID should be the same as sessionID.
197200
// For serverless, clusterID is read from the workspace metadata.
198-
func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version string) (int, string, string, error) {
201+
func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) {
199202
wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID)
200203
if err != nil {
201204
return 0, "", "", errors.Join(errServerMetadata, err)
@@ -222,6 +225,9 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
222225
if err != nil {
223226
return 0, "", "", err
224227
}
228+
if liteswap != "" {
229+
req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap)
230+
}
225231
if err := client.Config.Authenticate(req); err != nil {
226232
return 0, "", "", err
227233
}
@@ -356,7 +362,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server
356362

357363
func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error {
358364
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
359-
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
365+
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort, opts.Liteswap)
360366
}
361367
requestHandoverTick := func() <-chan time.Time {
362368
return time.After(opts.HandoverTimeout)
@@ -389,7 +395,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
389395
// For dedicated clusters, use clusterID; for serverless, it will be read from metadata
390396
clusterID := opts.ClusterID
391397

392-
serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version)
398+
serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
393399
if errors.Is(err, errServerMetadata) {
394400
cmdio.LogString(ctx, "SSH server is not running, starting it now...")
395401

@@ -405,7 +411,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
405411
if ctx.Err() != nil {
406412
return "", 0, "", ctx.Err()
407413
}
408-
serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version)
414+
serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
409415
if err == nil {
410416
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
411417
break

experimental/ssh/internal/client/websockets.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/gorilla/websocket"
1010
)
1111

12-
func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int) (*websocket.Conn, error) {
12+
func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int, liteswap string) (*websocket.Conn, error) {
1313
url, err := getProxyURL(ctx, client, connID, clusterID, serverPort)
1414
if err != nil {
1515
return nil, fmt.Errorf("failed to get proxy URL: %w", err)
@@ -20,6 +20,9 @@ func createWebsocketConnection(ctx context.Context, client *databricks.Workspace
2020
return nil, fmt.Errorf("failed to create request: %w", err)
2121
}
2222

23+
if liteswap != "" {
24+
req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap)
25+
}
2326
if err := client.Config.Authenticate(req); err != nil {
2427
return nil, fmt.Errorf("failed to authenticate: %w", err)
2528
}

0 commit comments

Comments
 (0)