Skip to content

Commit 184fd40

Browse files
authored
chore: validate socket.io subscription topics (#1144)
* feat: added validation check for subscriptions * fix: resolved export error * refactor: added validation in middleware * refactor: updated code according to suggestions * fix: updated code to return multiple invalid topics upon subscription * style: added space before comma-separated invalid topics
1 parent 20bf561 commit 184fd40

File tree

3 files changed

+143
-7
lines changed

3 files changed

+143
-7
lines changed

src/api/query-helpers.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { ClarityAbi } from '@stacks/transactions';
22
import { NextFunction, Request, Response } from 'express';
3-
import { has0xPrefix, hexToBuffer, isValidPrincipal, parseEventTypeStrings } from './../helpers';
3+
import { has0xPrefix, hexToBuffer, parseEventTypeStrings, isValidPrincipal } from './../helpers';
44
import { InvalidRequestError, InvalidRequestErrorType } from '../errors';
55
import { DbEventTypeId } from './../datastore/common';
66

@@ -286,3 +286,11 @@ export function parseEventTypeFilter(
286286

287287
return eventTypeFilter;
288288
}
289+
export function isValidTxId(tx_id: string) {
290+
try {
291+
validateRequestHexInput(tx_id);
292+
return true;
293+
} catch {
294+
return false;
295+
}
296+
}

src/api/routes/ws/socket-io.ts

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,42 @@ import {
1717
getTxFromDataStore,
1818
parseDbTx,
1919
} from '../../controllers/db-controller';
20-
import { isProdEnv, logError, logger } from '../../../helpers';
20+
import { isProdEnv, isValidPrincipal, logError, logger } from '../../../helpers';
2121
import { WebSocketPrometheus } from './metrics';
22+
import { isValidTxId } from '../../../api/query-helpers';
23+
24+
function getInvalidSubscriptionTopics(subscriptions: Topic | Topic[]): undefined | string[] {
25+
const isSubValid = (sub: Topic): undefined | string => {
26+
if (sub.includes(':')) {
27+
const txOrAddr = sub.split(':')[0];
28+
const value = sub.split(':')[1];
29+
switch (txOrAddr) {
30+
case 'address-transaction':
31+
case 'address-stx-balance':
32+
return isValidPrincipal(value) ? undefined : sub;
33+
case 'transaction':
34+
return isValidTxId(value) ? undefined : sub;
35+
default:
36+
return sub;
37+
}
38+
}
39+
switch (sub) {
40+
case 'block':
41+
case 'mempool':
42+
case 'microblock':
43+
return undefined;
44+
default:
45+
return sub;
46+
}
47+
};
48+
if (!Array.isArray(subscriptions)) {
49+
const invalidSub = isSubValid(subscriptions);
50+
return invalidSub ? [invalidSub] : undefined;
51+
}
52+
const validatedSubs = subscriptions.map(isSubValid);
53+
const invalidSubs = validatedSubs.filter(validSub => typeof validSub === 'string');
54+
return invalidSubs.length === 0 ? undefined : (invalidSubs as string[]);
55+
}
2256

2357
export function createSocketIORouter(db: DataStore, server: http.Server) {
2458
const io = new SocketIOServer<ClientToServerMessages, ServerToClientMessages>(server, {
@@ -38,7 +72,6 @@ export function createSocketIORouter(db: DataStore, server: http.Server) {
3872
}
3973
const subscriptions = socket.handshake.query['subscriptions'];
4074
if (subscriptions) {
41-
// TODO: check if init topics are valid, reject connection with error if not
4275
const topics = [...[subscriptions]].flat().flatMap(r => r.split(','));
4376
for (const topic of topics) {
4477
prometheus?.subscribe(socket, topic);
@@ -51,10 +84,11 @@ export function createSocketIORouter(db: DataStore, server: http.Server) {
5184
prometheus?.disconnect(socket);
5285
});
5386
socket.on('subscribe', async (topic, callback) => {
54-
prometheus?.subscribe(socket, topic);
55-
await socket.join(topic);
56-
// TODO: check if topic is valid, and return error message if not
57-
callback?.(null);
87+
if (!getInvalidSubscriptionTopics(topic)) {
88+
prometheus?.subscribe(socket, topic);
89+
await socket.join(topic);
90+
callback?.(null);
91+
}
5892
});
5993
socket.on('unsubscribe', async (...topics) => {
6094
for (const topic of topics) {
@@ -64,6 +98,23 @@ export function createSocketIORouter(db: DataStore, server: http.Server) {
6498
});
6599
});
66100

101+
// Middleware checks for the invalid topic subscriptions and terminates connection if found any
102+
io.use((socket, next) => {
103+
const subscriptions = socket.handshake.query['subscriptions'];
104+
if (subscriptions) {
105+
const topics = [...[subscriptions]].flat().flatMap(r => r.split(','));
106+
const invalidSubs = getInvalidSubscriptionTopics(topics as Topic[]);
107+
if (invalidSubs) {
108+
const error = new Error(`Invalid topic: ${invalidSubs.join(', ')}`);
109+
next(error);
110+
} else {
111+
next();
112+
}
113+
} else {
114+
next();
115+
}
116+
});
117+
67118
const adapter = io.of('/').adapter;
68119

69120
adapter.on('create-room', room => {

src/tests/socket-io-tests.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,83 @@ describe('socket-io', () => {
224224
}
225225
});
226226

227+
test('socket-io > invalid topic connection', async () => {
228+
const faultyAddr = 'faulty address';
229+
const address = apiServer.address;
230+
const socket = io(`http://${address}`, {
231+
reconnection: false,
232+
query: { subscriptions: `address-stx-balance:${faultyAddr}` },
233+
});
234+
const updateWaiter: Waiter<Error> = waiter();
235+
236+
socket.on(`connect_error`, err => {
237+
updateWaiter.finish(err);
238+
});
239+
240+
const result = await updateWaiter;
241+
try {
242+
throw result;
243+
} catch (err: any) {
244+
expect(err.message).toEqual(`Invalid topic: address-stx-balance:${faultyAddr}`);
245+
} finally {
246+
socket.close();
247+
}
248+
});
249+
250+
test('socket-io > multiple invalid topic connection', async () => {
251+
const faultyAddrStx = 'address-stx-balance:faulty address';
252+
const faultyTx = 'transaction:0x1';
253+
const address = apiServer.address;
254+
const socket = io(`http://${address}`, {
255+
reconnection: false,
256+
query: { subscriptions: `${faultyAddrStx},${faultyTx}` },
257+
});
258+
const updateWaiter: Waiter<Error> = waiter();
259+
260+
socket.on(`connect_error`, err => {
261+
updateWaiter.finish(err);
262+
});
263+
264+
const result = await updateWaiter;
265+
try {
266+
throw result;
267+
} catch (err: any) {
268+
expect(err.message).toEqual(`Invalid topic: ${faultyAddrStx}, ${faultyTx}`);
269+
} finally {
270+
socket.close();
271+
}
272+
});
273+
274+
test('socket-io > valid socket subscription', async () => {
275+
const address = apiServer.address;
276+
const socket = io(`http://${address}`, {
277+
reconnection: false,
278+
query: { subscriptions: '' },
279+
});
280+
const updateWaiter: Waiter<Block> = waiter();
281+
282+
socket.on('block', block => {
283+
updateWaiter.finish(block);
284+
});
285+
286+
socket.emit('subscribe', 'block');
287+
288+
const block = new TestBlockBuilder({ block_hash: '0x1234', burn_block_hash: '0x5454' })
289+
.addTx({ tx_id: '0x4321' })
290+
.build();
291+
await db.update(block);
292+
293+
const result = await updateWaiter;
294+
try {
295+
expect(result.hash).toEqual('0x1234');
296+
expect(result.burn_block_hash).toEqual('0x5454');
297+
expect(result.txs[0]).toEqual('0x4321');
298+
} finally {
299+
socket.emit('unsubscribe', 'block');
300+
socket.close();
301+
}
302+
});
303+
227304
afterEach(async () => {
228305
await apiServer.terminate();
229306
dbClient.release();

0 commit comments

Comments
 (0)