Skip to content

Commit 7469d97

Browse files
committed
fix: abort message handler using listeners
1 parent 4c038a6 commit 7469d97

File tree

3 files changed

+39
-23
lines changed

3 files changed

+39
-23
lines changed

src/adapters/web-socket-adapter.ts

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import { messageSchema } from '../schemas/message-schema'
2121
const debug = createLogger('web-socket-adapter')
2222
const debugHeartbeat = debug.extend('heartbeat')
2323

24+
const abortableMessageHandlers: WeakMap<WebSocket, IAbortable[]> = new WeakMap()
25+
2426
export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter {
2527
public clientId: string
2628
private clientAddress: string
@@ -33,23 +35,26 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
3335
private readonly webSocketServer: IWebSocketServerAdapter,
3436
private readonly createMessageHandler: Factory<IMessageHandler, [IncomingMessage, IWebSocketAdapter]>,
3537
private readonly slidingWindowRateLimiter: Factory<IRateLimiter>,
36-
private readonly settingsFactory: Factory<ISettings>,
38+
private readonly settings: Factory<ISettings>,
3739
) {
3840
super()
3941
this.alive = true
4042
this.subscriptions = new Map()
4143

4244
this.clientId = Buffer.from(this.request.headers['sec-websocket-key'], 'base64').toString('hex')
43-
this.clientAddress = (this.request.headers['x-forwarded-for'] ?? this.request.socket.remoteAddress) as string
44-
45-
debug('client %s from address %s', this.clientId, this.clientAddress)
45+
const remoteIpHeader = this.settings().network?.remote_ip_header ?? 'x-forwarded-for'
46+
this.clientAddress = (this.request.headers[remoteIpHeader] ?? this.request.socket.remoteAddress) as string
4647

4748
this.client
4849
.on('message', this.onClientMessage.bind(this))
4950
.on('close', this.onClientClose.bind(this))
5051
.on('pong', this.onClientPong.bind(this))
5152
.on('error', (error) => {
52-
debug('error', error)
53+
if (error.name === 'RangeError' && error.message === 'Max payload size exceeded') {
54+
debug('client %s from %s sent payload too large', this.clientId, this.clientAddress)
55+
} else {
56+
debug('error', error)
57+
}
5358
})
5459

5560
this
@@ -60,7 +65,7 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
6065
.on(WebSocketAdapterEvent.Broadcast, this.onBroadcast.bind(this))
6166
.on(WebSocketAdapterEvent.Message, this.sendMessage.bind(this))
6267

63-
debug('client %s connected', this.clientId)
68+
debug('client %s connected from %s', this.clientId, this.clientAddress)
6469
}
6570

6671
public getClientId(): string {
@@ -78,10 +83,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
7883
}
7984

8085
public onBroadcast(event: Event): void {
81-
debug('client %s broadcast event: %o', this.clientId, event)
8286
this.webSocketServer.emit(WebSocketServerAdapterEvent.Broadcast, event)
8387
if (cluster.isWorker) {
84-
debug('client %s broadcast event to primary: %o', this.clientId, event)
8588
process.send({
8689
eventName: WebSocketServerAdapterEvent.Broadcast,
8790
event,
@@ -100,7 +103,6 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
100103
}
101104

102105
private sendMessage(message: OutgoingMessage): void {
103-
debug('sending message to client %s: %o', this.clientId, message)
104106
this.client.send(JSON.stringify(message))
105107
}
106108

@@ -127,7 +129,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
127129
}
128130

129131
private async onClientMessage(raw: Buffer) {
130-
let abort: () => void
132+
let abortable = false
133+
let messageHandler: IMessageHandler & IAbortable
131134
try {
132135
if (await this.isRateLimited(this.clientAddress)) {
133136
this.sendMessage(createNoticeMessage('rate limited'))
@@ -136,10 +139,13 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
136139

137140
const message = attemptValidation(messageSchema)(JSON.parse(raw.toString('utf8')))
138141

139-
const messageHandler = this.createMessageHandler([message, this]) as IMessageHandler & IAbortable
140-
if (typeof messageHandler?.abort === 'function') {
141-
abort = messageHandler.abort.bind(messageHandler)
142-
this.client.prependOnceListener('close', abort)
142+
messageHandler = this.createMessageHandler([message, this]) as IMessageHandler & IAbortable
143+
abortable = typeof messageHandler?.abort === 'function'
144+
145+
if (abortable) {
146+
const handlers = abortableMessageHandlers.get(this.client) ?? []
147+
handlers.push(messageHandler)
148+
abortableMessageHandlers.set(this.client, handlers)
143149
}
144150

145151
await messageHandler?.handleMessage(message)
@@ -150,11 +156,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
150156
debug('invalid message: %o', (error as any).annotate())
151157
this.sendMessage(createNoticeMessage(`Invalid message: ${error.message}`))
152158
} else {
153-
debug('unable to handle message: %o', error)
159+
console.error('unable to handle message', error)
154160
}
155161
} finally {
156-
if (abort) {
157-
this.client.removeListener('close', abort)
162+
if (abortable) {
163+
const handlers = abortableMessageHandlers.get(this.client)
164+
const index = handlers.indexOf(messageHandler)
165+
if (index >= 0) {
166+
handlers.splice(index, 1)
167+
}
158168
}
159169
}
160170
}
@@ -163,10 +173,9 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
163173
const {
164174
rateLimits,
165175
ipWhitelist = [],
166-
} = this.settingsFactory().limits?.message ?? {}
176+
} = this.settings().limits?.message ?? {}
167177

168178
if (ipWhitelist.includes(client)) {
169-
debug('rate limit check %s: skipped', client)
170179
return false
171180
}
172181

@@ -195,8 +204,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
195204
}
196205

197206
private onClientClose() {
198-
debug('client %s closing', this.clientId)
199207
this.alive = false
208+
this.subscriptions.clear()
209+
210+
const handlers = abortableMessageHandlers.get(this.client)
211+
if (Array.isArray(handlers) && handlers.length) {
212+
for (const handler of handlers) {
213+
handler.abort()
214+
}
215+
}
200216

201217
this.removeAllListeners()
202218
this.client.removeAllListeners()

src/factories/worker-factory.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export const workerFactory = (): AppWorker => {
1717
const server = http.createServer()
1818
const webSocketServer = new WebSocketServer({
1919
server,
20-
maxPayload: 131072, // 128 kB
20+
maxPayload: createSettings().network?.max_payload_size ?? 131072, // 128 kB
2121
})
2222
const adapter = new WebSocketServerAdapter(
2323
server,

src/handlers/subscribe-message-handler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ export class SubscribeMessageHandler implements IMessageHandler, IAbortable {
7070
)
7171
} catch (error) {
7272
if (error instanceof Error && error.name === 'AbortError') {
73-
debug('aborted: %o', error)
74-
findEvents.end()
73+
debug('subscription aborted: %o', error)
74+
findEvents.destroy()
7575
} else {
7676
debug('error streaming events: %o', error)
7777
}

0 commit comments

Comments
 (0)