@@ -19,15 +19,23 @@ import (
1919 "go.uber.org/goleak"
2020)
2121
22+ // writeCommand represents a command to write to the WebSocket connection
23+ type writeCommand struct {
24+ messageType int
25+ data []byte
26+ response chan error // Channel to send back the result
27+ }
28+
2229// mockWebSocketServer provides a controllable WebSocket server for testing
2330type mockWebSocketServer struct {
2431 server * httptest.Server
2532 upgrader websocket.Upgrader
26- connMu sync.Mutex
33+ connMu sync.Mutex // Only for connection assignment/reassignment
2734 conn * websocket.Conn
2835 closed atomic.Bool
2936 messages chan []byte
3037 pings chan struct {}
38+ writeCmds chan writeCommand // Channel for write commands
3139 stopHandler chan struct {}
3240 ignorePings atomic.Bool // When true, don't respond with pongs
3341}
@@ -37,6 +45,7 @@ func newMockWebSocketServer() *mockWebSocketServer {
3745 upgrader : websocket.Upgrader {},
3846 messages : make (chan []byte , 100 ),
3947 pings : make (chan struct {}, 100 ),
48+ writeCmds : make (chan writeCommand , 10 ),
4049 stopHandler : make (chan struct {}),
4150 }
4251
@@ -56,6 +65,7 @@ func (m *mockWebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Req
5665 return
5766 }
5867
68+ // Connection assignment - this is the only place we need mutex
5969 m .connMu .Lock ()
6070 // Close any existing connection
6171 if m .conn != nil {
@@ -64,36 +74,86 @@ func (m *mockWebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Req
6474 m .conn = conn
6575 m .connMu .Unlock ()
6676
67- // Set up ping handler to capture pings
77+ // Ensure the connection gets cleared when this handler exits
78+ defer func () {
79+ m .connMu .Lock ()
80+ if m .conn == conn {
81+ m .conn = nil
82+ }
83+ m .connMu .Unlock ()
84+ conn .Close ()
85+ }()
86+
87+ // Handle connection lifecycle - this function owns all I/O on conn
88+ fmt .Println ("WebSocket connection established" )
89+ defer fmt .Println ("WebSocket connection handler exiting" )
90+
91+ // We need to handle pings manually since we can't use the ping handler
92+ // (it would cause concurrent writes with our main loop)
93+ pingReceived := make (chan struct {}, 10 )
94+
95+ // Custom ping handler that just signals - no writing
6896 conn .SetPingHandler (func (string ) error {
97+ fmt .Println ("received ping" )
6998 select {
7099 case m .pings <- struct {}{}:
71100 default :
72101 }
73- // Only send pong if not ignoring pings
74- if ! m . ignorePings . Load () {
75- return conn . WriteMessage ( websocket . PongMessage , nil )
102+ select {
103+ case pingReceived <- struct {}{}:
104+ default :
76105 }
77106 return nil
78107 })
79108
80- // Read messages until connection is closed
109+ // Start dedicated read goroutine - only reads, never writes
110+ readDone := make (chan error , 1 )
111+ go func () {
112+ defer close (readDone )
113+ for {
114+ _ , _ , err := conn .ReadMessage ()
115+ if err != nil {
116+ fmt .Printf ("WebSocket read error: %v\n " , err )
117+ readDone <- err
118+ return
119+ }
120+ }
121+ }()
122+
123+ // Main write loop - all writes happen here sequentially
81124 for {
82125 select {
83126 case <- m .stopHandler :
127+ fmt .Println ("WebSocket connection closed by stop signal" )
84128 return
85- default :
86- }
87129
88- _ , message , err := conn .ReadMessage ()
89- if err != nil {
90- // Connection closed
130+ case err := <- readDone :
131+ fmt .Printf ("WebSocket connection closed by read error: %v\n " , err )
91132 return
92- }
93133
94- select {
95- case m .messages <- message :
96- default :
134+ case writeCmd := <- m .writeCmds :
135+ // Handle write command
136+ err := conn .WriteMessage (writeCmd .messageType , writeCmd .data )
137+ if writeCmd .response != nil {
138+ select {
139+ case writeCmd .response <- err :
140+ default :
141+ }
142+ }
143+ if err != nil {
144+ fmt .Printf ("WebSocket write error: %v\n " , err )
145+ return
146+ }
147+
148+ case <- pingReceived :
149+ // Handle ping response (send pong)
150+ if ! m .ignorePings .Load () {
151+ err := conn .WriteMessage (websocket .PongMessage , nil )
152+ if err != nil {
153+ fmt .Printf ("WebSocket pong write error: %v\n " , err )
154+ return
155+ }
156+ }
97157 }
98158 }
99159}
@@ -111,12 +171,16 @@ func (m *mockWebSocketServer) close() {
111171 default :
112172 }
113173
114- m .connMu .Lock ()
115- if m .conn != nil {
116- m .conn .Close ()
117- m .conn = nil
118- }
119- m .connMu .Unlock ()
174+ // Connection will be closed by the handler when it receives stop signal
175+ // We just need to clear our reference after a brief delay
176+ go func () {
177+ time .Sleep (100 * time .Millisecond )
178+ m .connMu .Lock ()
179+ if m .conn != nil {
180+ m .conn = nil // Just clear reference, handler already closed
181+ }
182+ m .connMu .Unlock ()
183+ }()
120184}
121185
122186func (m * mockWebSocketServer ) shutdown () {
@@ -127,11 +191,26 @@ func (m *mockWebSocketServer) shutdown() {
127191func (m * mockWebSocketServer ) restart () {
128192 // Reset for new connections
129193 m .closed .Store (false )
130- // Drain and recreate stop handler channel
194+ // Drain stop handler channel and write command channel
131195 select {
132196 case <- m .stopHandler :
133197 default :
134198 }
199+ // Drain any pending write commands
200+ drainLoop:
201+ for {
202+ select {
203+ case cmd := <- m .writeCmds :
204+ if cmd .response != nil {
205+ select {
206+ case cmd .response <- fmt .Errorf ("server restarting" ):
207+ default :
208+ }
209+ }
210+ default :
211+ break drainLoop
212+ }
213+ }
135214}
136215
137216func (m * mockWebSocketServer ) waitForConnection (timeout time.Duration ) bool {
@@ -150,47 +229,78 @@ func (m *mockWebSocketServer) waitForConnection(timeout time.Duration) bool {
150229
151230// sendBinaryMessage sends a binary WebSocket message to the connected client
152231func (m * mockWebSocketServer ) sendBinaryMessage (data []byte ) error {
232+ // Check if we have a connection without blocking
153233 m .connMu .Lock ()
154- defer m .connMu .Unlock ()
234+ hasConn := m .conn != nil
235+ m .connMu .Unlock ()
155236
156- if m . conn == nil {
237+ if ! hasConn {
157238 return fmt .Errorf ("no connection" )
158239 }
159240
160- return m .conn .WriteMessage (websocket .BinaryMessage , data )
161- }
162-
163- // sendTextMessage sends a text WebSocket message to the connected client
164- func (m * mockWebSocketServer ) sendTextMessage (data string ) error {
165- m .connMu .Lock ()
166- defer m .connMu .Unlock ()
167-
168- if m .conn == nil {
169- return fmt .Errorf ("no connection" )
241+ // Send write command via channel
242+ response := make (chan error , 1 )
243+ cmd := writeCommand {
244+ messageType : websocket .BinaryMessage ,
245+ data : data ,
246+ response : response ,
170247 }
171248
172- return m .conn .WriteMessage (websocket .TextMessage , []byte (data ))
249+ select {
250+ case m .writeCmds <- cmd :
251+ // Wait for response
252+ select {
253+ case err := <- response :
254+ return err
255+ case <- time .After (1 * time .Second ):
256+ return fmt .Errorf ("write timeout" )
257+ }
258+ case <- time .After (1 * time .Second ):
259+ return fmt .Errorf ("write command queue full" )
260+ }
173261}
174262
175263// sendCloseMessage sends a WebSocket close message with specified code and reason
176264func (m * mockWebSocketServer ) sendCloseMessage (code int , text string ) error {
265+ // Check if we have a connection without blocking
177266 m .connMu .Lock ()
178- defer m .connMu .Unlock ()
267+ hasConn := m .conn != nil
268+ m .connMu .Unlock ()
179269
180- if m . conn == nil {
270+ if ! hasConn {
181271 return fmt .Errorf ("no connection" )
182272 }
183273
184- // Write the close message
274+ // Format close message
185275 message := websocket .FormatCloseMessage (code , text )
186- err := m .conn .WriteMessage (websocket .CloseMessage , message )
187276
188- // After sending close, we should close the connection
189- // This prevents concurrent writes and simulates proper WebSocket close behavior
190- m .conn .Close ()
191- m .conn = nil
277+ // Send write command via channel
278+ response := make (chan error , 1 )
279+ cmd := writeCommand {
280+ messageType : websocket .CloseMessage ,
281+ data : message ,
282+ response : response ,
283+ }
192284
193- return err
285+ select {
286+ case m .writeCmds <- cmd :
287+ // Wait for response
288+ select {
289+ case err := <- response :
290+ // After sending close, close the connection from our side too
291+ m .connMu .Lock ()
292+ if m .conn != nil {
293+ m .conn .Close ()
294+ m .conn = nil
295+ }
296+ m .connMu .Unlock ()
297+ return err
298+ case <- time .After (1 * time .Second ):
299+ return fmt .Errorf ("write timeout" )
300+ }
301+ case <- time .After (1 * time .Second ):
302+ return fmt .Errorf ("write command queue full" )
303+ }
194304}
195305
196306// TestConductorReconnection tests various reconnection scenarios for the conductor
@@ -518,6 +628,10 @@ func TestConductorReconnection(t *testing.T) {
518628 for _ , tc := range testCases {
519629 t .Logf ("Testing %s (code %d)" , tc .name , tc .code )
520630
631+ // Wait for stable connection before testing
632+ assert .True (t , mockServer .waitForConnection (5 * time .Second ), "Should have stable connection before %s" , tc .name )
633+ time .Sleep (300 * time .Millisecond ) // Give time for ping cycle to establish
634+
521635 // Collect pings before sending close message
522636 beforePings := 0
523637 timeout := time .After (200 * time .Millisecond )
0 commit comments