@@ -21,6 +21,8 @@ import { messageSchema } from '../schemas/message-schema'
2121const debug = createLogger ( 'web-socket-adapter' )
2222const debugHeartbeat = debug . extend ( 'heartbeat' )
2323
24+ const abortableMessageHandlers : WeakMap < WebSocket , IAbortable [ ] > = new WeakMap ( )
25+
2426export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter {
2527 public clientId : string
2628 private clientAddress : string
@@ -33,23 +35,26 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
3335 private readonly webSocketServer : IWebSocketServerAdapter ,
3436 private readonly createMessageHandler : Factory < IMessageHandler , [ IncomingMessage , IWebSocketAdapter ] > ,
3537 private readonly slidingWindowRateLimiter : Factory < IRateLimiter > ,
36- private readonly settingsFactory : Factory < ISettings > ,
38+ private readonly settings : Factory < ISettings > ,
3739 ) {
3840 super ( )
3941 this . alive = true
4042 this . subscriptions = new Map ( )
4143
4244 this . clientId = Buffer . from ( this . request . headers [ 'sec-websocket-key' ] , 'base64' ) . toString ( 'hex' )
43- this . clientAddress = ( this . request . headers [ 'x-forwarded-for' ] ?? this . request . socket . remoteAddress ) as string
44-
45- debug ( 'client %s from address %s' , this . clientId , this . clientAddress )
45+ const remoteIpHeader = this . settings ( ) . network ?. remote_ip_header ?? 'x-forwarded-for'
46+ this . clientAddress = ( this . request . headers [ remoteIpHeader ] ?? this . request . socket . remoteAddress ) as string
4647
4748 this . client
4849 . on ( 'message' , this . onClientMessage . bind ( this ) )
4950 . on ( 'close' , this . onClientClose . bind ( this ) )
5051 . on ( 'pong' , this . onClientPong . bind ( this ) )
5152 . on ( 'error' , ( error ) => {
52- debug ( 'error' , error )
53+ if ( error . name === 'RangeError' && error . message === 'Max payload size exceeded' ) {
54+ debug ( 'client %s from %s sent payload too large' , this . clientId , this . clientAddress )
55+ } else {
56+ debug ( 'error' , error )
57+ }
5358 } )
5459
5560 this
@@ -60,7 +65,7 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
6065 . on ( WebSocketAdapterEvent . Broadcast , this . onBroadcast . bind ( this ) )
6166 . on ( WebSocketAdapterEvent . Message , this . sendMessage . bind ( this ) )
6267
63- debug ( 'client %s connected' , this . clientId )
68+ debug ( 'client %s connected from %s ' , this . clientId , this . clientAddress )
6469 }
6570
6671 public getClientId ( ) : string {
@@ -78,10 +83,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
7883 }
7984
8085 public onBroadcast ( event : Event ) : void {
81- debug ( 'client %s broadcast event: %o' , this . clientId , event )
8286 this . webSocketServer . emit ( WebSocketServerAdapterEvent . Broadcast , event )
8387 if ( cluster . isWorker ) {
84- debug ( 'client %s broadcast event to primary: %o' , this . clientId , event )
8588 process . send ( {
8689 eventName : WebSocketServerAdapterEvent . Broadcast ,
8790 event,
@@ -100,7 +103,6 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
100103 }
101104
102105 private sendMessage ( message : OutgoingMessage ) : void {
103- debug ( 'sending message to client %s: %o' , this . clientId , message )
104106 this . client . send ( JSON . stringify ( message ) )
105107 }
106108
@@ -127,7 +129,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
127129 }
128130
129131 private async onClientMessage ( raw : Buffer ) {
130- let abort : ( ) => void
132+ let abortable = false
133+ let messageHandler : IMessageHandler & IAbortable
131134 try {
132135 if ( await this . isRateLimited ( this . clientAddress ) ) {
133136 this . sendMessage ( createNoticeMessage ( 'rate limited' ) )
@@ -136,10 +139,13 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
136139
137140 const message = attemptValidation ( messageSchema ) ( JSON . parse ( raw . toString ( 'utf8' ) ) )
138141
139- const messageHandler = this . createMessageHandler ( [ message , this ] ) as IMessageHandler & IAbortable
140- if ( typeof messageHandler ?. abort === 'function' ) {
141- abort = messageHandler . abort . bind ( messageHandler )
142- this . client . prependOnceListener ( 'close' , abort )
142+ messageHandler = this . createMessageHandler ( [ message , this ] ) as IMessageHandler & IAbortable
143+ abortable = typeof messageHandler ?. abort === 'function'
144+
145+ if ( abortable ) {
146+ const handlers = abortableMessageHandlers . get ( this . client ) ?? [ ]
147+ handlers . push ( messageHandler )
148+ abortableMessageHandlers . set ( this . client , handlers )
143149 }
144150
145151 await messageHandler ?. handleMessage ( message )
@@ -150,11 +156,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
150156 debug ( 'invalid message: %o' , ( error as any ) . annotate ( ) )
151157 this . sendMessage ( createNoticeMessage ( `Invalid message: ${ error . message } ` ) )
152158 } else {
153- debug ( 'unable to handle message: %o ' , error )
159+ console . error ( 'unable to handle message' , error )
154160 }
155161 } finally {
156- if ( abort ) {
157- this . client . removeListener ( 'close' , abort )
162+ if ( abortable ) {
163+ const handlers = abortableMessageHandlers . get ( this . client )
164+ const index = handlers . indexOf ( messageHandler )
165+ if ( index >= 0 ) {
166+ handlers . splice ( index , 1 )
167+ }
158168 }
159169 }
160170 }
@@ -163,10 +173,9 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
163173 const {
164174 rateLimits,
165175 ipWhitelist = [ ] ,
166- } = this . settingsFactory ( ) . limits ?. message ?? { }
176+ } = this . settings ( ) . limits ?. message ?? { }
167177
168178 if ( ipWhitelist . includes ( client ) ) {
169- debug ( 'rate limit check %s: skipped' , client )
170179 return false
171180 }
172181
@@ -195,8 +204,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
195204 }
196205
197206 private onClientClose ( ) {
198- debug ( 'client %s closing' , this . clientId )
199207 this . alive = false
208+ this . subscriptions . clear ( )
209+
210+ const handlers = abortableMessageHandlers . get ( this . client )
211+ if ( Array . isArray ( handlers ) && handlers . length ) {
212+ for ( const handler of handlers ) {
213+ handler . abort ( )
214+ }
215+ }
200216
201217 this . removeAllListeners ( )
202218 this . client . removeAllListeners ( )
0 commit comments