Skip to content

Commit 8a9960a

Browse files
zxch3nCopilotCopilot
authored
feat: join room with auth providers (#37)
* fix: resolve auth before join * fix: dedupe concurrent joins before auth resolution * Update packages/loro-websocket/src/client/index.ts Co-authored-by: Copilot <[email protected]> * fix: re-resolve auth option on join retries * fix: add explicit pendingRooms.delete before cleanupRoom in reject handler Co-authored-by: zxch3n <[email protected]> --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: zxch3n <[email protected]>
1 parent 6570a2e commit 8a9960a

File tree

2 files changed

+161
-27
lines changed

2 files changed

+161
-27
lines changed

packages/loro-websocket/src/client/index.ts

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ import type { CrdtDocAdaptor } from "loro-adaptors";
2424

2525
export * from "loro-adaptors";
2626

27+
export type AuthProvider = () => Uint8Array | Promise<Uint8Array>;
28+
type AuthOption = Uint8Array | AuthProvider;
29+
2730
interface FragmentBatch {
2831
header: DocUpdateFragmentHeader;
2932
fragments: Map<number, Uint8Array>;
@@ -36,7 +39,7 @@ interface PendingRoom {
3639
reject: (error: Error) => void;
3740
adaptor: CrdtDocAdaptor;
3841
roomId: string;
39-
auth?: Uint8Array;
42+
auth?: AuthOption;
4043
isRejoin?: boolean;
4144
}
4245

@@ -173,7 +176,7 @@ export class LoroWebsocketClient {
173176
private roomAdaptors: Map<string, CrdtDocAdaptor> = new Map();
174177
// Track roomId for each active id so we can rejoin on reconnect
175178
private roomIds: Map<string, string> = new Map();
176-
private roomAuth: Map<string, Uint8Array | undefined> = new Map();
179+
private roomAuth: Map<string, AuthOption | undefined> = new Map();
177180
private roomStatusListeners: Map<
178181
string,
179182
Set<(s: RoomJoinStatusValue) => void>
@@ -206,6 +209,17 @@ export class LoroWebsocketClient {
206209
void this.connect();
207210
}
208211

212+
private async resolveAuth(auth?: AuthOption): Promise<Uint8Array> {
213+
if (typeof auth === "function") {
214+
const value = await auth();
215+
if (!(value instanceof Uint8Array)) {
216+
throw new Error("Auth provider must return Uint8Array");
217+
}
218+
return value;
219+
}
220+
return auth ?? new Uint8Array();
221+
}
222+
209223
get socket(): WebSocket {
210224
return this.ws;
211225
}
@@ -562,17 +576,27 @@ export class LoroWebsocketClient {
562576
if (!roomId) continue;
563577
const active = this.activeRooms.get(id);
564578
if (!active) continue;
565-
this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id));
579+
void this.sendRejoinRequest(id, roomId, adaptor, active.room, this.roomAuth.get(id));
566580
}
567581
}
568582

