Skip to content

Commit 20a1d94

Browse files
authored
feat: add ws upgrade handler (midwayjs#4360)
* feat: add websocket upgrade handler * fix: lint
1 parent e0f9abb commit 20a1d94

File tree

9 files changed

+404
-2
lines changed

9 files changed

+404
-2
lines changed

packages/ws/src/framework.ts

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import {
2727
IMidwayWSConfigurationOptions,
2828
IMidwayWSContext,
2929
NextFunction,
30+
UpgradeAuthHandler,
3031
} from './interface';
3132
import * as WebSocket from 'ws';
3233

@@ -39,6 +40,7 @@ export class MidwayWSFramework extends BaseFramework<
3940
server: http.Server;
4041
protected heartBeatInterval: NodeJS.Timeout;
4142
protected connectionMiddlewareManager = this.createMiddlewareManager();
43+
protected upgradeAuthHandler: UpgradeAuthHandler | null = null;
4244

4345
configure(): IMidwayWSConfigurationOptions {
4446
return this.configService.getConfiguration('webSocket');
@@ -61,6 +63,9 @@ export class MidwayWSFramework extends BaseFramework<
6163
> => {
6264
return this.getConnectionMiddleware();
6365
},
66+
onWebSocketUpgrade: (handler: UpgradeAuthHandler) => {
67+
return this.onWebSocketUpgrade(handler);
68+
},
6469
});
6570
}
6671
public app: IMidwayWSApplication;
@@ -82,7 +87,35 @@ export class MidwayWSFramework extends BaseFramework<
8287
server = this.configurationOptions.server ?? http.createServer();
8388
}
8489

