Skip to content

Commit c5a995d

Browse files
authored
fix(WebSocket): await connection event listeners (#748)
1 parent 7039bae commit c5a995d

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

src/interceptors/WebSocket/index.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717
} from './WebSocketOverride'
1818
import { bindEvent } from './utils/bindEvent'
1919
import { hasConfigurableGlobal } from '../../utils/hasConfigurableGlobal'
20+
import { emitAsync } from '../../utils/emitAsync'
2021

2122
export { type WebSocketData, WebSocketTransport } from './WebSocketTransport'
2223
export {
@@ -102,18 +103,21 @@ export class WebSocketInterceptor extends Interceptor<WebSocketEventMap> {
102103
// Emit the "connection" event to the interceptor on the next tick
103104
// so the client can modify WebSocket options, like "binaryType"
104105
// while the connection is already pending.
105-
queueMicrotask(() => {
106+
queueMicrotask(async () => {
106107
try {
107108
const server = new WebSocketServerConnection(
108109
socket,
109110
transport,
110111
createConnection
111112
)
112113

114+
const hasConnectionListeners =
115+
this.emitter.listenerCount('connection') > 0
116+
113117
// The "globalThis.WebSocket" class stands for
114118
// the client-side connection. Assume it's established
115119
// as soon as the WebSocket instance is constructed.
116-
const hasConnectionListeners = this.emitter.emit('connection', {
120+
await emitAsync(this.emitter, 'connection', {
117121
client: new WebSocketClientConnection(socket, transport),
118122
server,
119123
info: {

test/modules/WebSocket/compliance/websocket.events.test.ts

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { DeferredPromise } from '@open-draft/deferred-promise'
88
import { WebSocketServer } from 'ws'
99
import { WebSocketInterceptor } from '../../../../src/interceptors/WebSocket'
1010
import { getWsUrl } from '../utils/getWsUrl'
11+
import { sleep } from '../../../helpers'
1112

1213
const wsServer = new WebSocketServer({
1314
host: '127.0.0.1',
@@ -21,6 +22,7 @@ beforeAll(() => {
2122
})
2223

2324
afterEach(() => {
25+
vi.restoreAllMocks()
2426
interceptor.removeAllListeners()
2527
wsServer.removeAllListeners()
2628
wsServer.clients.forEach((client) => client.close())
@@ -245,6 +247,79 @@ it('emits "error" event on passthrough client connection failure', async () => {
245247
expect(closeListener).toHaveBeenCalledOnce()
246248
})
247249

250+
it('allows erroring the connection in a synchronous listener', async () => {
251+
vi.spyOn(console, 'error').mockImplementation(() => {})
252+
253+
interceptor.once('connection', () => {
254+
throw new Error('mock error')
255+
})
256+
257+
const ws = new WebSocket('wss://localhost/non-existing-url')
258+
259+
const openListener = vi.fn()
260+
const errorListener = vi.fn()
261+
const closeListener = vi.fn()
262+
ws.onopen = openListener
263+
ws.onerror = errorListener
264+
ws.onclose = closeListener
265+
266+
await expect.poll(() => errorListener).toHaveBeenCalledTimes(1)
267+
expect(errorListener).toHaveBeenCalledWith(
268+
expect.objectContaining({
269+
type: 'error',
270+
})
271+
)
272+
273+
await expect.poll(() => ws.readyState).toBe(ws.CLOSED)
274+
expect(openListener).not.toHaveBeenCalled()
275+
expect(closeListener).toHaveBeenCalledOnce()
276+
expect(closeListener).toHaveBeenCalledWith(
277+
expect.objectContaining({
278+
type: 'close',
279+
code: 1011,
280+
reason: 'mock error',
281+
})
282+
)
283+
})
284+
285+
it('allows erroring the connection from an asynchronous listener', async ({
286+
onTestFinished,
287+
}) => {
288+
vi.spyOn(console, 'error').mockImplementation(() => {})
289+
290+
interceptor.once('connection', async () => {
291+
await sleep(200)
292+
throw new Error('mock error')
293+
})
294+
295+
const ws = new WebSocket('wss://localhost/non-existing-url')
296+
297+
const openListener = vi.fn()
298+
const errorListener = vi.fn()
299+
const closeListener = vi.fn()
300+
ws.onopen = openListener
301+
ws.onerror = errorListener
302+
ws.onclose = closeListener
303+
304+
await expect.poll(() => errorListener).toHaveBeenCalledTimes(1)
305+
expect(errorListener).toHaveBeenCalledWith(
306+
expect.objectContaining({
307+
type: 'error',
308+
})
309+
)
310+
311+
await expect.poll(() => ws.readyState).toBe(ws.CLOSED)
312+
expect(openListener).not.toHaveBeenCalled()
313+
expect(closeListener).toHaveBeenCalledOnce()
314+
expect(closeListener).toHaveBeenCalledWith(
315+
expect.objectContaining({
316+
type: 'close',
317+
code: 1011,
318+
reason: 'mock error',
319+
})
320+
)
321+
})
322+
248323
it('does not emit "error" event on mocked error code closures', async () => {
249324
interceptor.once('connection', ({ client }) => {
250325
/**

test/modules/WebSocket/utils/getWsUrl.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import { invariant } from 'outvariant'
12
import type { WebSocketServer } from 'ws'
23

34
export function getWsUrl(ws: WebSocketServer): string {
45
const address = ws.address()
6+
7+
invariant(address != null, 'Failed to get WebSocket address: address is null')
8+
59
if (typeof address === 'string') {
610
return address
711
}

0 commit comments

Comments
 (0)