@@ -18,28 +18,89 @@ import (
1818var errProxyEOF = errors .New ("proxy EOF error" )
1919
2020const (
21- proxyBufferSize = 4 * 1024 // Same as gorilla/websocket default read/write buffer sizes. Bigger payloads will be split into multiple ws frames.
22- proxyHandoverInitTimeout = 30 * time .Second
23- proxyHandoverAcceptTimeout = 25 * time .Second
24- proxyHandoverAckCloseConnTimeout = 15 * time .Second
21+ // Same as gorilla/websocket default read/write buffer sizes. Bigger payloads will be split into multiple ws frames.
22+ proxyBufferSize = 4 * 1024
23+ // Timeout for the full handover process, when initiated by the client.
24+ proxyHandoverInitTimeout = 30 * time .Second
25+ // Timeout for the handover process, when accepted by the server.
26+ proxyHandoverAcceptTimeout = 25 * time .Second
2527)
2628
29+ // handoverCoordination holds the context and channels used to coordinate a single handover operation
30+ // between the receiving loop and the handover initiator (initiateHandover or acceptHandover).
31+ type handoverCoordination struct {
32+ // Context with timeout for the entire handover operation.
33+ // Shared between the handover initiator and the receiving loop.
34+ ctx context.Context
35+ // Used by the receiving loop to signal about the closure of the current connection to the handover initiator.
36+ // After signalling, the receiving loop will block until connSwapped channel is signaled.
37+ connClosed chan error
38+ // Used by the handover initiator to signal the receiving loop that it's safe to start reading from the new connection.
39+ connSwapped chan struct {}
40+ }
41+
42+ func (c * handoverCoordination ) signalConnectionClosed (err error ) error {
43+ select {
44+ case c .connClosed <- err :
45+ return err
46+ case <- c .ctx .Done ():
47+ return c .ctx .Err ()
48+ }
49+ }
50+
51+ func (c * handoverCoordination ) waitForConnectionToClose () error {
52+ select {
53+ case err := <- c .connClosed :
54+ return err
55+ case <- c .ctx .Done ():
56+ return c .ctx .Err ()
57+ }
58+ }
59+
60+ func (c * handoverCoordination ) signalConnectionSwapped () error {
61+ select {
62+ case c .connSwapped <- struct {}{}:
63+ return nil
64+ case <- c .ctx .Done ():
65+ return c .ctx .Err ()
66+ }
67+ }
68+
69+ func (c * handoverCoordination ) waitForConnectionToSwap () error {
70+ select {
71+ case <- c .connSwapped :
72+ return nil
73+ case <- c .ctx .Done ():
74+ return c .ctx .Err ()
75+ }
76+ }
77+
78+ // proxyConnection is the main struct that manages the websocket connection and the handover process.
79+ // It works both on the client and the server side (see internal/client and internal/server packages).
80+ // It has 3 goroutines:
81+ // - Sending loop: reads from src and sends to the current connection.
82+ // - Receiving loop: reads from the current connection and writes to dst.
83+ // - Main: starts the other two (start method) and initiates or accepts handover (initiateHandover or acceptHandover).
2784type proxyConnection struct {
28- connID string
29- conn atomic.Value // *websocket.Conn
85+ // Each connection has a unique ID.
86+ connID string
87+ // Function to create a new websocket connection. Tests can override this to use a test websocket connection.
3088 createWebsocketConnection createWebsocketConnectionFunc
31-
32- handoverMutex sync.Mutex
33- isHandover atomic.Bool
34- currentConnectionClosed chan error
89+ // Atomic that keeps the currently active connection.
90+ // Can be swapped during handover.
91+ conn atomic.Pointer [websocket.Conn ]
92+ // Prevents multiple handover processes from running concurrently.
93+ // Blocks proxying any outgoing messages during the entire handover in the sending loop.
94+ handoverMutex sync.Mutex
95+ // Atomic that holds the current handover coordination channels, or nil if no handover is in progress.
96+ handoverState atomic.Pointer [handoverCoordination ]
3597}
3698
3799type createWebsocketConnectionFunc func (ctx context.Context , connID string ) (* websocket.Conn , error )
38100
39101func newProxyConnection (createConn createWebsocketConnectionFunc ) * proxyConnection {
40102 return & proxyConnection {
41103 connID : uuid .NewString (),
42- currentConnectionClosed : make (chan error ),
43104 createWebsocketConnection : createConn ,
44105 }
45106}
@@ -114,7 +175,7 @@ func (pc *proxyConnection) runSendingLoop(ctx context.Context, src io.Reader) er
114175func (pc * proxyConnection ) sendMessage (mt int , data []byte ) error {
115176 pc .handoverMutex .Lock ()
116177 defer pc .handoverMutex .Unlock ()
117- conn := pc .conn .Load ().( * websocket. Conn )
178+ conn := pc .conn .Load ()
118179 return conn .WriteMessage (mt , data )
119180}
120181
@@ -123,22 +184,23 @@ func (pc *proxyConnection) runReceivingLoop(ctx context.Context, dst io.Writer)
123184 if ctx .Err () != nil {
124185 return ctx .Err ()
125186 }
126- conn := pc .conn .Load ().( * websocket. Conn )
187+ conn := pc .conn .Load ()
127188 mt , data , err := conn .ReadMessage ()
128189 if err != nil {
129190 // During handover a normal closure is expected, but any other error must stop the read loop (and eventually terminate the ssh session).
130- if pc .isHandover .Load () {
191+ if handover := pc .handoverState .Load (); handover != nil {
131192 var closeConnSignal error
132193 if ! websocket .IsCloseError (err , websocket .CloseNormalClosure ) {
133194 closeConnSignal = fmt .Errorf ("failed to read from websocket during handover: %w" , err )
134195 }
135- if err := pc .signalClosedConnection (closeConnSignal ); err != nil {
196+ // Signal the current connection is closed to the handover initiator (initiateHandover or acceptHandover).
197+ if err := handover .signalConnectionClosed (closeConnSignal ); err != nil {
136198 return err
137199 }
138- // Next time we read, we want to read from the new connection.
200+ // Wait for the handover initiator to swap the connection.
139201 // While we wait for the handover to complete, the new connection might be getting incoming messages.
140202 // They will be buffered by the TCP stack and will be read by us after the handover is complete.
141- if err := pc . blockUntilHandoverComplete ( conn ); err != nil {
203+ if err := handover . waitForConnectionToSwap ( ); err != nil {
142204 return err
143205 }
144206 // Continue with the receiving loop, pc.conn is now the new connection.
@@ -161,25 +223,6 @@ func (pc *proxyConnection) runReceivingLoop(ctx context.Context, dst io.Writer)
161223 }
162224}
163225
164- func (pc * proxyConnection ) signalClosedConnection (err error ) error {
165- select {
166- case pc .currentConnectionClosed <- err :
167- return err
168- case <- time .After (proxyHandoverAckCloseConnTimeout ):
169- return fmt .Errorf ("timeout waiting for acknowledgement of old connection closed message: %w" , err )
170- }
171- }
172-
173- func (pc * proxyConnection ) blockUntilHandoverComplete (conn * websocket.Conn ) error {
174- pc .handoverMutex .Lock ()
175- defer pc .handoverMutex .Unlock ()
176- // Sanity check to ensure we are not in the middle of a handover - this should not happen.
177- if pc .conn .Load () == conn {
178- return errors .New ("handover mutex acquired while the old connection is still active" )
179- }
180- return nil
181- }
182-
183226func (pc * proxyConnection ) close () error {
184227 // Keep in mind that pc.sendMessage blocks during handover
185228 err := pc .sendMessage (websocket .CloseMessage , websocket .FormatCloseMessage (websocket .CloseNormalClosure , "" ))
@@ -198,35 +241,40 @@ func (pc *proxyConnection) initiateHandover(ctx context.Context) error {
198241 pc .handoverMutex .Lock ()
199242 defer pc .handoverMutex .Unlock ()
200243
201- // When handover flag is set, the receiving loop handles a close message from the current connection
202- // as a signal to finish the handover and switch to the new connection.
203- pc .isHandover .Store (true )
204- defer pc .isHandover .Store (false )
205-
206- ctx , cancel := context .WithTimeout (ctx , proxyHandoverInitTimeout )
244+ handoverCtx , cancel := context .WithTimeout (ctx , proxyHandoverInitTimeout )
207245 defer cancel ()
246+ handoverState := & handoverCoordination {
247+ ctx : handoverCtx ,
248+ connClosed : make (chan error ),
249+ connSwapped : make (chan struct {}),
250+ }
251+ // Existence of the handoverState indicates to the receiving loop that we are in the middle of a handover process,
252+ // and should treat close messages as a signal to finish the handover instead of erroring out.
253+ pc .handoverState .Store (handoverState )
254+ defer pc .handoverState .Store (nil )
208255
209256 // Create a new websocket connection by sending an /ssh?id=<connID> request to the server.
210257 // When server realises it's an ID of an existing connection, it will start AcceptHandover process.
211- newConn , err := pc .createWebsocketConnection (ctx , pc .connID )
258+ newConn , err := pc .createWebsocketConnection (handoverCtx , pc .connID )
212259 if err != nil {
213260 return fmt .Errorf ("failed to create new websocket connection: %w" , err )
214261 }
215262
216263 // Wait for the server to close the old connection
217- select {
218- case err := <- pc .currentConnectionClosed :
219- if err != nil {
220- newConn .Close ()
221- return fmt .Errorf ("connection handover failed: %w" , err )
222- }
223- case <- ctx .Done ():
264+ // (it does so when it receives an /ssh request with known connection ID and starts AcceptHandover process).
265+ // Receiving loop will signal about closed connection to the coord.connClosed channel.
266+ if err := handoverState .waitForConnectionToClose (); err != nil {
224267 newConn .Close ()
225- return ctx . Err ()
268+ return err
226269 }
227270
228271 pc .conn .Store (newConn )
229272
273+ // Let the receiving loop know that the current connection is swapped and it's safe to start reading from it.
274+ if err := handoverState .signalConnectionSwapped (); err != nil {
275+ newConn .Close ()
276+ return err
277+ }
230278 return nil
231279}
232280
@@ -235,13 +283,17 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
235283 pc .handoverMutex .Lock ()
236284 defer pc .handoverMutex .Unlock ()
237285
238- // When handover flag is set, the receiving loop handles a close message from the current connection
239- // as a signal to finish the handover and switch to the new connection.
240- pc .isHandover .Store (true )
241- defer pc .isHandover .Store (false )
242-
243- ctx , cancel := context .WithTimeout (ctx , proxyHandoverAcceptTimeout )
286+ handoverCtx , cancel := context .WithTimeout (ctx , proxyHandoverAcceptTimeout )
244287 defer cancel ()
288+ handoverState := & handoverCoordination {
289+ ctx : handoverCtx ,
290+ connClosed : make (chan error ),
291+ connSwapped : make (chan struct {}),
292+ }
293+ // Existence of the handoverState indicates to the receiving loop that we are in the middle of a handover process,
294+ // and should treat close messages as a signal to finish the handover instead of erroring out.
295+ pc .handoverState .Store (handoverState )
296+ defer pc .handoverState .Store (nil )
245297
246298 newConn , err := pc .acceptWebsocketConnection (w , r )
247299 if err != nil {
@@ -250,7 +302,7 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
250302
251303 // Signal the client to complete handover by closing the old connection.
252304 // Not using pc.sendMessage here, because it's blocked by the handover mutex.
253- currentConn := pc .conn .Load ().( * websocket. Conn )
305+ currentConn := pc .conn .Load ()
254306 err = currentConn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (websocket .CloseNormalClosure , "handover" ))
255307 if err != nil {
256308 newConn .Close ()
@@ -259,20 +311,20 @@ func (pc *proxyConnection) acceptHandover(ctx context.Context, w http.ResponseWr
259311
260312 // Wait for the client to acknowledge the closure of the old connection.
261313 // On the client its done automatically by the websocket library with the default close handler.
262- // On the server we then receive a close error in the RunReceivingLoop and signal about it to the handoverOldConnClosed channel.
263- select {
264- case err := <- pc .currentConnectionClosed :
265- if err != nil {
266- newConn .Close ()
267- return fmt .Errorf ("connection handover failed: %w" , err )
268- }
269- case <- ctx .Done ():
314+ // On the server we then receive a close error in the RunReceivingLoop and signal about it to the coord.connClosed channel.
315+ if err := handoverState .waitForConnectionToClose (); err != nil {
270316 newConn .Close ()
271- return ctx . Err ()
317+ return err
272318 }
273319
274320 pc .conn .Store (newConn )
275321
322+ // Let the receiving loop know that the current connection is swapped and it's safe to start reading from it.
323+ if err := handoverState .signalConnectionSwapped (); err != nil {
324+ newConn .Close ()
325+ return err
326+ }
327+
276328 return nil
277329}
278330
0 commit comments