Skip to content

Commit 282e214

Browse files
committed
Fix race condition in the ssh tunnel handover logic
Proxy client can't initiate multiple handovers at the same time from its point of view, but proxy server can still be in the situation where 2 handovers racing for handoverMutex simultaneously. The first one can still wait for the close-message ack from the client (on old ws connection), while the client sends a new handover http request. Servers recieving loop has blockUntilHandoverComplete, which tries to grab a handoverMutex, which races for it with the new acceptHandover call. If the later executed first, we have a deadlock. The fix is better waitig logic on the reciever loop side - it cares to wait for a new connection to be activated, and now it does exactly this using connChanged condition variable.
1 parent 8126230 commit 282e214

File tree

3 files changed

+55
-12
lines changed

3 files changed

+55
-12
lines changed

experimental/ssh/internal/proxy/client_server_test.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func TestHandover(t *testing.T) {
158158
wg := sync.WaitGroup{}
159159
wg.Go(func() {
160160
for i := range TOTAL_MESSAGE_COUNT {
161-
if i > 0 && i%MESSAGES_PER_CHUNK == 0 {
161+
if i > 0 && i%MESSAGES_PER_CHUNK == 0 && i < TOTAL_MESSAGE_COUNT-1 {
162162
handoverChan <- time.Now()
163163
}
164164
message := fmt.Appendf(nil, "message %d\n", i)
@@ -178,3 +178,41 @@ func TestHandover(t *testing.T) {
178178
// clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here
179179
assert.Equal(t, expectedOutput, clientOutput.String())
180180
}
181+
182+
// Tests handovers in quick succession with few messages in between.
183+
// Not a real world scenario, but it can help uncover potential race conditions or deadlocks.
184+
func TestQuickHandover(t *testing.T) {
185+
server := createTestServer(t, 2, time.Hour)
186+
defer server.Close()
187+
188+
handoverChan := make(chan time.Time)
189+
requestHandoverTick := func() <-chan time.Time {
190+
return handoverChan
191+
}
192+
clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil)
193+
defer clientInputWriter.Close()
194+
195+
expectedOutput := ""
196+
197+
wg := sync.WaitGroup{}
198+
wg.Go(func() {
199+
for i := range 16 {
200+
if i == 4 || i == 8 || i == 12 {
201+
handoverChan <- time.Now()
202+
}
203+
message := fmt.Appendf(nil, "message %d\n", i)
204+
_, err := clientInputWriter.Write(message)
205+
if err != nil {
206+
t.Errorf("failed to write message %d: %v", i, err)
207+
}
208+
expectedOutput += string(message)
209+
}
210+
})
211+
212+
err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", 15))
213+
require.NoError(t, err, "failed to receive the last message (%d)", 15)
214+
215+
wg.Wait()
216+
217+
assert.Equal(t, expectedOutput, clientOutput.String())
218+
}

experimental/ssh/internal/proxy/proxy.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const (
2727
type proxyConnection struct {
2828
connID string
2929
conn atomic.Value // *websocket.Conn
30+
connChanged sync.Cond
3031
createWebsocketConnection createWebsocketConnectionFunc
3132

3233
handoverMutex sync.Mutex
@@ -39,6 +40,7 @@ type createWebsocketConnectionFunc func(ctx context.Context, connID string) (*we
3940
func newProxyConnection(createConn createWebsocketConnectionFunc) *proxyConnection {
4041
return &proxyConnection{
4142
connID: uuid.NewString(),
43+
connChanged: sync.Cond{L: &sync.Mutex{}},
4244
currentConnectionClosed: make(chan error),
4345
createWebsocketConnection: createConn,
4446
}
@@ -65,6 +67,7 @@ func (pc *proxyConnection) connect(ctx context.Context) error {
6567
return err
6668
}
6769
pc.conn.Store(conn)
70+
pc.connChanged.Broadcast()
6871
return nil
6972
}
7073

@@ -74,6 +77,7 @@ func (pc *proxyConnection) accept(w http.ResponseWriter, r *http.Request) error
7477
return err
7578
}
7679
pc.conn.Store(conn)
80+
pc.connChanged.Broadcast()
7781
return nil
7882
}
7983

@@ -138,9 +142,7 @@ func (pc *proxyConnection) runReceivingLoop(ctx context.Context, dst io.Writer)
138142
// Next time we read, we want to read from the new connection.
139143
// While we wait for the handover to complete, the new connection might be getting incoming messages.
140144
// They will be buffered by the TCP stack and will be read by us after the handover is complete.
141-
if err := pc.blockUntilHandoverComplete(conn); err != nil {
142-
return err
143-
}
145+
pc.waitForNewConnection(conn)
144146
// Continue with the receiving loop, pc.conn is now the new connection.
145147
continue
146148
} else {
@@ -170,14 +172,12 @@ func (pc *proxyConnection) signalClosedConnection(err error) error {
170172
}
171173
}
172174

173-
func (pc *proxyConnection) blockUntilHandoverComplete(conn *websocket.Conn) error {
174-
pc.handoverMutex.Lock()
175-
defer pc.handoverMutex.Unlock()
176-
// Sanity check to ensure we are not in the middle of a handover - this should not happen.
177-
if pc.conn.Load() == conn {
178-
return errors.New("handover mutex acquired while the old connection is still active")
175+
func (pc *proxyConnection) waitForNewConnection(conn *websocket.Conn) {
176+
pc.connChanged.L.Lock()
177+
defer pc.connChanged.L.Unlock()
178+
for pc.conn.Load() == conn {
179+
pc.connChanged.Wait()
179180
}
180-
return nil
181181
}
182182

183183
func (pc *proxyConnection) close() error {
@@ -226,6 +226,7 @@ func (pc *proxyConnection) initiateHandover(ctx context.Context) error {
226226
}
227227

228228
pc.conn.Store(newConn)
229+
pc.connChanged.Broadcast()
229230

230231
return nil
231232
}
@@ -272,6 +273,7 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
272273
}
273274

274275
pc.conn.Store(newConn)
276+
pc.connChanged.Broadcast()
275277

276278
return nil
277279
}

experimental/ssh/internal/proxy/proxy_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313
"time"
1414

15+
"github.com/databricks/cli/libs/log"
1516
"github.com/gorilla/websocket"
1617
"github.com/stretchr/testify/assert"
1718
"github.com/stretchr/testify/require"
@@ -98,6 +99,7 @@ func setupTestServer(ctx context.Context, t *testing.T) *TestProxy {
9899
serverOutput := newTestBuffer(t)
99100
var serverProxy *proxyConnection
100101
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
102+
ctx := log.NewContext(ctx, log.GetLogger(ctx).With("Server", true))
101103
if serverProxy != nil {
102104
err := serverProxy.acceptHandover(ctx, w, r)
103105
if err != nil {
@@ -138,6 +140,7 @@ func createTestWebsocketConnection(url string) (*websocket.Conn, error) {
138140
}
139141

140142
func setupTestClient(ctx context.Context, t *testing.T, serverURL string) *TestProxy {
143+
ctx = log.NewContext(ctx, log.GetLogger(ctx).With("Client", true))
141144
clientInput, clientInputWriter := io.Pipe()
142145
clientOutput := newTestBuffer(t)
143146
wsURL := "ws" + serverURL[4:]
@@ -221,7 +224,7 @@ func TestConnectionHandover(t *testing.T) {
221224
for i := range TOTAL_MESSAGE_COUNT {
222225
client.Input.Write(createTestMessage("client", i)) // nolint:errcheck
223226
server.Input.Write(createTestMessage("server", i)) // nolint:errcheck
224-
if i > 0 && i%MESSAGES_PER_CHUNK == 0 {
227+
if i > 0 && i%MESSAGES_PER_CHUNK == 0 && i < TOTAL_MESSAGE_COUNT-1 {
225228
handoverChan <- struct{}{}
226229
}
227230
}

0 commit comments

Comments
 (0)