Skip to content

Commit 130407d

Browse files
committed
fix: guard websocket reconnects against stale sockets
1 parent 5b8184c commit 130407d

File tree

2 files changed

+148
-42
lines changed

2 files changed

+148
-42
lines changed

packages/loro-websocket/src/client/index.ts

Lines changed: 132 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ interface ActiveRoom {
4646
handler: InternalRoomHandler;
4747
}
4848

49+
interface SocketListeners {
50+
open: () => void;
51+
error: (event: Event) => void;
52+
close: () => void;
53+
message: (event: MessageEvent<string | ArrayBuffer>) => void;
54+
}
55+
4956
/**
5057
* The websocket client's high-level connection status.
5158
* - `Connecting`: initial connect or a manual `connect()` in progress.
@@ -114,6 +121,7 @@ export class LoroWebsocketClient {
114121
// Track roomId for each active id so we can rejoin on reconnect
115122
private roomIds: Map<string, string> = new Map();
116123
private roomAuth: Map<string, Uint8Array | undefined> = new Map();
124+
private socketListeners = new WeakMap<WebSocket, SocketListeners>();
117125

118126
private pingTimer?: ReturnType<typeof setInterval>;
119127
private pingWaiters: Array<{
@@ -195,7 +203,13 @@ export class LoroWebsocketClient {
195203
async connect(): Promise<void> {
196204
// Ensure future unexpected closes will auto-reconnect again
197205
this.shouldReconnect = true;
198-
if (this.ws && this.ws.readyState === WebSocket.OPEN) return; // already connected
206+
const current = this.ws;
207+
if (current) {
208+
const state = current.readyState;
209+
if (state === WebSocket.OPEN || state === WebSocket.CONNECTING) {
210+
return this.connectedPromise;
211+
}
212+
}
199213
this.clearReconnectTimer();
200214
this.setStatus(
201215
this.reconnectAttempts > 0
@@ -206,55 +220,75 @@ export class LoroWebsocketClient {
206220
// Reset the connected promise for this attempt
207221
this.connectedPromise = this.createConnectedPromise();
208222

209-
this.ws = new WebSocket(this.ops.url);
223+
const ws = new WebSocket(this.ops.url);
224+
this.ws = ws;
225+
226+
if (current && current !== ws) {
227+
this.detachSocketListeners(current);
228+
}
229+
230+
this.attachSocketListeners(ws);
231+
210232
try {
211-
this.ws.binaryType = "arraybuffer";
233+
ws.binaryType = "arraybuffer";
212234
} catch {}
213235

214-
this.ws.addEventListener("open", this.handleOpen);
215-
this.ws.addEventListener("error", this.handleError);
216-
this.ws.addEventListener("close", this.handleClose);
217-
this.ws.addEventListener("message", this.handleWsMessage);
218-
219236
return this.connectedPromise;
220237
}
221238

222-
private handleWsMessage = async (
223-
event: MessageEvent<string | ArrayBuffer>
224-
) => {
225-
if (typeof event.data === "string") {
226-
if (event.data === "ping") {
227-
try {
228-
this.ws.send("pong");
229-
} catch {}
230-
return;
231-
}
232-
if (event.data === "pong") {
233-
this.handlePong();
234-
return;
235-
}
236-
return; // ignore other texts
237-
}
238-
const dataU8 = new Uint8Array(event.data);
239-
const msg = tryDecode(dataU8);
240-
if (msg != null) await this.handleMessage(msg);
241-
};
239+
private attachSocketListeners(ws: WebSocket): void {
240+
const open = () => this.onSocketOpen(ws);
241+
const error = (event: Event) => this.onSocketError(ws, event);
242+
const close = () => this.onSocketClose(ws);
243+
const message = (event: MessageEvent<string | ArrayBuffer>) => {
244+
void this.onSocketMessage(ws, event);
245+
};
246+
247+
ws.addEventListener("open", open);
248+
ws.addEventListener("error", error);
249+
ws.addEventListener("close", close);
250+
ws.addEventListener("message", message);
242251

243-
private handleOpen = () => {
252+
this.socketListeners.set(ws, {
253+
open,
254+
error,
255+
close,
256+
message,
257+
});
258+
}
259+
260+
private onSocketOpen(ws: WebSocket): void {
261+
if (ws !== this.ws) {
262+
// TODO: REVIEW stale sockets bail early so they can't tear down the new connection
263+
this.detachSocketListeners(ws);
264+
try {
265+
ws.close(1000, "Superseded");
266+
} catch {}
267+
return;
268+
}
244269
this.clearReconnectTimer();
245270
this.reconnectAttempts = 0;
246271
this.setStatus(ClientStatus.Connected);
247272
this.startPingTimer();
248273
this.resolveConnected?.();
249274
// Rejoin rooms after reconnect
250275
this.rejoinActiveRooms();
251-
};
276+
}
252277

253-
private handleError = () => {
254-
// Leave for now; close event will drive reconnection
255-
};
278+
private onSocketError(ws: WebSocket, _event: Event): void {
279+
if (ws !== this.ws) {
280+
this.detachSocketListeners(ws);
281+
}
282+
// Leave further handling to the close event for the active socket
283+
}
284+
285+
private onSocketClose(ws: WebSocket): void {
286+
const isCurrent = ws === this.ws;
287+
this.detachSocketListeners(ws);
288+
if (!isCurrent) {
289+
return;
290+
}
256291

257-
private handleClose = () => {
258292
this.clearPingTimer();
259293
// Clear any pending fragment reassembly timers to avoid late callbacks
260294
if (this.fragmentBatches.size) {
@@ -269,7 +303,6 @@ export class LoroWebsocketClient {
269303
this.awaitingPongSince = undefined;
270304
this.ops.onWsClose?.();
271305
this.rejectAllPingWaiters(new Error("WebSocket closed"));
272-
this.detachSocketListeners(this.ws);
273306
if (!this.shouldReconnect) {
274307
this.setStatus(ClientStatus.Disconnected);
275308
this.rejectConnected?.(new Error("Disconnected"));
@@ -278,7 +311,32 @@ export class LoroWebsocketClient {
278311
// Start (or continue) exponential backoff retries
279312
this.setStatus(ClientStatus.Reconnecting);
280313
this.scheduleReconnect();
281-
};
314+
}
315+
316+
private async onSocketMessage(
317+
ws: WebSocket,
318+
event: MessageEvent<string | ArrayBuffer>
319+
): Promise<void> {
320+
if (ws !== this.ws) {
321+
return;
322+
}
323+
if (typeof event.data === "string") {
324+
if (event.data === "ping") {
325+
try {
326+
ws.send("pong");
327+
} catch {}
328+
return;
329+
}
330+
if (event.data === "pong") {
331+
this.handlePong();
332+
return;
333+
}
334+
return; // ignore other texts
335+
}
336+
const dataU8 = new Uint8Array(event.data);
337+
const msg = tryDecode(dataU8);
338+
if (msg != null) await this.handleMessage(msg);
339+
}
282340

283341
private scheduleReconnect() {
284342
if (this.reconnectTimer) return;
@@ -758,7 +816,22 @@ export class LoroWebsocketClient {
758816
this.rejectConnected?.(new Error("Disconnected"));
759817
this.rejectConnected = undefined;
760818
this.resolveConnected = undefined;
761-
this.flushAndCloseWebSocket(this.ws, {
819+
this.rejectAllPingWaiters(new Error("Disconnected"));
820+
if (this.fragmentBatches.size) {
821+
for (const [, batch] of this.fragmentBatches) {
822+
try {
823+
clearTimeout(batch.timeoutId);
824+
} catch {}
825+
}
826+
this.fragmentBatches.clear();
827+
}
828+
this.awaitingPongSince = undefined;
829+
const ws = this.ws;
830+
if (ws && this.socketListeners.has(ws)) {
831+
this.ops.onWsClose?.();
832+
}
833+
this.detachSocketListeners(ws);
834+
this.flushAndCloseWebSocket(ws, {
762835
code: 1000,
763836
reason: "Client closed",
764837
});
@@ -846,6 +919,20 @@ export class LoroWebsocketClient {
846919
this.rejectConnected = undefined;
847920
this.resolveConnected = undefined;
848921
this.rejectAllPingWaiters(new Error("Destroyed"));
922+
if (this.fragmentBatches.size) {
923+
for (const [, batch] of this.fragmentBatches) {
924+
try {
925+
clearTimeout(batch.timeoutId);
926+
} catch {}
927+
}
928+
this.fragmentBatches.clear();
929+
}
930+
this.awaitingPongSince = undefined;
931+
const ws = this.ws;
932+
if (ws && this.socketListeners.has(ws)) {
933+
this.ops.onWsClose?.();
934+
}
935+
this.detachSocketListeners(ws);
849936
// Remove window event listeners if present
850937
try {
851938
if (
@@ -858,7 +945,7 @@ export class LoroWebsocketClient {
858945
} catch {}
859946
// Close websocket after flushing pending frames
860947
try {
861-
this.flushAndCloseWebSocket(this.ws, {
948+
this.flushAndCloseWebSocket(ws, {
862949
code: 1000,
863950
reason: "Client destroyed",
864951
});
@@ -915,12 +1002,15 @@ export class LoroWebsocketClient {
9151002

9161003
private detachSocketListeners(ws: WebSocket | undefined): void {
9171004
if (!ws) return;
1005+
const handlers = this.socketListeners.get(ws);
1006+
if (!handlers) return;
9181007
try {
919-
ws.removeEventListener?.("open", this.handleOpen);
920-
ws.removeEventListener?.("error", this.handleError);
921-
ws.removeEventListener?.("close", this.handleClose);
922-
ws.removeEventListener?.("message", this.handleWsMessage);
1008+
ws.removeEventListener?.("open", handlers.open);
1009+
ws.removeEventListener?.("error", handlers.error);
1010+
ws.removeEventListener?.("close", handlers.close);
1011+
ws.removeEventListener?.("message", handlers.message);
9231012
} catch {}
1013+
this.socketListeners.delete(ws);
9241014
}
9251015

9261016
private startPingTimer(): void {

packages/loro-websocket/tests/e2e.test.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,22 @@ describe("E2E: Client-Server Sync", () => {
324324

325325
clientWithPong.handlePong = originalHandlePong;
326326
}, 10000);
327+
328+
it("allows immediate reconnect after close without hanging", async () => {
329+
const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` });
330+
await client.waitConnected();
331+
332+
client.close();
333+
const reconnect = client.connect();
334+
335+
// Let the previous socket dispatch its close event before awaiting connect
336+
await new Promise(resolve => setTimeout(resolve, 10));
337+
338+
await reconnect;
339+
await waitUntil(() => client.getStatus() === ClientStatus.Connected, 5000);
340+
341+
client.destroy();
342+
}, 15000);
327343
});
328344

329345
// Small polling helper for this file

0 commit comments

Comments
 (0)