@@ -17,8 +17,42 @@ import {
17
17
getTxFromDataStore ,
18
18
parseDbTx ,
19
19
} from '../../controllers/db-controller' ;
20
- import { isProdEnv , logError , logger } from '../../../helpers' ;
20
+ import { isProdEnv , isValidPrincipal , logError , logger } from '../../../helpers' ;
21
21
import { WebSocketPrometheus } from './metrics' ;
22
+ import { isValidTxId } from '../../../api/query-helpers' ;
23
+
24
+ function getInvalidSubscriptionTopics ( subscriptions : Topic | Topic [ ] ) : undefined | string [ ] {
25
+ const isSubValid = ( sub : Topic ) : undefined | string => {
26
+ if ( sub . includes ( ':' ) ) {
27
+ const txOrAddr = sub . split ( ':' ) [ 0 ] ;
28
+ const value = sub . split ( ':' ) [ 1 ] ;
29
+ switch ( txOrAddr ) {
30
+ case 'address-transaction' :
31
+ case 'address-stx-balance' :
32
+ return isValidPrincipal ( value ) ? undefined : sub ;
33
+ case 'transaction' :
34
+ return isValidTxId ( value ) ? undefined : sub ;
35
+ default :
36
+ return sub ;
37
+ }
38
+ }
39
+ switch ( sub ) {
40
+ case 'block' :
41
+ case 'mempool' :
42
+ case 'microblock' :
43
+ return undefined ;
44
+ default :
45
+ return sub ;
46
+ }
47
+ } ;
48
+ if ( ! Array . isArray ( subscriptions ) ) {
49
+ const invalidSub = isSubValid ( subscriptions ) ;
50
+ return invalidSub ? [ invalidSub ] : undefined ;
51
+ }
52
+ const validatedSubs = subscriptions . map ( isSubValid ) ;
53
+ const invalidSubs = validatedSubs . filter ( validSub => typeof validSub === 'string' ) ;
54
+ return invalidSubs . length === 0 ? undefined : ( invalidSubs as string [ ] ) ;
55
+ }
22
56
23
57
export function createSocketIORouter ( db : DataStore , server : http . Server ) {
24
58
const io = new SocketIOServer < ClientToServerMessages , ServerToClientMessages > ( server , {
@@ -38,7 +72,6 @@ export function createSocketIORouter(db: DataStore, server: http.Server) {
38
72
}
39
73
const subscriptions = socket . handshake . query [ 'subscriptions' ] ;
40
74
if ( subscriptions ) {
41
- // TODO: check if init topics are valid, reject connection with error if not
42
75
const topics = [ ...[ subscriptions ] ] . flat ( ) . flatMap ( r => r . split ( ',' ) ) ;
43
76
for ( const topic of topics ) {
44
77
prometheus ?. subscribe ( socket , topic ) ;
@@ -51,10 +84,11 @@ export function createSocketIORouter(db: DataStore, server: http.Server) {
51
84
prometheus ?. disconnect ( socket ) ;
52
85
} ) ;
53
86
socket . on ( 'subscribe' , async ( topic , callback ) => {
54
- prometheus ?. subscribe ( socket , topic ) ;
55
- await socket . join ( topic ) ;
56
- // TODO: check if topic is valid, and return error message if not
57
- callback ?.( null ) ;
87
+ if ( ! getInvalidSubscriptionTopics ( topic ) ) {
88
+ prometheus ?. subscribe ( socket , topic ) ;
89
+ await socket . join ( topic ) ;
90
+ callback ?.( null ) ;
91
+ }
58
92
} ) ;
59
93
socket . on ( 'unsubscribe' , async ( ...topics ) => {
60
94
for ( const topic of topics ) {
@@ -64,6 +98,23 @@ export function createSocketIORouter(db: DataStore, server: http.Server) {
64
98
} ) ;
65
99
} ) ;
66
100
101
+ // Middleware checks for the invalid topic subscriptions and terminates connection if found any
102
+ io . use ( ( socket , next ) => {
103
+ const subscriptions = socket . handshake . query [ 'subscriptions' ] ;
104
+ if ( subscriptions ) {
105
+ const topics = [ ...[ subscriptions ] ] . flat ( ) . flatMap ( r => r . split ( ',' ) ) ;
106
+ const invalidSubs = getInvalidSubscriptionTopics ( topics as Topic [ ] ) ;
107
+ if ( invalidSubs ) {
108
+ const error = new Error ( `Invalid topic: ${ invalidSubs . join ( ', ' ) } ` ) ;
109
+ next ( error ) ;
110
+ } else {
111
+ next ( ) ;
112
+ }
113
+ } else {
114
+ next ( ) ;
115
+ }
116
+ } ) ;
117
+
67
118
const adapter = io . of ( '/' ) . adapter ;
68
119
69
120
adapter . on ( 'create-room' , room => {
0 commit comments