569-
private sendRejoinRequest(
583+
private async sendRejoinRequest(
570584
id: string,
571585
roomId: string,
572586
adaptor: CrdtDocAdaptor,
573587
room: LoroWebsocketClientRoom,
574-
auth?: Uint8Array
588+
auth?: AuthOption
575589
) {
590+
let authValue: Uint8Array;
591+
try {
592+
authValue = await this.resolveAuth(auth);
593+
} catch (e) {
594+
console.error("Failed to resolve auth for rejoin:", e);
595+
this.cleanupRoom(roomId, adaptor.crdtType);
596+
this.emitRoomStatus(id, RoomJoinStatus.Error);
597+
return;
598+
}
599+
576600
// Prepare a lightweight pending entry so JoinError handling can retry version formats
577601
const pending: PendingRoom = {
578602
room: Promise.resolve(room),
@@ -589,6 +613,7 @@ export class LoroWebsocketClient {
589613
},
590614
reject: (error: Error) => {
591615
console.error("Rejoin failed:", error);
616+
this.pendingRooms.delete(id);
592617
this.cleanupRoom(roomId, adaptor.crdtType);
593618
this.emitRoomStatus(id, RoomJoinStatus.Error);
594619
},
@@ -603,7 +628,7 @@ export class LoroWebsocketClient {
603628
type: MessageType.JoinRequest,
604629
crdt: adaptor.crdtType,
605630
roomId,
606-
auth: auth ?? new Uint8Array(),
631+
auth: authValue,
607632
version: adaptor.getVersion(),
608633
} as JoinRequest);
609634

@@ -677,7 +702,7 @@ export class LoroWebsocketClient {
677702
// Drop any in-flight join since the server explicitly removed us
678703
this.pendingRooms.delete(roomId);
679704
if (shouldRejoin && active && adaptor) {
680-
this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth);
705+
void this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth);
681706
} else {
682707
// Remove local room state so client does not auto-retry unless requested
683708
this.cleanupRoom(msg.roomId, msg.crdt);
@@ -815,6 +840,19 @@ export class LoroWebsocketClient {
815840
roomId: string
816841
) {
817842
if (msg.code === JoinErrorCode.VersionUnknown) {
843+
let authValue: Uint8Array;
844+
try {
845+
authValue = await this.resolveAuth(pending.auth);
846+
} catch (e) {
847+
pending.reject(e as Error);
848+
this.pendingRooms.delete(roomId);
849+
this.emitRoomStatus(
850+
pending.adaptor.crdtType + pending.roomId,
851+
RoomJoinStatus.Error
852+
);
853+
return;
854+
}
855+
818856
// Try alternative version format
819857
const currentVersion = pending.adaptor.getVersion();
820858
const alternativeVersion =
@@ -826,7 +864,7 @@ export class LoroWebsocketClient {
826864
type: MessageType.JoinRequest,
827865
crdt: pending.adaptor.crdtType,
828866
roomId: pending.roomId,
829-
auth: pending.auth ?? new Uint8Array(),
867+
auth: authValue,
830868
version: alternativeVersion,
831869
} as JoinRequest)
832870
);
@@ -838,7 +876,7 @@ export class LoroWebsocketClient {
838876
type: MessageType.JoinRequest,
839877
crdt: pending.adaptor.crdtType,
840878
roomId: pending.roomId,
841-
auth: pending.auth ?? new Uint8Array(),
879+
auth: authValue,
842880
version: new Uint8Array(),
843881
} as JoinRequest)
844882
);
@@ -915,7 +953,11 @@ export class LoroWebsocketClient {
915953
}
916954

