Skip to content

Commit ef07ad4

Browse files
committed
serialize conductor test server
1 parent cbf578d commit ef07ad4

File tree

1 file changed

+158
-44
lines changed

1 file changed

+158
-44
lines changed

dbos/conductor_test.go

Lines changed: 158 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,23 @@ import (
1919
"go.uber.org/goleak"
2020
)
2121

22+
// writeCommand represents a command to write to the WebSocket connection
23+
type writeCommand struct {
24+
messageType int
25+
data []byte
26+
response chan error // Channel to send back the result
27+
}
28+
2229
// mockWebSocketServer provides a controllable WebSocket server for testing
2330
type mockWebSocketServer struct {
2431
server *httptest.Server
2532
upgrader websocket.Upgrader
26-
connMu sync.Mutex
33+
connMu sync.Mutex // Only for connection assignment/reassignment
2734
conn *websocket.Conn
2835
closed atomic.Bool
2936
messages chan []byte
3037
pings chan struct{}
38+
writeCmds chan writeCommand // Channel for write commands
3139
stopHandler chan struct{}
3240
ignorePings atomic.Bool // When true, don't respond with pongs
3341
}
@@ -37,6 +45,7 @@ func newMockWebSocketServer() *mockWebSocketServer {
3745
upgrader: websocket.Upgrader{},
3846
messages: make(chan []byte, 100),
3947
pings: make(chan struct{}, 100),
48+
writeCmds: make(chan writeCommand, 10),
4049
stopHandler: make(chan struct{}),
4150
}
4251

@@ -56,6 +65,7 @@ func (m *mockWebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Req
5665
return
5766
}
5867

