@@ -45,13 +45,30 @@ type Conductor struct {
4545 pingTimeout time.Duration
4646 reconnectWait time.Duration
4747
48+ // writeMu protects concurrent writes to the WebSocket connection (pings + handling messages)
49+ writeMu sync.Mutex
50+
51+ // pingCancel cancels the ping goroutine context
52+ pingCancel context.CancelFunc
53+
4854 logger * slog.Logger
4955}
5056
5157// closeConn closes the connection and signals that reconnection is needed
5258func (c * Conductor ) closeConn () {
59+ // Cancel ping goroutine first
60+ if c .pingCancel != nil {
61+ c .pingCancel ()
62+ c .pingCancel = nil
63+ }
64+
5365 if c .conn != nil {
54- err := c .conn .Close ()
66+ c .conn .SetWriteDeadline (time .Now ().Add (5 * time .Second )) // Make sure the write doesn't block
67+ err := c .conn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (websocket .CloseNormalClosure , "shutting down" ))
68+ if err != nil {
69+ c .logger .Warn ("Failed to send close message" , "error" , err )
70+ }
71+ err = c .conn .Close ()
5572 if err != nil {
5673 c .logger .Warn ("Failed to close connection" , "error" , err )
5774 }
@@ -148,13 +165,7 @@ func (c *Conductor) run() {
148165 select {
149166 case <- c .dbosCtx .Done ():
150167 c .logger .Info ("DBOS context done, stopping conductor" , "cause" , context .Cause (c .dbosCtx ))
151- if c .conn != nil {
152- err := c .conn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (websocket .CloseNormalClosure , "shutting down" ))
153- if err != nil {
154- c .logger .Warn ("Failed to send close message" , "error" , err )
155- }
156- c .closeConn ()
157- }
168+ c .closeConn ()
158169 return
159170 default :
160171 }
@@ -180,13 +191,13 @@ func (c *Conductor) run() {
180191 c .needsReconnect .Store (false )
181192 }
182193
183- // Read message (will timeout based on read deadline set in connect)
194+ // This shouldn't happen but check anyway
184195 if c .conn == nil {
185- // This shouldn't happen but check anyway
186196 c .needsReconnect .Store (true )
187197 continue
188198 }
189199
200+ // Read message (will timeout based on read deadline set in connect)
190201 messageType , message , err := c .conn .ReadMessage ()
191202 if err != nil {
192203 if websocket .IsUnexpectedCloseError (err , websocket .CloseGoingAway , websocket .CloseAbnormalClosure ) {
@@ -246,6 +257,10 @@ func (c *Conductor) connect() error {
246257 // Store the connection
247258 c .conn = conn
248259
260+ // Create a cancellable context for the ping goroutine
261+ pingCtx , pingCancel := context .WithCancel (c .dbosCtx )
262+ c .pingCancel = pingCancel
263+
249264 // Start ping goroutine
250265 c .wg .Add (1 )
251266 go func () {
@@ -255,8 +270,8 @@ func (c *Conductor) connect() error {
255270
256271 for {
257272 select {
258- case <- c . dbosCtx .Done ():
259- c .logger .Debug ("Exiting Conductor ping goroutine" , "cause" , context .Cause (c . dbosCtx ))
273+ case <- pingCtx .Done ():
274+ c .logger .Debug ("Exiting Conductor ping goroutine" , "cause" , context .Cause (pingCtx ))
260275 return
261276 case <- ticker .C :
262277 if err := c .ping (); err != nil {
@@ -281,9 +296,14 @@ func (c *Conductor) ping() error {
281296
282297 c .logger .Debug ("Sending ping to conductor" )
283298
299+ c .writeMu .Lock ()
300+ defer c .writeMu .Unlock ()
301+
302+ c .conn .SetWriteDeadline (time .Now ().Add (5 * time .Second )) // Make sure the write doesn't block
284303 if err := c .conn .WriteMessage (websocket .PingMessage , nil ); err != nil {
285304 return fmt .Errorf ("failed to send ping: %w" , err )
286305 }
306+ c .conn .SetWriteDeadline (time.Time {}) // Clear the write deadline
287307
288308 return nil
289309}
@@ -548,10 +568,15 @@ func (c *Conductor) sendResponse(response any, responseType string) error {
548568
549569 c .logger .Debug ("Sending response" , "type" , responseType )
550570
571+ c .writeMu .Lock ()
572+ defer c .writeMu .Unlock ()
573+
574+ c .conn .SetWriteDeadline (time .Now ().Add (5 * time .Second )) // Make sure the write doesn't block
551575 if err := c .conn .WriteMessage (websocket .TextMessage , data ); err != nil {
552576 c .logger .Error ("Failed to send response" , "type" , responseType , "error" , err )
553577 return fmt .Errorf ("failed to send message: %w" , err )
554578 }
579+ c .conn .SetWriteDeadline (time.Time {}) // Clear the write deadline
555580
556581 return nil
557582}
@@ -944,10 +969,15 @@ func (c *Conductor) handleUnknownMessageType(requestID string, msgType messageTy
944969
945970 c .logger .Debug ("Sending error response" , "data" , response )
946971
972+ c .writeMu .Lock ()
973+ defer c .writeMu .Unlock ()
974+
975+ c .conn .SetWriteDeadline (time .Now ().Add (5 * time .Second )) // Make sure the write doesn't block
947976 if err := c .conn .WriteMessage (websocket .TextMessage , data ); err != nil {
948977 c .logger .Error ("Failed to send error response" , "error" , err )
949978 return fmt .Errorf ("failed to send message: %w" , err )
950979 }
980+ c .conn .SetWriteDeadline (time.Time {}) // Clear the write deadline
951981
952982 return nil
953983}
0 commit comments