Skip to content

Commit 903eb9c

Browse files
authored
feat: Support cookies when authenticating websockets (#44)
* Support Cookies when authenticating new sockets * Removing the yarn.lock file * Include request struct in handshake_auth instead of cookies * Exposing conn_id to hooks to recognize user sessions * Added support for a new "on_close_connection" hook.
1 parent b4695cc commit 903eb9c

File tree

14 files changed

+325
-21
lines changed

14 files changed

+325
-21
lines changed

packages/loro-websocket/src/server/simple-server.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { WebSocketServer, WebSocket } from "ws";
22
import { randomBytes } from "node:crypto";
33
import type { RawData } from "ws";
4+
import type { IncomingMessage } from "http";
45
// no direct CRDT imports here; handled by CrdtDoc implementations
56
import {
67
encode,
@@ -47,6 +48,13 @@ export interface SimpleServerConfig {
4748
crdtType: CrdtType,
4849
auth: Uint8Array
4950
) => Promise<Permission | null>;
51+
/**
52+
* Optional handshake auth: called during WS HTTP upgrade.
53+
* Return true to accept, false to reject.
54+
*/
55+
handshakeAuth?: (
56+
req: IncomingMessage
57+
) => boolean | Promise<boolean>;
5058
}
5159

5260
interface RoomDocument {
@@ -86,12 +94,28 @@ export class SimpleServer {
8694

8795
start(): Promise<void> {
8896
return new Promise(resolve => {
89-
const options: { port: number; host?: string } = {
97+
const options: { port: number; host?: string; verifyClient?: any } = {
9098
port: this.config.port,
9199
};
92100
if (this.config.host) {
93101
options.host = this.config.host;
94102
}
103+
if (this.config.handshakeAuth) {
104+
options.verifyClient = (
105+
info: { origin: string; secure: boolean; req: IncomingMessage },
106+
cb: (res: boolean, code?: number, message?: string) => void
107+
) => {
108+
Promise.resolve(this.config.handshakeAuth!(info.req))
109+
.then(allowed => {
110+
if (allowed) cb(true);
111+
else cb(false, 401, "Unauthorized");
112+
})
113+
.catch(err => {
114+
console.error("Handshake auth error", err);
115+
cb(false, 500, "Internal Server Error");
116+
});
117+
};
118+
}
95119
this.wss = new WebSocketServer(options);
96120

97121
this.wss.on("connection", ws => {
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import { describe, it, expect, beforeAll, afterAll } from "vitest";
2+
import { WebSocket } from "ws";
3+
import getPort from "get-port";
4+
import { SimpleServer } from "../src/server/simple-server";
5+
6+
// Make WebSocket available globally for the client
7+
Object.defineProperty(globalThis, "WebSocket", {
8+
value: WebSocket,
9+
configurable: true,
10+
writable: true,
11+
});
12+
13+
describe("Handshake Auth", () => {
14+
let server: SimpleServer;
15+
let port: number;
16+
17+
beforeAll(async () => {
18+
port = await getPort();
19+
server = new SimpleServer({
20+
port,
21+
handshakeAuth: req => {
22+
const cookie = req.headers.cookie;
23+
return cookie === "session=valid";
24+
},
25+
});
26+
await server.start();
27+
});
28+
29+
afterAll(async () => {
30+
await server.stop();
31+
}, 10000);
32+
33+
it("should accept connection with valid cookie", async () => {
34+
const ws = new WebSocket(`ws://localhost:${port}`, {
35+
headers: {
36+
Cookie: "session=valid",
37+
},
38+
});
39+
40+
await new Promise<void>((resolve, reject) => {
41+
ws.onopen = () => resolve();
42+
ws.onerror = err => reject(err);
43+
});
44+
ws.close();
45+
});
46+
47+
it("should reject connection with invalid cookie", async () => {
48+
const ws = new WebSocket(`ws://localhost:${port}`, {
49+
headers: {
50+
Cookie: "session=invalid",
51+
},
52+
});
53+
54+
await new Promise<void>((resolve, reject) => {
55+
ws.onopen = () => reject(new Error("Should have failed"));
56+
ws.onerror = err => {
57+
resolve();
58+
};
59+
});
60+
});
61+
62+
it("should reject connection with missing cookie", async () => {
63+
const ws = new WebSocket(`ws://localhost:${port}`);
64+
65+
await new Promise<void>((resolve, reject) => {
66+
ws.onopen = () => reject(new Error("Should have failed"));
67+
ws.onerror = () => resolve();
68+
});
69+
});
70+
});

rust/Cargo.lock

Lines changed: 63 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/loro-websocket-server/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ tokio-tungstenite = "0.27"
2020
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
2121
loro = "1"
2222
tracing = "0.1"
23+
cookie = "0.18.1"
2324

2425
[dev-dependencies]
2526
loro-websocket-client = { version = "0.1.0", path = "../loro-websocket-client" }

rust/loro-websocket-server/src/lib.rs

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,40 @@ type LoadFuture<DocCtx> =
104104
type SaveFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
105105
type LoadFn<DocCtx> = Arc<dyn Fn(LoadDocArgs) -> LoadFuture<DocCtx> + Send + Sync>;
106106
type SaveFn<DocCtx> = Arc<dyn Fn(SaveDocArgs<DocCtx>) -> SaveFuture + Send + Sync>;
107+
108+
/// Arguments provided to `authenticate`.
109+
pub struct AuthArgs {
110+
pub room: String,
111+
pub crdt: CrdtType,
112+
pub auth: Vec<u8>,
113+
pub conn_id: u64,
114+
}
115+
107116
type AuthFuture =
108117
Pin<Box<dyn Future<Output = Result<Option<Permission>, String>> + Send + 'static>>;
109-
type AuthFn = Arc<dyn Fn(String, CrdtType, Vec<u8>) -> AuthFuture + Send + Sync>;
118+
type AuthFn = Arc<dyn Fn(AuthArgs) -> AuthFuture + Send + Sync>;
119+
120+
/// Arguments provided to `handshake_auth`.
121+
pub struct HandshakeAuthArgs<'a> {
122+
pub workspace: &'a str,
123+
pub token: Option<&'a str>,
124+
pub request: &'a tungstenite::handshake::server::Request,
125+
pub conn_id: u64,
126+
}
110127

111-
type HandshakeAuthFn = dyn Fn(&str, Option<&str>) -> bool + Send + Sync;
128+
type HandshakeAuthFn = dyn Fn(HandshakeAuthArgs) -> bool + Send + Sync;
129+
130+
/// Arguments provided to `on_close_connection`.
131+
pub struct CloseConnectionArgs {
132+
pub workspace: String,
133+
pub conn_id: u64,
134+
pub rooms: Vec<(CrdtType, String)>,
135+
}
136+
137+
type CloseConnectionFuture =
138+
Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
139+
type CloseConnectionFn =
140+
Arc<dyn Fn(CloseConnectionArgs) -> CloseConnectionFuture + Send + Sync>;
112141

113142
#[derive(Clone)]
114143
pub struct ServerConfig<DocCtx = ()> {
@@ -122,9 +151,14 @@ pub struct ServerConfig<DocCtx = ()> {
122151
/// Parameters:
123152
/// - `workspace_id`: extracted from request path `/{workspace}` (empty if missing)
124153
/// - `token`: `token` query parameter if present
154+
/// - `request`: the full HTTP request (headers, uri, etc)
155+
/// - `conn_id`: the connection id
125156
///
126157
/// Return true to accept, false to reject with 401.
127158
pub handshake_auth: Option<Arc<HandshakeAuthFn>>,
159+
/// Optional hook invoked after a connection fully closes.
160+
/// Receives the workspace id, connection id, and rooms the client had joined.
161+
pub on_close_connection: Option<CloseConnectionFn>,
128162
}
129163

130164
// CRDT document abstraction to reduce match-based branching
@@ -440,6 +474,7 @@ impl<DocCtx> Default for ServerConfig<DocCtx> {
440474
default_permission: Permission::Write,
441475
authenticate: None,
442476
handshake_auth: None,
477+
on_close_connection: None,
443478
}
444479
}
445480
}
@@ -884,12 +919,18 @@ async fn handle_conn<DocCtx>(
884919
where
885920
DocCtx: Clone + Send + Sync + 'static,
886921
{
922+
923+
// Generate a connection id
924+
let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
925+
887926
// Capture config outside of non-async closure
888927
let handshake_auth = registry.config.handshake_auth.clone();
928+
let close_connection = registry.config.on_close_connection.clone();
889929
let workspace_holder: Arc<std::sync::Mutex<Option<String>>> =
890930
Arc::new(std::sync::Mutex::new(None));
891931
let workspace_holder_c = workspace_holder.clone();
892932

933+
893934
let ws = accept_hdr_async(
894935
stream,
895936
move |req: &tungstenite::handshake::server::Request,
@@ -925,7 +966,12 @@ where
925966
None
926967
});
927968

928-
let allowed = (check)(workspace_id, token);
969+
let allowed = (check)(HandshakeAuthArgs {
970+
workspace: workspace_id,
971+
token,
972+
request: req,
973+
conn_id,
974+
});
929975
if !allowed {
930976
warn!(workspace=%workspace_id, token=?token, "handshake auth denied");
931977
// Build a 401 Unauthorized response
@@ -971,7 +1017,6 @@ where
9711017
}
9721018
});
9731019

974-
let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
9751020
let mut joined_rooms: HashSet<RoomKey> = HashSet::new();
9761021

9771022
while let Some(msg) = stream.next().await {
@@ -1001,7 +1046,14 @@ where
10011046
let mut permission = h.config.default_permission;
10021047
if let Some(auth_fn) = &h.config.authenticate {
10031048
let room_str = room.room.clone();
1004-
match (auth_fn)(room_str, room.crdt, auth.clone()).await {
1049+
match (auth_fn)(AuthArgs {
1050+
room: room_str,
1051+
crdt: room.crdt,
1052+
auth: auth.clone(),
1053+
conn_id,
1054+
})
1055+
.await
1056+
{
10051057
Ok(Some(p)) => {
10061058
permission = p;
10071059
}
@@ -1387,6 +1439,11 @@ where
13871439
}
13881440
}
13891441

1442+
let rooms_for_hook: Vec<(CrdtType, String)> = joined_rooms
1443+
.into_iter()
1444+
.map(|RoomKey { crdt, room }| (crdt, room))
1445+
.collect();
1446+
13901447
// cleanup
13911448
{
13921449
let mut h = hub.lock().await;
@@ -1395,6 +1452,18 @@ where
13951452
// drop tx to stop writer
13961453
drop(tx);
13971454
let _ = sink_task.await;
1455+
1456+
if let Some(hook) = close_connection {
1457+
let args = CloseConnectionArgs {
1458+
workspace: workspace_id.clone(),
1459+
conn_id,
1460+
rooms: rooms_for_hook,
1461+
};
1462+
if let Err(e) = (hook)(args).await {
1463+
warn!(conn_id, %e, "on_close_connection hook failed");
1464+
}
1465+
}
1466+
13981467
debug!(conn_id, "connection closed and cleaned up");
13991468
Ok(())
14001469
}

0 commit comments

Comments
 (0)