Skip to content

Commit 6271520

Browse files
authored
fix: Improve WS request sanitization (#10231)
1 parent 4c3029c commit 6271520

File tree

10 files changed

+516
-74
lines changed

10 files changed

+516
-74
lines changed

packages/cubejs-api-gateway/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
"nexus": "^1.1.0",
5050
"node-fetch": "^2.6.1",
5151
"ramda": "^0.27.0",
52-
"uuid": "^8.3.2"
52+
"uuid": "^8.3.2",
53+
"zod": "^4.1.13"
5354
},
5455
"devDependencies": {
5556
"@cubejs-backend/linter": "1.5.12",

packages/cubejs-api-gateway/src/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@ export * from './sql-server';
33
export * from './interfaces';
44
export * from './cubejs-handler-error';
55
export * from './user-error';
6+
67
export { getRequestIdFromRequest } from './request-parser';
78
export { TransformDataRequest } from './types/responses';
9+
10+
export type { SubscriptionServer } from './ws';
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export * from './local-subscription-store';
2+
export * from './message-schema';
3+
export * from './subscription-server';

packages/cubejs-api-gateway/src/ws/local-subscription-store.ts

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,19 @@ interface LocalSubscriptionStoreOptions {
22
heartBeatInterval?: number;
33
}
44

5+
export type LocalSubscriptionStoreSubscription = {
6+
message: any,
7+
state: any,
8+
timestamp: Date,
9+
};
10+
11+
export type LocalSubscriptionStoreConnection = {
12+
subscriptions: Map<string, LocalSubscriptionStoreSubscription>,
13+
authContext?: any,
14+
};
15+
516
export class LocalSubscriptionStore {
6-
protected connections = {};
17+
protected readonly connections: Map<string, LocalSubscriptionStoreConnection> = new Map();
718

819
protected readonly hearBeatInterval: number;
920

@@ -12,60 +23,68 @@ export class LocalSubscriptionStore {
1223
}
1324

1425
public async getSubscription(connectionId: string, subscriptionId: string) {
15-
const connection = this.getConnection(connectionId);
16-
return connection.subscriptions[subscriptionId];
26+
const connection = this.getConnectionOrCreate(connectionId);
27+
return connection.subscriptions.get(subscriptionId);
1728
}
1829

1930
public async subscribe(connectionId: string, subscriptionId: string, subscription) {
20-
const connection = this.getConnection(connectionId);
21-
connection.subscriptions[subscriptionId] = {
31+
const connection = this.getConnectionOrCreate(connectionId);
32+
connection.subscriptions.set(subscriptionId, {
2233
...subscription,
2334
timestamp: new Date()
24-
};
35+
});
2536
}
2637

2738
public async unsubscribe(connectionId: string, subscriptionId: string) {
28-
const connection = this.getConnection(connectionId);
29-
delete connection.subscriptions[subscriptionId];
39+
const connection = this.getConnectionOrCreate(connectionId);
40+
connection.subscriptions.delete(subscriptionId);
3041
}
3142

32-
public async getAllSubscriptions() {
33-
return Object.keys(this.connections).map(connectionId => {
34-
Object.keys(this.connections[connectionId].subscriptions).filter(
35-
subscriptionId => new Date().getTime() -
36-
this.connections[connectionId].subscriptions[subscriptionId].timestamp.getTime() >
37-
this.hearBeatInterval * 4 * 1000
38-
).forEach(subscriptionId => { delete this.connections[connectionId].subscriptions[subscriptionId]; });
39-
40-
return Object.keys(this.connections[connectionId].subscriptions)
41-
.map(subscriptionId => ({
42-
connectionId,
43-
...this.connections[connectionId].subscriptions[subscriptionId]
44-
}));
45-
}).reduce((a, b) => a.concat(b), []);
43+
public getAllSubscriptions() {
44+
const now = Date.now();
45+
const staleThreshold = this.hearBeatInterval * 4 * 1000;
46+
const result: Array<{ connectionId: string } & LocalSubscriptionStoreSubscription> = [];
47+
48+
for (const [connectionId, connection] of this.connections) {
49+
for (const [subscriptionId, subscription] of connection.subscriptions) {
50+
if (now - subscription.timestamp.getTime() > staleThreshold) {
51+
connection.subscriptions.delete(subscriptionId);
52+
}
53+
}
54+
55+
for (const [, subscription] of connection.subscriptions) {
56+
result.push({ connectionId, ...subscription });
57+
}
58+
}
59+
60+
return result;
4661
}
4762

48-
public async cleanupSubscriptions(connectionId: string) {
49-
delete this.connections[connectionId];
63+
public async disconnect(connectionId: string) {
64+
this.connections.delete(connectionId);
5065
}
5166

5267
public async getAuthContext(connectionId: string) {
53-
return this.getConnection(connectionId).authContext;
68+
return this.getConnectionOrCreate(connectionId).authContext;
5469
}
5570

5671
public async setAuthContext(connectionId: string, authContext) {
57-
this.getConnection(connectionId).authContext = authContext;
72+
this.getConnectionOrCreate(connectionId).authContext = authContext;
5873
}
5974

60-
protected getConnection(connectionId: string) {
61-
if (!this.connections[connectionId]) {
62-
this.connections[connectionId] = { subscriptions: {} };
75+
protected getConnectionOrCreate(connectionId: string): LocalSubscriptionStoreConnection {
76+
const connect = this.connections.get(connectionId);
77+
if (connect) {
78+
return connect;
6379
}
6480

65-
return this.connections[connectionId];
81+
const connection = { subscriptions: new Map() };
82+
this.connections.set(connectionId, connection);
83+
84+
return connection;
6685
}
6786

6887
public clear() {
69-
this.connections = {};
88+
this.connections.clear();
7089
}
7190
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { z } from 'zod';
2+
3+
const messageId = z.union([z.string().max(16), z.number()]);
4+
const requestId = z.string().max(64).optional();
5+
6+
export const authMessageSchema = z.object({
7+
authorization: z.string(),
8+
}).strict();
9+
10+
export const unsubscribeMessageSchema = z.object({
11+
unsubscribe: z.string().max(16),
12+
}).strict();
13+
14+
const queryParams = z.object({
15+
query: z.unknown(),
16+
queryType: z.string().optional(),
17+
}).strict();
18+
19+
const queryOnlyParams = z.object({
20+
query: z.unknown(),
21+
}).strict();
22+
23+
// Method-based messages using discriminatedUnion
24+
export const methodMessageSchema = z.discriminatedUnion('method', [
25+
z.object({
26+
method: z.literal('load'),
27+
messageId,
28+
requestId,
29+
params: queryParams,
30+
}).strict(),
31+
z.object({
32+
method: z.literal('sql'),
33+
messageId,
34+
requestId,
35+
params: queryOnlyParams,
36+
}).strict(),
37+
z.object({
38+
method: z.literal('dry-run'),
39+
messageId,
40+
requestId,
41+
params: queryOnlyParams,
42+
}).strict(),
43+
z.object({
44+
method: z.literal('meta'),
45+
messageId,
46+
requestId,
47+
params: z.object({}).strict().optional(),
48+
}).strict(),
49+
z.object({
50+
method: z.literal('subscribe'),
51+
messageId,
52+
requestId,
53+
params: queryParams,
54+
}).strict(),
55+
z.object({
56+
method: z.literal('unsubscribe'),
57+
messageId,
58+
requestId,
59+
params: z.object({}).strict().optional(),
60+
}).strict(),
61+
]);
62+
63+
// Export types
64+
export type AuthMessage = z.infer<typeof authMessageSchema>;
65+
export type UnsubscribeMessage = z.infer<typeof unsubscribeMessageSchema>;
66+
export type MethodMessage = z.infer<typeof methodMessageSchema>;
67+
export type WsMessage = AuthMessage | UnsubscribeMessage | MethodMessage;

0 commit comments

Comments
 (0)