68+
// Connection assignment - this is the only place we need mutex
5969
m.connMu.Lock()
6070
// Close any existing connection
6171
if m.conn != nil {
@@ -64,36 +74,86 @@ func (m *mockWebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Req
6474
m.conn = conn
6575
m.connMu.Unlock()
6676

67-
// Set up ping handler to capture pings
77+
// Ensure the connection gets cleared when this handler exits
78+
defer func() {
79+
m.connMu.Lock()
80+
if m.conn == conn {
81+
m.conn = nil
82+
}
83+
m.connMu.Unlock()
84+
conn.Close()
85+
}()
86+
87+
// Handle connection lifecycle - this function owns all I/O on conn
88+
fmt.Println("WebSocket connection established")
89+
defer fmt.Println("WebSocket connection handler exiting")
90+
91+
// We need to handle pings manually since we can't use the ping handler
92+
// (it would cause concurrent writes with our main loop)
93+
pingReceived := make(chan struct{}, 10)
94+
95+
// Custom ping handler that just signals - no writing
6896
conn.SetPingHandler(func(string) error {
97+
fmt.Println("received ping")
6998
select {
7099
case m.pings <- struct{}{}:
71100
default:
72101
}
73-
// Only send pong if not ignoring pings
74-
if !m.ignorePings.Load() {
75-
return conn.WriteMessage(websocket.PongMessage, nil)
102+
select {
103+
case pingReceived <- struct{}{}:
104+
default:
76105
}
77106
return nil
78107
})
79108

80-
// Read messages until connection is closed
109+
// Start dedicated read goroutine - only reads, never writes
110+
readDone := make(chan error, 1)
111+
go func() {
112+
defer close(readDone)
113+
for {
114+
_, _, err := conn.ReadMessage()
115+
if err != nil {
116+
fmt.Printf("WebSocket read error: %v\n", err)
117+
readDone <- err
118+
return
119+
}
120+
}
121+
}()
122+
123+
// Main write loop - all writes happen here sequentially
81124
for {
82125
select {
83126
case <-m.stopHandler:
127+
fmt.Println("WebSocket connection closed by stop signal")
84128
return
85-
default:
86-
}
87129

88-
_, message, err := conn.ReadMessage()
89-
if err != nil {
90-
// Connection closed
130+
case err := <-readDone:
131+
fmt.Printf("WebSocket connection closed by read error: %v\n", err)
91132
return
92-
}
93133

94-
select {
95-
case m.messages <- message:
96-
default:
134+
case writeCmd := <-m.writeCmds:
135+
// Handle write command
136+
err := conn.WriteMessage(writeCmd.messageType, writeCmd.data)
137+
if writeCmd.response != nil {
138+
select {
139+
case writeCmd.response <- err:
140+
default:
141+
}
142+
}
143+
if err != nil {
144+
fmt.Printf("WebSocket write error: %v\n", err)
145+
return
146+
}
147+
148+
case <-pingReceived:
149+
// Handle ping response (send pong)
150+
if !m.ignorePings.Load() {
151+
err := conn.WriteMessage(websocket.PongMessage, nil)
152+
if err != nil {
153+
fmt.Printf("WebSocket pong write error: %v\n", err)
154+
return
155+
}
156+
}
97157
}
98158
}
99159
}
@@ -111,12 +171,16 @@ func (m *mockWebSocketServer) close() {
111171
default:
112172
}
113173

114-
m.connMu.Lock()
115-
if m.conn != nil {
116-
m.conn.Close()
117-
m.conn = nil
118-
}
119-
m.connMu.Unlock()
174+
// Connection will be closed by the handler when it receives stop signal
175+
// We just need to clear our reference after a brief delay
176+
go func() {
177+
time.Sleep(100 * time.Millisecond)
178+
m.connMu.Lock()
179+
if m.conn != nil {
180+
m.conn = nil // Just clear reference, handler already closed
181+
}
182+
m.connMu.Unlock()
183+
}()
120184
}
121185

122186
func (m *mockWebSocketServer) shutdown() {
@@ -127,11 +191,26 @@ func (m *mockWebSocketServer) shutdown() {
127191
func (m *mockWebSocketServer) restart() {
128192
// Reset for new connections
129193
m.closed.Store(false)
130-
// Drain and recreate stop handler channel
194+
// Drain stop handler channel and write command channel
131195
select {
132196
case <-m.stopHandler:
133197
default:
134198
}
199+
// Drain any pending write commands
200+
drainLoop:
201+
for {
202+
select {
203+
case cmd := <-m.writeCmds:
204+
if cmd.response != nil {
205+
select {
206+
case cmd.response <- fmt.Errorf("server restarting"):
207+
default:
208+
}
209+
}
210+
default:
211+
break drainLoop
212+
}
213+
}
135214
}
136215

137216
func (m *mockWebSocketServer) waitForConnection(timeout time.Duration) bool {
@@ -150,47 +229,78 @@ func (m *mockWebSocketServer) waitForConnection(timeout time.Duration) bool {
150229

151230
// sendBinaryMessage sends a binary WebSocket message to the connected client
152231
func (m *mockWebSocketServer) sendBinaryMessage(data []byte) error {
232+
// Check if we have a connection without blocking
153233
m.connMu.Lock()
154-
defer m.connMu.Unlock()
234+
hasConn := m.conn != nil
235+
m.connMu.Unlock()
155236

156-
if m.conn == nil {
237+
if !hasConn {
157238
return fmt.Errorf("no connection")
158239
}
159240

160-
return m.conn.WriteMessage(websocket.BinaryMessage, data)
161-
}
162-
163-
// sendTextMessage sends a text WebSocket message to the connected client
164-
func (m *mockWebSocketServer) sendTextMessage(data string) error {
165-
m.connMu.Lock()
166-
defer m.connMu.Unlock()
167-
168-
if m.conn == nil {
169-
return fmt.Errorf("no connection")
241+
// Send write command via channel
242+
response := make(chan error, 1)
243+
cmd := writeCommand{
244+
messageType: websocket.BinaryMessage,
245+
data: data,
246+
response: response,
170247
}
171248

172-
return m.conn.WriteMessage(websocket.TextMessage, []byte(data))
249+
select {
250+
case m.writeCmds <- cmd:
251+
// Wait for response
252+
select {
253+
case err := <-response:
254+
return err
255+
case <-time.After(1 * time.Second):
256+
return fmt.Errorf("write timeout")
257+
}
258+
case <-time.After(1 * time.Second):
259+
return fmt.Errorf("write command queue full")
260+
}
173261
}
174262

175263
// sendCloseMessage sends a WebSocket close message with specified code and reason
176264
func (m *mockWebSocketServer) sendCloseMessage(code int, text string) error {
265+
// Check if we have a connection without blocking
177266
m.connMu.Lock()
178-
defer m.connMu.Unlock()
267+
hasConn := m.conn != nil
268+
m.connMu.Unlock()
179269

180-
if m.conn == nil {
270+
if !hasConn {
181271
return fmt.Errorf("no connection")
182272
}
183273

184-
// Write the close message
274+
// Format close message
185275
message := websocket.FormatCloseMessage(code, text)
186-
err := m.conn.WriteMessage(websocket.CloseMessage, message)
187276

188-
// After sending close, we should close the connection
189-
// This prevents concurrent writes and simulates proper WebSocket close behavior
190-
m.conn.Close()
191-
m.conn = nil
277+
// Send write command via channel
278+
response := make(chan error, 1)
279+
cmd := writeCommand{
280+
messageType: websocket.CloseMessage,
281+
data: message,
282+
response: response,
283+
}
192284

193-
return err
285+
select {
286+
case m.writeCmds <- cmd:
287+
// Wait for response
288+
select {
289+
case err := <-response:
290+
// After sending close, close the connection from our side too
291+
m.connMu.Lock()
292+
if m.conn != nil {
293+
m.conn.Close()
294+
m.conn = nil
295+
}
296+
m.connMu.Unlock()
297+
return err
298+
case <-time.After(1 * time.Second):
299+
return fmt.Errorf("write timeout")
300+
}
301+
case <-time.After(1 * time.Second):
302+
return fmt.Errorf("write command queue full")
303+
}
194304
}
195305

196306
// TestConductorReconnection tests various reconnection scenarios for the conductor
@@ -518,6 +628,10 @@ func TestConductorReconnection(t *testing.T) {
518628
for _, tc := range testCases {
519629
t.Logf("Testing %s (code %d)", tc.name, tc.code)
520630

631+
// Wait for stable connection before testing
632+
assert.True(t, mockServer.waitForConnection(5*time.Second), "Should have stable connection before %s", tc.name)
633+
time.Sleep(300 * time.Millisecond) // Give time for ping cycle to establish
634+
521635
// Collect pings before sending close message
522636
beforePings := 0
523637
timeout := time.After(200 * time.Millisecond)

0 commit comments

Comments
 (0)