Skip to content

Commit c2d0e9d

Browse files
authored
Fix race condition in the ssh tunnel handover logic (#3888)
## Changes Proxy client can't initiate multiple handovers at the same time from its point of view, but proxy server can still be in the situation where 2 handovers racing for handoverMutex simultaneously. The first one can still wait for the close-message ack from the client (on old ws connection), while the client sends a new handover http request. Servers receiving loop has blockUntilHandoverComplete, which tries to grab a handoverMutex, which races for it with the new acceptHandover call. If the later executed first, we have a deadlock. The fix is better waiting logic on the receiver loop side - it needs to wait for a new connection to be activated, and now it does exactly this using connSwapped channel. Refactored the overall coordination logic a bit: we now use handoverState atomic to store handover context (with a timeout), and two channels necessary to coordinate between receiving loop and the handover initiator. `isHandover` flag is not gone, as `handoverState` itself is an indicator of when the handover is in progress. ### Testing Also ran tests with `-race` and `-count 100` flags \+ spawning manual ssh sessions with short handover timeout
1 parent 6e1810a commit c2d0e9d

File tree

3 files changed

+162
-69
lines changed

3 files changed

+162
-69
lines changed

experimental/ssh/internal/proxy/client_server_test.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func TestHandover(t *testing.T) {
158158
wg := sync.WaitGroup{}
159159
wg.Go(func() {
160160
for i := range TOTAL_MESSAGE_COUNT {
161-
if i > 0 && i%MESSAGES_PER_CHUNK == 0 {
161+
if i > 0 && i%MESSAGES_PER_CHUNK == 0 && i < TOTAL_MESSAGE_COUNT-1 {
162162
handoverChan <- time.Now()
163163
}
164164
message := fmt.Appendf(nil, "message %d\n", i)
@@ -178,3 +178,41 @@ func TestHandover(t *testing.T) {
178178
// clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here
179179
assert.Equal(t, expectedOutput, clientOutput.String())
180180
}
181+
182+
// Tests handovers in quick succession with few messages in between.
183+
// Not a real world scenario, but it can help uncover potential race conditions or deadlocks.
184+
func TestQuickHandover(t *testing.T) {
185+
server := createTestServer(t, 2, time.Hour)
186+
defer server.Close()
187+
188+
handoverChan := make(chan time.Time)
189+
requestHandoverTick := func() <-chan time.Time {
190+
return handoverChan
191+
}
192+
clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil)
193+
defer clientInputWriter.Close()
194+
195+
expectedOutput := ""
196+
197+
wg := sync.WaitGroup{}
198+
wg.Go(func() {
199+
for i := range 16 {
200+
if i == 4 || i == 8 || i == 12 {
201+
handoverChan <- time.Now()
202+
}
203+
message := fmt.Appendf(nil, "message %d\n", i)
204+
_, err := clientInputWriter.Write(message)
205+
if err != nil {
206+
t.Errorf("failed to write message %d: %v", i, err)
207+
}
208+
expectedOutput += string(message)
209+
}
210+
})
211+
212+
err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", 15))
213+
require.NoError(t, err, "failed to receive the last message (%d)", 15)
214+
215+
wg.Wait()
216+
217+
assert.Equal(t, expectedOutput, clientOutput.String())
218+
}

experimental/ssh/internal/proxy/proxy.go

Lines changed: 119 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,89 @@ import (
1818
var errProxyEOF = errors.New("proxy EOF error")
1919

2020
const (
21-
proxyBufferSize = 4 * 1024 // Same as gorilla/websocket default read/write buffer sizes. Bigger payloads will be split into multiple ws frames.
22-
proxyHandoverInitTimeout = 30 * time.Second
23-
proxyHandoverAcceptTimeout = 25 * time.Second
24-
proxyHandoverAckCloseConnTimeout = 15 * time.Second
21+
// Same as gorilla/websocket default read/write buffer sizes. Bigger payloads will be split into multiple ws frames.
22+
proxyBufferSize = 4 * 1024
23+
// Timeout for the full handover process, when initiated by the client.
24+
proxyHandoverInitTimeout = 30 * time.Second
25+
// Timeout for the handover process, when accepted by the server.
26+
proxyHandoverAcceptTimeout = 25 * time.Second
2527
)
2628

29+
// handoverCoordination holds the context and channels used to coordinate a single handover operation
30+
// between the receiving loop and the handover initiator (initiateHandover or acceptHandover).
31+
type handoverCoordination struct {
32+
// Context with timeout for the entire handover operation.
33+
// Shared between the handover initiator and the receiving loop.
34+
ctx context.Context
35+
// Used by the receiving loop to signal about the closure of the current connection to the handover initiator.
36+
// After signalling, the receiving loop will block until connSwapped channel is signaled.
37+
connClosed chan error
38+
// Used by the handover initiator to signal the receiving loop that it's safe to start reading from the new connection.
39+
connSwapped chan struct{}
40+
}
41+
42+
func (c *handoverCoordination) signalConnectionClosed(err error) error {
43+
select {
44+
case c.connClosed <- err:
45+
return err
46+
case <-c.ctx.Done():
47+
return c.ctx.Err()
48+
}
49+
}
50+
51+
func (c *handoverCoordination) waitForConnectionToClose() error {
52+
select {
53+
case err := <-c.connClosed:
54+
return err
55+
case <-c.ctx.Done():
56+
return c.ctx.Err()
57+
}
58+
}
59+
60+
func (c *handoverCoordination) signalConnectionSwapped() error {
61+
select {
62+
case c.connSwapped <- struct{}{}:
63+
return nil
64+
case <-c.ctx.Done():
65+
return c.ctx.Err()
66+
}
67+
}
68+
69+
func (c *handoverCoordination) waitForConnectionToSwap() error {
70+
select {
71+
case <-c.connSwapped:
72+
return nil
73+
case <-c.ctx.Done():
74+
return c.ctx.Err()
75+
}
76+
}
77+
78+
// proxyConnection is the main struct that manages the websocket connection and the handover process.
79+
// It works both on the client and the server side (see internal/client and internal/server packages).
80+
// It has 3 goroutines:
81+
// - Sending loop: reads from src and sends to the current connection.
82+
// - Receiving loop: reads from the current connection and writes to dst.
83+
// - Main: starts the other two (start method) and initiates or accepts handover (initiateHandover or acceptHandover).
2784
type proxyConnection struct {
28-
connID string
29-
conn atomic.Value // *websocket.Conn
85+
// Each connection has a unique ID.
86+
connID string
87+
// Function to create a new websocket connection. Tests can override this to use a test websocket connection.
3088
createWebsocketConnection createWebsocketConnectionFunc
31-
32-
handoverMutex sync.Mutex
33-
isHandover atomic.Bool
34-
currentConnectionClosed chan error
89+
// Atomic that keeps the currently active connection.
90+
// Can be swapped during handover.
91+
conn atomic.Pointer[websocket.Conn]
92+
// Prevents multiple handover processes from running concurrently.
93+
// Blocks proxying any outgoing messages during the entire handover in the sending loop.
94+
handoverMutex sync.Mutex
95+
// Atomic that holds the current handover coordination channels, or nil if no handover is in progress.
96+
handoverState atomic.Pointer[handoverCoordination]
3597
}
3698

3799
type createWebsocketConnectionFunc func(ctx context.Context, connID string) (*websocket.Conn, error)
38100

39101
func newProxyConnection(createConn createWebsocketConnectionFunc) *proxyConnection {
40102
return &proxyConnection{
41103
connID: uuid.NewString(),
42-
currentConnectionClosed: make(chan error),
43104
createWebsocketConnection: createConn,
44105
}
45106
}
@@ -114,7 +175,7 @@ func (pc *proxyConnection) runSendingLoop(ctx context.Context, src io.Reader) er
114175
func (pc *proxyConnection) sendMessage(mt int, data []byte) error {
115176
pc.handoverMutex.Lock()
116177
defer pc.handoverMutex.Unlock()
117-
conn := pc.conn.Load().(*websocket.Conn)
178+
conn := pc.conn.Load()
118179
return conn.WriteMessage(mt, data)
119180
}
120181

@@ -123,22 +184,23 @@ func (pc *proxyConnection) runReceivingLoop(ctx context.Context, dst io.Writer)
123184
if ctx.Err() != nil {
124185
return ctx.Err()
125186
}
126-
conn := pc.conn.Load().(*websocket.Conn)
187+
conn := pc.conn.Load()
127188
mt, data, err := conn.ReadMessage()
128189
if err != nil {
129190
// During handover a normal closure is expected, but any other error must stop the read loop (and eventually terminate the ssh session).
130-
if pc.isHandover.Load() {
191+
if handover := pc.handoverState.Load(); handover != nil {
131192
var closeConnSignal error
132193
if !websocket.IsCloseError(err, websocket.CloseNormalClosure) {
133194
closeConnSignal = fmt.Errorf("failed to read from websocket during handover: %w", err)
134195
}
135-
if err := pc.signalClosedConnection(closeConnSignal); err != nil {
196+
// Signal the current connection is closed to the handover initiator (initiateHandover or acceptHandover).
197+
if err := handover.signalConnectionClosed(closeConnSignal); err != nil {
136198
return err
137199
}
138-
// Next time we read, we want to read from the new connection.
200+
// Wait for the handover initiator to swap the connection.
139201
// While we wait for the handover to complete, the new connection might be getting incoming messages.
140202
// They will be buffered by the TCP stack and will be read by us after the handover is complete.
141-
if err := pc.blockUntilHandoverComplete(conn); err != nil {
203+
if err := handover.waitForConnectionToSwap(); err != nil {
142204
return err
143205
}
144206
// Continue with the receiving loop, pc.conn is now the new connection.
@@ -161,25 +223,6 @@ func (pc *proxyConnection) runReceivingLoop(ctx context.Context, dst io.Writer)
161223
}
162224
}
163225

164-
func (pc *proxyConnection) signalClosedConnection(err error) error {
165-
select {
166-
case pc.currentConnectionClosed <- err:
167-
return err
168-
case <-time.After(proxyHandoverAckCloseConnTimeout):
169-
return fmt.Errorf("timeout waiting for acknowledgement of old connection closed message: %w", err)
170-
}
171-
}
172-
173-
func (pc *proxyConnection) blockUntilHandoverComplete(conn *websocket.Conn) error {
174-
pc.handoverMutex.Lock()
175-
defer pc.handoverMutex.Unlock()
176-
// Sanity check to ensure we are not in the middle of a handover - this should not happen.
177-
if pc.conn.Load() == conn {
178-
return errors.New("handover mutex acquired while the old connection is still active")
179-
}
180-
return nil
181-
}
182-
183226
func (pc *proxyConnection) close() error {
184227
// Keep in mind that pc.sendMessage blocks during handover
185228
err := pc.sendMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
@@ -198,35 +241,40 @@ func (pc *proxyConnection) initiateHandover(ctx context.Context) error {
198241
pc.handoverMutex.Lock()
199242
defer pc.handoverMutex.Unlock()
200243

201-
// When handover flag is set, the receiving loop handles a close message from the current connection
202-
// as a signal to finish the handover and switch to the new connection.
203-
pc.isHandover.Store(true)
204-
defer pc.isHandover.Store(false)
205-
206-
ctx, cancel := context.WithTimeout(ctx, proxyHandoverInitTimeout)
244+
handoverCtx, cancel := context.WithTimeout(ctx, proxyHandoverInitTimeout)
207245
defer cancel()
246+
handoverState := &handoverCoordination{
247+
ctx: handoverCtx,
248+
connClosed: make(chan error),
249+
connSwapped: make(chan struct{}),
250+
}
251+
// Existence of the handoverState indicates to the receiving loop that we are in the middle of a handover process,
252+
// and should treat close messages as a signal to finish the handover instead of erroring out.
253+
pc.handoverState.Store(handoverState)
254+
defer pc.handoverState.Store(nil)
208255

209256
// Create a new websocket connection by sending an /ssh?id=<connID> request to the server.
210257
// When server realises it's an ID of an existing connection, it will start AcceptHandover process.
211-
newConn, err := pc.createWebsocketConnection(ctx, pc.connID)
258+
newConn, err := pc.createWebsocketConnection(handoverCtx, pc.connID)
212259
if err != nil {
213260
return fmt.Errorf("failed to create new websocket connection: %w", err)
214261
}
215262

216263
// Wait for the server to close the old connection
217-
select {
218-
case err := <-pc.currentConnectionClosed:
219-
if err != nil {
220-
newConn.Close()
221-
return fmt.Errorf("connection handover failed: %w", err)
222-
}
223-
case <-ctx.Done():
264+
// (it does so when it receives an /ssh request with known connection ID and starts AcceptHandover process).
265+
// Receiving loop will signal about closed connection to the coord.connClosed channel.
266+
if err := handoverState.waitForConnectionToClose(); err != nil {
224267
newConn.Close()
225-
return ctx.Err()
268+
return err
226269
}
227270

228271
pc.conn.Store(newConn)
229272

273+
// Let the receiving loop know that the current connection is swapped and it's safe to start reading from it.
274+
if err := handoverState.signalConnectionSwapped(); err != nil {
275+
newConn.Close()
276+
return err
277+
}
230278
return nil
231279
}
232280

@@ -235,13 +283,17 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
235283
pc.handoverMutex.Lock()
236284
defer pc.handoverMutex.Unlock()
237285

238-
// When handover flag is set, the receiving loop handles a close message from the current connection
239-
// as a signal to finish the handover and switch to the new connection.
240-
pc.isHandover.Store(true)
241-
defer pc.isHandover.Store(false)
242-
243-
ctx, cancel := context.WithTimeout(ctx, proxyHandoverAcceptTimeout)
286+
handoverCtx, cancel := context.WithTimeout(ctx, proxyHandoverAcceptTimeout)
244287
defer cancel()
288+
handoverState := &handoverCoordination{
289+
ctx: handoverCtx,
290+
connClosed: make(chan error),
291+
connSwapped: make(chan struct{}),
292+
}
293+
// Existence of the handoverState indicates to the receiving loop that we are in the middle of a handover process,
294+
// and should treat close messages as a signal to finish the handover instead of erroring out.
295+
pc.handoverState.Store(handoverState)
296+
defer pc.handoverState.Store(nil)
245297

246298
newConn, err := pc.acceptWebsocketConnection(w, r)
247299
if err != nil {
@@ -250,7 +302,7 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
250302

251303
// Signal the client to complete handover by closing the old connection.
252304
// Not using pc.sendMessage here, because it's blocked by the handover mutex.
253-
currentConn := pc.conn.Load().(*websocket.Conn)
305+
currentConn := pc.conn.Load()
254306
err = currentConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "handover"))
255307
if err != nil {
256308
newConn.Close()
@@ -259,20 +311,20 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
259311

260312
// Wait for the client to acknowledge the closure of the old connection.
261313
// On the client its done automatically by the websocket library with the default close handler.
262-
// On the server we then receive a close error in the RunReceivingLoop and signal about it to the handoverOldConnClosed channel.
263-
select {
264-
case err := <-pc.currentConnectionClosed:
265-
if err != nil {
266-
newConn.Close()
267-
return fmt.Errorf("connection handover failed: %w", err)
268-
}
269-
case <-ctx.Done():
314+
// On the server we then receive a close error in the RunReceivingLoop and signal about it to the coord.connClosed channel.
315+
if err := handoverState.waitForConnectionToClose(); err != nil {
270316
newConn.Close()
271-
return ctx.Err()
317+
return err
272318
}
273319

274320
pc.conn.Store(newConn)
275321

322+
// Let the receiving loop know that the current connection is swapped and it's safe to start reading from it.
323+
if err := handoverState.signalConnectionSwapped(); err != nil {
324+
newConn.Close()
325+
return err
326+
}
327+
276328
return nil
277329
}
278330

experimental/ssh/internal/proxy/proxy_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313
"time"
1414

15+
"github.com/databricks/cli/libs/log"
1516
"github.com/gorilla/websocket"
1617
"github.com/stretchr/testify/assert"
1718
"github.com/stretchr/testify/require"
@@ -98,6 +99,7 @@ func setupTestServer(ctx context.Context, t *testing.T) *TestProxy {
9899
serverOutput := newTestBuffer(t)
99100
var serverProxy *proxyConnection
100101
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
102+
ctx := log.NewContext(ctx, log.GetLogger(ctx).With("Server", true))
101103
if serverProxy != nil {
102104
err := serverProxy.acceptHandover(ctx, w, r)
103105
if err != nil {
@@ -138,6 +140,7 @@ func createTestWebsocketConnection(url string) (*websocket.Conn, error) {
138140
}
139141

140142
func setupTestClient(ctx context.Context, t *testing.T, serverURL string) *TestProxy {
143+
ctx = log.NewContext(ctx, log.GetLogger(ctx).With("Client", true))
141144
clientInput, clientInputWriter := io.Pipe()
142145
clientOutput := newTestBuffer(t)
143146
wsURL := "ws" + serverURL[4:]
@@ -221,7 +224,7 @@ func TestConnectionHandover(t *testing.T) {
221224
for i := range TOTAL_MESSAGE_COUNT {
222225
client.Input.Write(createTestMessage("client", i)) // nolint:errcheck
223226
server.Input.Write(createTestMessage("server", i)) // nolint:errcheck
224-
if i > 0 && i%MESSAGES_PER_CHUNK == 0 {
227+
if i > 0 && i%MESSAGES_PER_CHUNK == 0 && i < TOTAL_MESSAGE_COUNT-1 {
225228
handoverChan <- struct{}{}
226229
}
227230
}

0 commit comments

Comments
 (0)