Skip to content

Commit a65a5c1

Browse files
committed
have a cancel method for the ping go routine so we always have one, plus set a write deadline
1 parent 7852609 commit a65a5c1

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

dbos/conductor.go

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,30 @@ type Conductor struct {
4545
pingTimeout time.Duration
4646
reconnectWait time.Duration
4747

48+
// writeMu protects concurrent writes to the WebSocket connection (pings + handling messages)
49+
writeMu sync.Mutex
50+
51+
// pingCancel cancels the ping goroutine context
52+
pingCancel context.CancelFunc
53+
4854
logger *slog.Logger
4955
}
5056

5157
// closeConn closes the connection and signals that reconnection is needed
5258
func (c *Conductor) closeConn() {
59+
// Cancel ping goroutine first
60+
if c.pingCancel != nil {
61+
c.pingCancel()
62+
c.pingCancel = nil
63+
}
64+
5365
if c.conn != nil {
54-
err := c.conn.Close()
66+
c.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // Make sure the write doesn't block
67+
err := c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"))
68+
if err != nil {
69+
c.logger.Warn("Failed to send close message", "error", err)
70+
}
71+
err = c.conn.Close()
5572
if err != nil {
5673
c.logger.Warn("Failed to close connection", "error", err)
5774
}
@@ -148,13 +165,7 @@ func (c *Conductor) run() {
148165
select {
149166
case <-c.dbosCtx.Done():
150167
c.logger.Info("DBOS context done, stopping conductor", "cause", context.Cause(c.dbosCtx))
151-
if c.conn != nil {
152-
err := c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down"))
153-
if err != nil {
154-
c.logger.Warn("Failed to send close message", "error", err)
155-
}
156-
c.closeConn()
157-
}
168+
c.closeConn()
158169
return
159170
default:
160171
}
@@ -180,13 +191,13 @@ func (c *Conductor) run() {
180191
c.needsReconnect.Store(false)
181192
}
182193

183-
// Read message (will timeout based on read deadline set in connect)
194+
// This shouldn't happen but check anyway
184195
if c.conn == nil {
185-
// This shouldn't happen but check anyway
186196
c.needsReconnect.Store(true)
187197
continue
188198
}
189199

200+
// Read message (will timeout based on read deadline set in connect)
190201
messageType, message, err := c.conn.ReadMessage()
191202
if err != nil {
192203
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
@@ -246,6 +257,10 @@ func (c *Conductor) connect() error {
246257
// Store the connection
247258
c.conn = conn
248259

260+
// Create a cancellable context for the ping goroutine
261+
pingCtx, pingCancel := context.WithCancel(c.dbosCtx)
262+
c.pingCancel = pingCancel
263+
249264
// Start ping goroutine
250265
c.wg.Add(1)
251266
go func() {
@@ -255,8 +270,8 @@ func (c *Conductor) connect() error {
255270

256271
for {
257272
select {
258-
case <-c.dbosCtx.Done():
259-
c.logger.Debug("Exiting Conductor ping goroutine", "cause", context.Cause(c.dbosCtx))
273+
case <-pingCtx.Done():
274+
c.logger.Debug("Exiting Conductor ping goroutine", "cause", context.Cause(pingCtx))
260275
return
261276
case <-ticker.C:
262277
if err := c.ping(); err != nil {
@@ -281,9 +296,14 @@ func (c *Conductor) ping() error {
281296

282297
c.logger.Debug("Sending ping to conductor")
283298

299+
c.writeMu.Lock()
300+
defer c.writeMu.Unlock()
301+
302+
c.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // Make sure the write doesn't block
284303
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
285304
return fmt.Errorf("failed to send ping: %w", err)
286305
}
306+
c.conn.SetWriteDeadline(time.Time{}) // Clear the write deadline
287307

288308
return nil
289309
}
@@ -548,10 +568,15 @@ func (c *Conductor) sendResponse(response any, responseType string) error {
548568

549569
c.logger.Debug("Sending response", "type", responseType)
550570

571+
c.writeMu.Lock()
572+
defer c.writeMu.Unlock()
573+
574+
c.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // Make sure the write doesn't block
551575
if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
552576
c.logger.Error("Failed to send response", "type", responseType, "error", err)
553577
return fmt.Errorf("failed to send message: %w", err)
554578
}
579+
c.conn.SetWriteDeadline(time.Time{}) // Clear the write deadline
555580

556581
return nil
557582
}
@@ -944,10 +969,15 @@ func (c *Conductor) handleUnknownMessageType(requestID string, msgType messageTy
944969

945970
c.logger.Debug("Sending error response", "data", response)
946971

972+
c.writeMu.Lock()
973+
defer c.writeMu.Unlock()
974+
975+
c.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) // Make sure the write doesn't block
947976
if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
948977
c.logger.Error("Failed to send error response", "error", err)
949978
return fmt.Errorf("failed to send message: %w", err)
950979
}
980+
c.conn.SetWriteDeadline(time.Time{}) // Clear the write deadline
951981

952982
return nil
953983
}

0 commit comments

Comments
 (0)