@@ -13,6 +13,8 @@ import { attemptValidation } from '../utils/validation'
1313import { createLogger } from '../factories/logger-factory'
1414import { Event } from '../@types/event'
1515import { Factory } from '../@types/base'
16+ import { IRateLimiter } from '../@types/utils'
17+ import { ISettings } from '../@types/settings'
1618import { isEventMatchingFilter } from '../utils/event'
1719import { messageSchema } from '../schemas/message-schema'
1820
@@ -21,7 +23,7 @@ const debugHeartbeat = debug.extend('heartbeat')
2123
2224export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter {
2325 public clientId : string
24- // private clientAddress: string
26+ private clientAddress : string
2527 private alive : boolean
2628 private subscriptions : Map < SubscriptionId , SubscriptionFilter [ ] >
2729
@@ -30,13 +32,17 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
3032 private readonly request : IncomingHttpMessage ,
3133 private readonly webSocketServer : IWebSocketServerAdapter ,
3234 private readonly createMessageHandler : Factory < IMessageHandler , [ IncomingMessage , IWebSocketAdapter ] > ,
35+ private readonly slidingWindowRateLimiter : Factory < IRateLimiter > ,
36+ private readonly settingsFactory : Factory < ISettings > ,
3337 ) {
3438 super ( )
3539 this . alive = true
3640 this . subscriptions = new Map ( )
3741
3842 this . clientId = Buffer . from ( this . request . headers [ 'sec-websocket-key' ] , 'base64' ) . toString ( 'hex' )
39- // this.clientAddress = this.request.headers['x-forwarded-for'] as string
43+ this . clientAddress = ( this . request . headers [ 'x-forwarded-for' ] ?? this . request . socket . remoteAddress ) as string
44+
45+ debug ( 'client %s from address %s' , this . clientId , this . clientAddress )
4046
4147 this . client
4248 . on ( 'message' , this . onClientMessage . bind ( this ) )
@@ -120,10 +126,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
120126 private async onClientMessage ( raw : Buffer ) {
121127 let abort : ( ) => void
122128 try {
129+ if ( await this . isRateLimited ( this . clientAddress ) ) {
130+ this . sendMessage ( createNoticeMessage ( 'rate limited' ) )
131+ return
132+ }
133+
123134 const message = attemptValidation ( messageSchema ) ( JSON . parse ( raw . toString ( 'utf8' ) ) )
124135
125136 const messageHandler = this . createMessageHandler ( [ message , this ] ) as IMessageHandler & IAbortable
126- if ( typeof messageHandler . abort === 'function' ) {
137+ if ( typeof messageHandler ? .abort === 'function' ) {
127138 abort = messageHandler . abort . bind ( messageHandler )
128139 this . client . prependOnceListener ( 'close' , abort )
129140 }
@@ -145,6 +156,36 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
145156 }
146157 }
147158
159+ private async isRateLimited ( client : string ) : Promise < boolean > {
160+ const {
161+ rateLimits,
162+ ipWhitelist = [ ] ,
163+ } = this . settingsFactory ( ) . limits ?. message ?? { }
164+
165+ if ( ipWhitelist . includes ( client ) ) {
166+ debug ( 'rate limit check %s: skipped' , client )
167+ return false
168+ }
169+
170+ const rateLimiter = this . slidingWindowRateLimiter ( )
171+
172+ const hit = ( period : number , rate : number ) =>
173+ rateLimiter . hit (
174+ `${ client } :message:${ period } ` ,
175+ 1 ,
176+ { period : period , rate : rate } ,
177+ )
178+
179+ const hits = await Promise . all (
180+ rateLimits
181+ . map ( ( { period, rate } ) => hit ( period , rate ) )
182+ )
183+
184+ debug ( 'rate limit check %s: %o = %o' , client , rateLimits . map ( ( { period } ) => period ) , hits )
185+
186+ return hits . some ( ( thresholdCrossed ) => thresholdCrossed )
187+ }
188+
148189 private onClientPong ( ) {
149190 debugHeartbeat ( 'client %s pong' , this . clientId )
150191 this . alive = true
0 commit comments