85-
server.on('upgrade', (request, socket: any, head: Buffer) => {
90+
server.on('upgrade', async (request, socket: any, head: Buffer) => {
91+
// check if the upgrade auth handler is set
92+
if (this.upgradeAuthHandler) {
93+
try {
94+
const authResult = await this.upgradeAuthHandler(
95+
request,
96+
socket,
97+
head
98+
);
99+
if (!authResult) {
100+
this.logger.warn(
101+
'[midway:ws] WebSocket upgrade authentication failed'
102+
);
103+
socket.destroy();
104+
return;
105+
}
106+
this.logger.debug(
107+
'[midway:ws] WebSocket upgrade authentication passed'
108+
);
109+
} catch (error) {
110+
this.logger.error(
111+
'[midway:ws] WebSocket upgrade authentication error:',
112+
error
113+
);
114+
socket.destroy();
115+
return;
116+
}
117+
}
118+
86119
this.app.handleUpgrade(request, socket, head, ws => {
87120
this.app.emit('connection', ws, request);
88121
});
@@ -120,6 +153,23 @@ export class MidwayWSFramework extends BaseFramework<
120153
return MidwayFrameworkType.WS;
121154
}
122155

156+
/**
157+
* 设置升级前鉴权处理函数
158+
* @param handler 鉴权处理函数,传入 null 可以禁用鉴权
159+
*/
160+
public onWebSocketUpgrade(handler: UpgradeAuthHandler | null): void {
161+
this.upgradeAuthHandler = handler;
162+
if (handler) {
163+
this.logger.info(
164+
'[midway:ws] WebSocket upgrade authentication handler set'
165+
);
166+
} else {
167+
this.logger.info(
168+
'[midway:ws] WebSocket upgrade authentication handler removed'
169+
);
170+
}
171+
}
172+
123173
private async loadMidwayController() {
124174
// create room
125175
const controllerModules = listModule(WS_CONTROLLER_KEY);

packages/ws/src/interface.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export type IMidwayWSApplication = IMidwayApplication<IMidwayWSContext, {
1414
middleware: CommonMiddlewareUnion<Context, NextFunction, undefined>
1515
) => void;
1616
getConnectionMiddleware: ContextMiddlewareManager<Context, NextFunction, undefined>;
17+
onWebSocketUpgrade: (handler: UpgradeAuthHandler | null) => void;
1718
}> & WebSocket.Server;
1819

1920
export type IMidwayWSConfigurationOptions = {
@@ -38,3 +39,12 @@ export type IMidwayWSContext = IMidwayContext<WebSocket & {
3839
export type Application = IMidwayWSApplication;
3940
export type NextFunction = BaseNextFunction;
4041
export interface Context extends IMidwayWSContext {}
42+
43+
/**
44+
* WebSocket 升级前鉴权处理函数类型
45+
*/
46+
export type UpgradeAuthHandler = (
47+
request: IncomingMessage,
48+
socket: any,
49+
head: Buffer
50+
) => Promise<boolean>;
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"name": "base-app-upgrade-auth"
3+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { Configuration, App } from '@midwayjs/core';
2+
import { ILifeCycle } from '@midwayjs/core';
3+
import { Application } from '../../../../src';
4+
import { UpgradeAuthHandler } from '../../../../src';
5+
import * as http from 'http';
6+
7+
@Configuration({
8+
importConfigs: [
9+
{
10+
default: {
11+
webSocket: {
12+
port: 3000,
13+
}
14+
}
15+
}
16+
]
17+
})
18+
export class AutoConfiguration implements ILifeCycle {
19+
20+
@App()
21+
app: Application;
22+
23+
async onReady() {
24+
// 设置升级鉴权处理函数
25+
this.app.onWebSocketUpgrade(this.authHandler);
26+
}
27+
28+
private authHandler: UpgradeAuthHandler = async (
29+
request: http.IncomingMessage,
30+
socket: any,
31+
head: Buffer
32+
): Promise<boolean> => {
33+
try {
34+
// 从 URL 参数获取 token
35+
const url = new URL(request.url || '', `http://${request.headers.host}`);
36+
const token = url.searchParams.get('token');
37+
38+
// 简单的 token 验证
39+
if (token === 'valid-token') {
40+
console.log('[Test Auth] Valid token, connection allowed');
41+
return true;
42+
}
43+
44+
console.log('[Test Auth] Invalid or missing token, connection denied');
45+
return false;
46+
} catch (error) {
47+
console.error('[Test Auth] Authentication error:', error);
48+
return false;
49+
}
50+
};
51+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import {
2+
OnWSConnection,
3+
OnWSMessage,
4+
Provide,
5+
WSController,
6+
} from '@midwayjs/core';
7+
8+
@Provide()
9+
@WSController()
10+
export class HelloSocketController {
11+
@OnWSConnection()
12+
async onConnectionMethod() {
13+
console.log('on connection');
14+
}
15+
16+
@OnWSMessage('message')
17+
async onMessage(data: any) {
18+
// 处理 Buffer 数据
19+
let messageData = data;
20+
if (Buffer.isBuffer(data)) {
21+
messageData = data.toString();
22+
}
23+
24+
return { echo: messageData, timestamp: Date.now() };
25+
}
26+
}

packages/ws/test/index.test.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { closeApp, createServer } from './utils';
1+
import { closeApp, createServer, testConnectionRejected } from './utils';
22
import { sleep } from '@midwayjs/core';
33
import { once } from 'events';
44
import { createWebSocketClient } from '@midwayjs/mock';
@@ -109,4 +109,31 @@ describe('/test/index.test.ts', () => {
109109

110110
await closeApp(app);
111111
});
112+
113+
it('should test onWebSocketUpgrade authentication', async () => {
114+
const app = await createServer('base-app-upgrade-auth');
115+
116+
// 测试1: 没有 token 的连接应该被拒绝
117+
const rejected1 = await testConnectionRejected('ws://localhost:3000');
118+
expect(rejected1).toBe(true);
119+
120+
// 测试2: 无效 token 的连接应该被拒绝
121+
const rejected2 = await testConnectionRejected('ws://localhost:3000?token=invalid-token');
122+
expect(rejected2).toBe(true);
123+
124+
// 测试3: 有效 token 的连接应该成功
125+
const client3 = await createWebSocketClient(`ws://localhost:3000?token=valid-token`);
126+
127+
// 发送消息测试连接是否正常工作
128+
client3.send('test-message');
129+
const gotEvent = once(client3, 'message');
130+
const [data] = await gotEvent;
131+
const response = JSON.parse(data);
132+
133+
expect(response.echo).toEqual('test-message');
134+
expect(response.timestamp).toBeDefined();
135+
136+
await client3.close();
137+
await closeApp(app);
138+
});
112139
});

packages/ws/test/utils.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,50 @@ export async function createServer(name: string, options: IMidwayWSConfiguration
1515
export async function closeApp(app) {
1616
return close(app);
1717
}
18+
19+
/**
20+
* 测试 WebSocket 连接是否被拒绝
21+
* @param url WebSocket 连接 URL
22+
* @param timeout 超时时间(毫秒)
23+
* @returns Promise<boolean> true 表示连接被拒绝,false 表示连接成功
24+
*/
25+
export async function testConnectionRejected(url: string, timeout: number = 2000): Promise<boolean> {
26+
return new Promise((resolve) => {
27+
const WebSocket = require('ws');
28+
const client = new WebSocket(url);
29+
30+
let resolved = false;
31+
const timer = setTimeout(() => {
32+
if (!resolved) {
33+
resolved = true;
34+
client.terminate();
35+
resolve(true); // 超时认为连接被拒绝
36+
}
37+
}, timeout);
38+
39+
client.on('open', () => {
40+
if (!resolved) {
41+
resolved = true;
42+
clearTimeout(timer);
43+
client.close();
44+
resolve(false); // 连接成功
45+
}
46+
});
47+
48+
client.on('error', (error) => {
49+
if (!resolved) {
50+
resolved = true;
51+
clearTimeout(timer);
52+
resolve(true); // 连接被拒绝
53+
}
54+
});
55+
56+
client.on('close', () => {
57+
if (!resolved) {
58+
resolved = true;
59+
clearTimeout(timer);
60+
resolve(true); // 连接被拒绝
61+
}
62+
});
63+
});
64+
}

site/docs/extensions/ws.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,101 @@ const client = new WebSocket('wss://websocket-echo.com/');
276276
client.on('ping', heartbeat);
277277
```
278278

279+
## 鉴权
280+
281+
在 WebSocket 连接建立之前,您可能需要对客户端进行身份验证。从 v3.20.9 开始 Midway 提供了 `onWebSocketUpgrade` 方法来在 WebSocket 握手前进行鉴权。
282+
283+
### 设置鉴权处理器
284+
285+
您可以在应用启动时设置鉴权处理器:
286+
287+
```typescript
288+
import { Configuration, Inject } from '@midwayjs/core';
289+
import { MidwayWSFramework } from '@midwayjs/ws';
290+
291+
@Configuration()
292+
export class WSConfiguration {
293+
@Inject()
294+
wsFramework: MidwayWSFramework;
295+
296+
async onReady() {
297+
// 设置升级前鉴权处理器
298+
this.wsFramework.onWebSocketUpgrade(async (request, socket, head) => {
299+
// 从 URL 参数中获取 token
300+
const url = new URL(request.url, `http://${request.headers.host}`);
301+
const token = url.searchParams.get('token');
302+
303+
// 验证 token
304+
if (token === 'valid-token') {
305+
return true; // 允许连接
306+
}
307+
308+
return false; // 拒绝连接
309+
});
310+
}
311+
}
312+
```
313+
314+
### 鉴权处理器参数
315+
316+
鉴权处理器接收三个参数:
317+
318+
- `request`: HTTP 请求对象 (`http.IncomingMessage`)
319+
- `socket`: 原始 socket 对象
320+
- `head`: WebSocket 握手的头部数据 (`Buffer`)
321+
322+
处理器需要返回一个 `Promise<boolean>`
323+
- `true`: 允许 WebSocket 连接
324+
- `false`: 拒绝 WebSocket 连接
325+
326+
### 获取鉴权信息
327+
328+
您可以从多个来源获取鉴权信息:
329+
330+
**URL 参数**
331+
332+
```typescript
333+
this.wsFramework.onWebSocketUpgrade(async (request, socket, head) => {
334+
const url = new URL(request.url, `http://${request.headers.host}`);
335+
const token = url.searchParams.get('token');
336+
const userId = url.searchParams.get('userId');
337+
338+
// 验证逻辑
339+
return await this.validateToken(token, userId);
340+
});
341+
```
342+
343+
**请求头**
344+
345+
```typescript
346+
this.wsFramework.onWebSocketUpgrade(async (request, socket, head) => {
347+
const authorization = request.headers.authorization;
348+
349+
if (!authorization) {
350+
return false;
351+
}
352+
353+
const token = authorization.replace('Bearer ', '');
354+
return await this.validateToken(token);
355+
});
356+
```
357+
358+
**Cookie**
359+
360+
```typescript
361+
this.wsFramework.onWebSocketUpgrade(async (request, socket, head) => {
362+
const cookie = request.headers.cookie;
363+
364+
if (!cookie) {
365+
return false;
366+
}
367+
368+
// 解析 cookie 获取 session 信息
369+
const sessionId = this.parseCookie(cookie).sessionId;
370+
return await this.validateSession(sessionId);
371+
});
372+
```
373+
279374

280375

281376
## 本地测试

0 commit comments

Comments
 (0)