Skip to content

Commit 4d8abe4

Browse files
authored
Improve ssh handover logic and tests (#3757)
## Changes With this, instead of doing a handover every N seconds flat, we first wait N seconds, then do a handover for however long it takes to complete, then wait N seconds again, etc. In tests we can now better control when handover should be performed, and the actual logic behaves better with small handover timeouts. Tests now don't use any hardcoded timeouts for initiating handover and should be more reliable. ## Tests Tested with `-count 1000 -parallel 100` and `-race` flags.
1 parent 7148fde commit 4d8abe4

File tree

4 files changed

+58
-68
lines changed

4 files changed

+58
-68
lines changed

experimental/ssh/internal/client/client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,10 @@ func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, server
264264
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
265265
return createWebsocketConnection(ctx, client, connID, opts.ClusterID, serverPort)
266266
}
267-
return proxy.RunClientProxy(ctx, os.Stdin, os.Stdout, opts.HandoverTimeout, createConn)
267+
requestHandoverTick := func() <-chan time.Time {
268+
return time.After(opts.HandoverTimeout)
269+
}
270+
return proxy.RunClientProxy(ctx, os.Stdin, os.Stdout, requestHandoverTick, createConn)
268271
}
269272

270273
func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, autoStart bool) error {

experimental/ssh/internal/proxy/client.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"golang.org/x/sync/errgroup"
1111
)
1212

