Skip to content

Commit 497bc87

Browse files
authored
ssh/proxy: improve cancellation and resource cleanup (#3965)
## Changes <!-- Brief summary of your changes that is easy to understand --> There are edge cases where the proxy connection could hang indefinitely during shutdown if the sending or receiving loops were blocked on I/O (e.g. reading from stdin or the websocket). - Introduce a dedicated cleanup goroutine in `start` that waits for context cancellation and explicitly closes both the websocket connection and the source. This ensures that blocked `Read` calls return immediately with an error, allowing the `errgroup` to finish. - Ensure that sending and receiving loops always return a non-nil error to trigger `errgroup` cancellation for the other loop. - Update `RunClientProxy` and `proxyConnection.start` to accept `io.ReadCloser` instead of `io.Reader`, allowing the proxy to close the input source. ## Tests Existing tests with additional `-count` and `-race` flags. <!-- If your PR needs to be included in the release notes for next release, add a separate entry in NEXT_CHANGELOG.md as part of your PR. -->
1 parent d689a0d commit 497bc87

File tree

4 files changed

+52
-8
lines changed

4 files changed

+52
-8
lines changed

experimental/ssh/internal/proxy/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"golang.org/x/sync/errgroup"
1111
)
1212

13-
func RunClientProxy(ctx context.Context, src io.Reader, dst io.Writer, requestHandoverTick func() <-chan time.Time, createConn createWebsocketConnectionFunc) error {
13+
func RunClientProxy(ctx context.Context, src io.ReadCloser, dst io.Writer, requestHandoverTick func() <-chan time.Time, createConn createWebsocketConnectionFunc) error {
1414
proxy := newProxyConnection(createConn)
1515
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
1616
g, gCtx := errgroup.WithContext(ctx)

experimental/ssh/internal/proxy/client_server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func createTestClient(t *testing.T, serverURL string, requestHandoverTick func()
4848
}
4949
go func() {
5050
err := RunClientProxy(ctx, clientInput, clientOutput, requestHandoverTick, createConn)
51-
if err != nil && !isNormalClosure(err) && !errors.Is(err, context.Canceled) {
51+
if err != nil && !errors.Is(err, context.Canceled) {
5252
if errChan != nil {
5353
errChan <- err
5454
} else {

experimental/ssh/internal/proxy/proxy.go

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"os"
910
"sync"
1011
"sync/atomic"
1112
"time"
@@ -15,7 +16,11 @@ import (
1516
"golang.org/x/sync/errgroup"
1617
)
1718

18-
var errProxyEOF = errors.New("proxy EOF error")
19+
var (
20+
errProxyEOF = errors.New("proxy EOF error")
21+
errSendingLoopStopped = errors.New("sending loop stopped")
22+
errReceivingLoopStopped = errors.New("receiving loop stopped")
23+
)
1924

2025
const (
2126
// Same as gorilla/websocket default read/write buffer sizes. Bigger payloads will be split into multiple ws frames.
@@ -94,6 +99,9 @@ type proxyConnection struct {
9499
handoverMutex sync.Mutex
95100
// Atomic that holds the current handover coordination channels, or nil if no handover is in progress.
96101
handoverState atomic.Pointer[handoverCoordination]
102+
// Channel that is closed when the initial connection is established (or failed).
103+
// Prevents race conditions where handover is accepted before the initial connection is ready.
104+
ready chan struct{}
97105
}
98106

99107
type createWebsocketConnectionFunc func(ctx context.Context, connID string) (*websocket.Conn, error)
@@ -102,16 +110,31 @@ func newProxyConnection(createConn createWebsocketConnectionFunc) *proxyConnecti
102110
return &proxyConnection{
103111
connID: uuid.NewString(),
104112
createWebsocketConnection: createConn,
113+
ready: make(chan struct{}),
105114
}
106115
}
107116

108-
func (pc *proxyConnection) start(ctx context.Context, src io.Reader, dst io.Writer) error {
117+
func (pc *proxyConnection) start(ctx context.Context, src io.ReadCloser, dst io.Writer) error {
109118
g, gCtx := errgroup.WithContext(ctx)
110119
g.Go(func() error {
111-
return pc.runSendingLoop(gCtx, src)
120+
err := pc.runSendingLoop(gCtx, src)
121+
// Always return a non nil error to cancel the errgroup context
122+
return errors.Join(err, errSendingLoopStopped)
123+
})
124+
g.Go(func() error {
125+
err := pc.runReceivingLoop(gCtx, dst)
126+
// Always return a non nil error to cancel the errgroup context
127+
return errors.Join(err, errReceivingLoopStopped)
112128
})
113129
g.Go(func() error {
114-
return pc.runReceivingLoop(gCtx, dst)
130+
// Wait for the context to be cancelled. There can be multiple reasons:
131+
// - Sending loop finished (e.g. EOF from source)
132+
// - Receiving loop finished (e.g. connection closed)
133+
// - Parent context cancelled
134+
// Both loops can still be stuck on conn.ReadMessage or src.Read and won't notice context cancellation,
135+
// so we close the connection and the source (sshd stdout pipe or ssh client stdio) to unblock them.
136+
<-gCtx.Done()
137+
return errors.Join(pc.close(), pc.closeSource(src))
115138
})
116139
err := g.Wait()
117140
if err == nil || isNormalClosure(err) {
@@ -121,6 +144,7 @@ func (pc *proxyConnection) start(ctx context.Context, src io.Reader, dst io.Writ
121144
}
122145

123146
func (pc *proxyConnection) connect(ctx context.Context) error {
147+
defer close(pc.ready)
124148
conn, err := pc.createWebsocketConnection(ctx, pc.connID)
125149
if err != nil {
126150
return err
@@ -130,6 +154,7 @@ func (pc *proxyConnection) connect(ctx context.Context) error {
130154
}
131155

132156
func (pc *proxyConnection) accept(w http.ResponseWriter, r *http.Request) error {
157+
defer close(pc.ready)
133158
conn, err := pc.acceptWebsocketConnection(w, r)
134159
if err != nil {
135160
return err
@@ -236,6 +261,14 @@ func (pc *proxyConnection) close() error {
236261
return nil
237262
}
238263

264+
func (pc *proxyConnection) closeSource(src io.ReadCloser) error {
265+
err := src.Close()
266+
if err != nil && (errors.Is(err, os.ErrClosed) || errors.Is(err, io.ErrClosedPipe)) {
267+
return nil
268+
}
269+
return err
270+
}
271+
239272
func (pc *proxyConnection) initiateHandover(ctx context.Context) error {
240273
// Blocks proxying any outgoing messages during the entire handover
241274
pc.handoverMutex.Lock()
@@ -283,6 +316,13 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
283316
pc.handoverMutex.Lock()
284317
defer pc.handoverMutex.Unlock()
285318

319+
// Wait for the initial connection to be ready
320+
select {
321+
case <-pc.ready:
322+
case <-ctx.Done():
323+
return ctx.Err()
324+
}
325+
286326
handoverCtx, cancel := context.WithTimeout(ctx, proxyHandoverAcceptTimeout)
287327
defer cancel()
288328
handoverState := &handoverCoordination{
@@ -303,6 +343,10 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
303343
// Signal the client to complete handover by closing the old connection.
304344
// Not using pc.sendMessage here, because it's blocked by the handover mutex.
305345
currentConn := pc.conn.Load()
346+
if currentConn == nil {
347+
newConn.Close()
348+
return errors.New("initial connection not established")
349+
}
306350
err = currentConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "handover"))
307351
if err != nil {
308352
newConn.Close()

experimental/ssh/internal/proxy/proxy_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func setupTestServer(ctx context.Context, t *testing.T) *TestProxy {
115115
}
116116
defer serverProxy.close()
117117
err = serverProxy.start(ctx, serverInput, serverOutput)
118-
if err != nil && !errors.Is(err, errProxyEOF) {
118+
if err != nil && !errors.Is(err, context.Canceled) {
119119
t.Errorf("server error: %v", err)
120120
return
121121
}
@@ -152,7 +152,7 @@ func setupTestClient(ctx context.Context, t *testing.T, serverURL string) *TestP
152152

153153
go func() {
154154
err := clientProxy.start(ctx, clientInput, clientOutput)
155-
if err != nil && !errors.Is(err, errProxyEOF) {
155+
if err != nil && !errors.Is(err, context.Canceled) {
156156
t.Errorf("proxy error: %v", err)
157157
}
158158
}()

0 commit comments

Comments
 (0)