917955
/**
918-
* Join a room; `auth` carries application-defined join metadata forwarded to the server.
956+
* Join a room.
957+
* - `auth` may be a `Uint8Array` or a provider function.
958+
* - The provider is invoked on the initial join and again on protocol-driven retries
959+
* (e.g. `VersionUnknown`) and reconnect rejoins, so it can refresh short-lived tokens.
960+
* If callers need a stable token, memoize in the provider.
919961
*/
920962
join({
921963
roomId,
@@ -925,7 +967,7 @@ export class LoroWebsocketClient {
925967
}: {
926968
roomId: string;
927969
crdtAdaptor: CrdtDocAdaptor;
928-
auth?: Uint8Array;
970+
auth?: AuthOption;
929971
onStatusChange?: (s: RoomJoinStatusValue) => void;
930972
}): Promise<LoroWebsocketClientRoom> {
931973
const id = crdtAdaptor.crdtType + roomId;
@@ -940,8 +982,8 @@ export class LoroWebsocketClient {
940982
return Promise.resolve(active.room);
941983
}
942984

943-
let resolve: (res: JoinResponseOk) => void;
944-
let reject: (error: Error) => void;
985+
let resolve!: (res: JoinResponseOk) => void;
986+
let reject!: (error: Error) => void;
945987

946988
const response = new Promise<JoinResponseOk>((resolve_, reject_) => {
947989
resolve = resolve_;
@@ -1005,6 +1047,7 @@ export class LoroWebsocketClient {
10051047
return room;
10061048
});
10071049

1050+
// Register pending room immediately so concurrent join calls dedupe
10081051
this.pendingRooms.set(id, {
10091052
room,
10101053
resolve: resolve!,
@@ -1015,21 +1058,30 @@ export class LoroWebsocketClient {
10151058
});
10161059
this.roomAuth.set(id, auth);
10171060

1018-
const joinPayload = encode({
1019-
type: MessageType.JoinRequest,
1020-
crdt: crdtAdaptor.crdtType,
1021-
roomId,
1022-
auth: auth ?? new Uint8Array(),
1023-
version: crdtAdaptor.getVersion(),
1024-
} as JoinRequest);
1061+
void this.resolveAuth(auth)
1062+
.then(authValue => {
1063+
const joinPayload = encode({
1064+
type: MessageType.JoinRequest,
1065+
crdt: crdtAdaptor.crdtType,
1066+
roomId,
1067+
auth: authValue,
1068+
version: crdtAdaptor.getVersion(),
1069+
} as JoinRequest);
10251070

1026-
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
1027-
this.ws.send(joinPayload);
1028-
} else {
1029-
this.enqueueJoin(joinPayload);
1030-
// ensure a connection attempt is running
1031-
void this.connect();
1032-
}
1071+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
1072+
this.ws.send(joinPayload);
1073+
} else {
1074+
this.enqueueJoin(joinPayload);
1075+
// ensure a connection attempt is running
1076+
void this.connect();
1077+
}
1078+
})
1079+
.catch(err => {
1080+
const error = err instanceof Error ? err : new Error(String(err));
1081+
this.emitRoomStatus(id, RoomJoinStatus.Error);
1082+
reject(error);
1083+
this.cleanupRoom(roomId, crdtAdaptor.crdtType);
1084+
});
10331085

10341086
return room;
10351087
}

packages/loro-websocket/tests/e2e.test.ts

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,88 @@ describe("E2E: Client-Server Sync", () => {
571571
await authServer.stop();
572572
}, 15000);
573573

574+
it("fetches fresh auth on rejoin when auth provider is used", async () => {
575+
const port = await getPort();
576+
const tokens: string[] = [];
577+
578+
const server = new SimpleServer({
579+
port,
580+
authenticate: async (_roomId, _crdt, auth) => {
581+
tokens.push(new TextDecoder().decode(auth));
582+
return "write";
583+
},
584+
});
585+
await server.start();
586+
587+
const client = new LoroWebsocketClient({
588+
url: `ws://localhost:${port}`,
589+
reconnect: { initialDelayMs: 20, maxDelayMs: 100, jitter: 0 },
590+
});
591+
592+
let room: LoroWebsocketClientRoom | undefined;
593+
try {
594+
await client.waitConnected();
595+
let call = 0;
596+
const adaptor = new LoroAdaptor();
597+
598+
room = await client.join({
599+
roomId: "auth-refresh",
600+
crdtAdaptor: adaptor,
601+
auth: async () => new TextEncoder().encode(`token-${++call}`),
602+
});
603+
604+
await waitUntil(() => tokens.length >= 1, 5000, 25);
605+
606+
await server.stop();
607+
await new Promise(resolve => setTimeout(resolve, 60));
608+
await server.start();
609+
610+
await waitUntil(() => tokens.some(t => t === "token-2"), 10000, 50);
611+
612+
expect(tokens[0]).toBe("token-1");
613+
expect(tokens.some(t => t === "token-2")).toBe(true);
614+
} finally {
615+
await room?.destroy();
616+
client.destroy();
617+
await server.stop();
618+
}
619+
}, 15000);
620+
621+
it("dedupes concurrent join calls even before auth resolves", async () => {
622+
const port = await getPort();
623+
const tokens: string[] = [];
624+
625+
const server = new SimpleServer({
626+
port,
627+
authenticate: async (_roomId, _crdt, auth) => {
628+
tokens.push(new TextDecoder().decode(auth));
629+
return "write";
630+
},
631+
});
632+
await server.start();
633+
634+
const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` });
635+
await client.waitConnected();
636+
637+
const adaptor = new LoroAdaptor();
638+
const auth = () => new TextEncoder().encode("token-once");
639+
640+
const joinPromise1 = client.join({ roomId: "dedupe", crdtAdaptor: adaptor, auth });
641+
const joinPromise2 = client.join({ roomId: "dedupe", crdtAdaptor: adaptor, auth });
642+
643+
expect(joinPromise1).toBe(joinPromise2);
644+
645+
const [room1, room2] = await Promise.all([joinPromise1, joinPromise2]);
646+
expect(room1).toBe(room2);
647+
648+
await waitUntil(() => tokens.length >= 1, 5000, 25);
649+
expect(tokens).toHaveLength(1);
650+
651+
await room1.destroy();
652+
client.destroy();
653+
await server.stop();
654+
}, 15000);
655+
574656
it("destroy rejects pending ping waiters", async () => {
575657
const client = new LoroWebsocketClient({ url: `ws://localhost:${port}` });
576658
await client.waitConnected();

0 commit comments

Comments
 (0)