13-
func RunClientProxy(ctx context.Context, src io.Reader, dst io.Writer, handoverTimeout time.Duration, createConn createWebsocketConnectionFunc) error {
13+
func RunClientProxy(ctx context.Context, src io.Reader, dst io.Writer, requestHandoverTick func() <-chan time.Time, createConn createWebsocketConnectionFunc) error {
1414
proxy := newProxyConnection(createConn)
1515
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
1616
g, gCtx := errgroup.WithContext(ctx)
@@ -20,16 +20,12 @@ func RunClientProxy(ctx context.Context, src io.Reader, dst io.Writer, handoverT
2020
defer proxy.close()
2121
cmdio.LogString(ctx, "SSH proxy connection established")
2222

23-
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
24-
handoverTicker := time.NewTicker(handoverTimeout)
25-
defer handoverTicker.Stop()
26-
2723
g.Go(func() error {
2824
for {
2925
select {
3026
case <-gCtx.Done():
3127
return gCtx.Err()
32-
case <-handoverTicker.C:
28+
case <-requestHandoverTick():
3329
err := proxy.initiateHandover(gCtx)
3430
if err != nil {
3531
return err

experimental/ssh/internal/proxy/client_server_test.go

Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"io"
1111
"net/http/httptest"
1212
"os/exec"
13-
"strings"
1413
"sync"
1514
"testing"
1615
"time"
@@ -30,21 +29,23 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration)
3029
return httptest.NewServer(proxyServer)
3130
}
3231

33-
func createTestClient(t *testing.T, serverURL string, handoverTimeout time.Duration, errChan, createConnChan chan error) (io.WriteCloser, *testBuffer) {
32+
func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) (io.WriteCloser, *testBuffer) {
3433
ctx := t.Context()
3534
clientInput, clientInputWriter := io.Pipe()
3635
clientOutput := newTestBuffer(t)
3736
wsURL := "ws" + serverURL[4:]
3837
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
3938
url := fmt.Sprintf("%s?id=%s", wsURL, connID)
4039
conn, _, err := websocket.DefaultDialer.Dial(url, nil) // nolint:bodyclose
41-
if createConnChan != nil {
42-
createConnChan <- err
43-
}
4440
return conn, err
4541
}
42+
if requestHandoverTick == nil {
43+
requestHandoverTick = func() <-chan time.Time {
44+
return time.After(time.Hour)
45+
}
46+
}
4647
go func() {
47-
err := RunClientProxy(ctx, clientInput, clientOutput, handoverTimeout, createConn)
48+
err := RunClientProxy(ctx, clientInput, clientOutput, requestHandoverTick, createConn)
4849
if err != nil && !isNormalClosure(err) && !errors.Is(err, context.Canceled) {
4950
if errChan != nil {
5051
errChan <- err
@@ -59,7 +60,7 @@ func createTestClient(t *testing.T, serverURL string, handoverTimeout time.Durat
5960
func TestClientServerEcho(t *testing.T) {
6061
server := createTestServer(t, 2, time.Hour)
6162
defer server.Close()
62-
clientInputWriter, clientOutput := createTestClient(t, server.URL, time.Hour, nil, nil)
63+
clientInputWriter, clientOutput := createTestClient(t, server.URL, nil, nil)
6364
defer clientInputWriter.Close()
6465

6566
testMsg1 := []byte("test message 1\n")
@@ -81,9 +82,9 @@ func TestClientServerEcho(t *testing.T) {
8182
func TestMultipleClients(t *testing.T) {
8283
server := createTestServer(t, 2, time.Hour)
8384
defer server.Close()
84-
clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, time.Hour, nil, nil)
85+
clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, nil, nil)
8586
defer clientInputWriter1.Close()
86-
clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, time.Hour, nil, nil)
87+
clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, nil, nil)
8788
defer clientInputWriter2.Close()
8889

8990
messageCount := 10
@@ -113,9 +114,9 @@ func TestMaxClients(t *testing.T) {
113114
maxClients := 2
114115
server := createTestServer(t, maxClients, time.Hour)
115116
defer server.Close()
116-
clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, time.Hour, nil, nil)
117+
clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, nil, nil)
117118
defer clientInputWriter1.Close()
118-
clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, time.Hour, nil, nil)
119+
clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, nil, nil)
119120
defer clientInputWriter2.Close()
120121

121122
testMsg1 := []byte("test message 1\n")
@@ -129,7 +130,7 @@ func TestMaxClients(t *testing.T) {
129130
require.NoError(t, err)
130131

131132
errChan := make(chan error, 1)
132-
clientInputWriter3, _ := createTestClient(t, server.URL, time.Hour, errChan, nil)
133+
clientInputWriter3, _ := createTestClient(t, server.URL, nil, errChan)
133134
defer clientInputWriter3.Close()
134135
select {
135136
case err = <-errChan:
@@ -143,60 +144,35 @@ func TestHandover(t *testing.T) {
143144
server := createTestServer(t, 2, time.Hour)
144145
defer server.Close()
145146

146-
maxHandoverCount := 3
147-
handoverTimeout := 500 * time.Millisecond
148-
createConnChan := make(chan error, 1)
149-
clientInputWriter, clientOutput := createTestClient(t, server.URL, handoverTimeout, nil, createConnChan)
147+
handoverChan := make(chan time.Time)
148+
requestHandoverTick := func() <-chan time.Time {
149+
return handoverChan
150+
}
151+
clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil)
150152
defer clientInputWriter.Close()
151153

152-
messageCount := 0
153154
expectedOutput := ""
154-
sendMessage := func() {
155-
message := fmt.Appendf(nil, "message %d\n", messageCount)
156-
_, err := clientInputWriter.Write(message)
157-
if err != nil {
158-
t.Errorf("failed to write message %d: %v", messageCount, err)
159-
}
160-
messageCount++
161-
if messageCount > TOTAL_MESSAGE_COUNT {
162-
t.Errorf("exceeded total message count, test buffer won't work correctly")
163-
}
164-
expectedOutput += string(message)
165-
}
166155

167156
wg := sync.WaitGroup{}
168157
wg.Go(func() {
169-
handoverCount := 0
170-
for {
171-
select {
172-
case <-createConnChan:
173-
sendMessage()
174-
handoverCount++
175-
if handoverCount >= maxHandoverCount {
176-
return
177-
}
178-
default:
179-
sendMessage()
180-
time.Sleep(time.Millisecond)
158+
for i := range TOTAL_MESSAGE_COUNT {
159+
if i > 0 && i%MESSAGES_PER_CHUNK == 0 {
160+
handoverChan <- time.Now()
181161
}
162+
message := fmt.Appendf(nil, "message %d\n", i)
163+
_, err := clientInputWriter.Write(message)
164+
if err != nil {
165+
t.Errorf("failed to write message %d: %v", i, err)
166+
}
167+
expectedOutput += string(message)
182168
}
183169
})
184170

185-
wg.Wait()
171+
err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", TOTAL_MESSAGE_COUNT-1))
172+
require.NoError(t, err, "failed to receive the last message (%d)", TOTAL_MESSAGE_COUNT-1)
186173

187-
for i := 0; i < messageCount; {
188-
// Client can receive multiple echo messages in one response,
189-
// so we split them again and verify each one.
190-
data, err := clientOutput.WaitForWrite()
191-
require.NoError(t, err, "failed to receive message %d", i)
192-
lines := strings.SplitSeq(string(data), "\n")
193-
for line := range lines {
194-
if line != "" {
195-
assert.Equal(t, fmt.Sprintf("message %d\n", i), line+"\n")
196-
i++
197-
}
198-
}
199-
}
174+
wg.Wait()
200175

176+
// clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here
201177
assert.Equal(t, expectedOutput, clientOutput.String())
202178
}

experimental/ssh/internal/proxy/proxy_test.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"net/http"
1010
"net/http/httptest"
11+
"sync"
1112
"testing"
1213
"time"
1314

@@ -18,6 +19,7 @@ import (
1819

1920
type testBuffer struct {
2021
t *testing.T
22+
m sync.Mutex
2123
buff *bytes.Buffer
2224
OnWrite chan []byte
2325
}
@@ -31,6 +33,7 @@ const (
3133
func newTestBuffer(t *testing.T) *testBuffer {
3234
return &testBuffer{
3335
t: t,
36+
m: sync.Mutex{},
3437
buff: new(bytes.Buffer),
3538
OnWrite: make(chan []byte, TOTAL_MESSAGE_COUNT),
3639
}
@@ -45,7 +48,9 @@ func (tb *testBuffer) Read(p []byte) (n int, err error) {
4548
}
4649

4750
func (tb *testBuffer) Write(p []byte) (n int, err error) {
51+
tb.m.Lock()
4852
n, err = tb.buff.Write(p)
53+
tb.m.Unlock()
4954
require.NoError(tb.t, err)
5055
tb.OnWrite <- p
5156
return n, err
@@ -61,12 +66,22 @@ func (tb *testBuffer) AssertWrite(expected []byte) error {
6166
}
6267
}
6368

64-
func (tb *testBuffer) WaitForWrite() ([]byte, error) {
65-
select {
66-
case data := <-tb.OnWrite:
67-
return data, nil
68-
case <-time.After(2 * time.Second):
69-
return nil, errors.New("timeout waiting for write")
69+
func (tb *testBuffer) Contains(data []byte) bool {
70+
tb.m.Lock()
71+
defer tb.m.Unlock()
72+
return bytes.Contains(tb.buff.Bytes(), data)
73+
}
74+
75+
func (tb *testBuffer) WaitForWrite(expected []byte) error {
76+
for {
77+
select {
78+
case <-tb.OnWrite:
79+
if tb.Contains(expected) {
80+
return nil
81+
}
82+
case <-time.After(2 * time.Second):
83+
return errors.New("timeout waiting for write")
84+
}
7085
}
7186
}
7287

0 commit comments

Comments
 (0)