Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,12 @@ To enable the feature, set the `wsReconnect` option to an object with the follow
- `connectionTimeout`: The timeout for establishing the connection in ms (default: `5_000`).
- `reconnectOnClose`: Whether to reconnect on close, as long as the connection from the related client to the proxy is active (default: `false`).
- `logs`: Whether to log the reconnection process (default: `false`).
- `onReconnect`: A hook function that is called when the connection is reconnected `async onReconnect(oldSocket, newSocket)` (default: `undefined`).

## wsHooks

- `onTargetRequest`: A hook function that is called when the request is received from the client `async onTargetRequest({ data, binary })` (default: `undefined`).
- `onTargetResponse`: A hook function that is called when the response is received from the target `async onTargetResponse({ data, binary })` (default: `undefined`).
- `onReconnect`: A hook function that is called when the connection is reconnected `async onReconnect(source, target)` (default: `undefined`).

## Benchmarks

Expand Down
78 changes: 58 additions & 20 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,22 @@ function isExternalUrl (url) {

function noop () { }

function proxyWebSockets (source, target) {
function proxyWebSockets (logger, source, target, hooks) {
function close (code, reason) {
closeWebSocket(source, code, reason)
closeWebSocket(target, code, reason)
}

source.on('message', (data, binary) => waitConnection(target, () => target.send(data, { binary })))
source.on('message', async (data, binary) => {
if (hooks.onTargetRequest) {
try {
await hooks.onTargetRequest({ data, binary })
} catch (err) {
logger.error({ err }, 'proxy ws error from onTargetRequest hook')
}
}
waitConnection(target, () => target.send(data, { binary }))
})
/* c8 ignore start */
source.on('ping', data => waitConnection(target, () => target.ping(data)))
source.on('pong', data => waitConnection(target, () => target.pong(data)))
Expand All @@ -100,7 +109,16 @@ function proxyWebSockets (source, target) {
/* c8 ignore stop */

// source WebSocket is already connected because it is created by ws server
target.on('message', (data, binary) => source.send(data, { binary }))
target.on('message', async (data, binary) => {
if (hooks.onTargetResponse) {
try {
await hooks.onTargetResponse({ data, binary })
} catch (err) {
logger.error({ err }, 'proxy ws error from onTargetResponse hook')
}
}
source.send(data, { binary })
})
/* c8 ignore start */
target.on('ping', data => source.ping(data))
/* c8 ignore stop */
Expand All @@ -112,37 +130,43 @@ function proxyWebSockets (source, target) {
/* c8 ignore stop */
}

async function reconnect (logger, source, wsReconnectOptions, oldTarget, targetParams) {
async function reconnect (logger, source, reconnectOptions, hooks, targetParams) {
const { url, subprotocols, optionsWs } = targetParams

let attempts = 0
let target
do {
const reconnectWait = wsReconnectOptions.reconnectInterval * (wsReconnectOptions.reconnectDecay * attempts || 1)
wsReconnectOptions.logs && logger.warn({ target: targetParams.url }, `proxy ws reconnect in ${reconnectWait} ms`)
const reconnectWait = reconnectOptions.reconnectInterval * (reconnectOptions.reconnectDecay * attempts || 1)
reconnectOptions.logs && logger.warn({ target: targetParams.url }, `proxy ws reconnect in ${reconnectWait} ms`)
await wait(reconnectWait)

try {
target = new WebSocket(url, subprotocols, optionsWs)
await waitForConnection(target, wsReconnectOptions.connectionTimeout)
await waitForConnection(target, reconnectOptions.connectionTimeout)
} catch (err) {
wsReconnectOptions.logs && logger.error({ target: targetParams.url, err, attempts }, 'proxy ws reconnect error')
reconnectOptions.logs && logger.error({ target: targetParams.url, err, attempts }, 'proxy ws reconnect error')
attempts++
target = undefined
}
} while (!target && attempts < wsReconnectOptions.maxReconnectionRetries)
} while (!target && attempts < reconnectOptions.maxReconnectionRetries)

if (!target) {
logger.error({ target: targetParams.url, attempts }, 'proxy ws failed to reconnect! No more retries')
return
}

wsReconnectOptions.logs && logger.info({ target: targetParams.url, attempts }, 'proxy ws reconnected')
await wsReconnectOptions.onReconnect(oldTarget, target)
proxyWebSocketsWithReconnection(logger, source, target, wsReconnectOptions, targetParams)
reconnectOptions.logs && logger.info({ target: targetParams.url, attempts }, 'proxy ws reconnected')
if (hooks.onReconnect) {
try {
await hooks.onReconnect(source, target)
} catch (err) {
reconnectOptions.logs && logger.error({ target: targetParams.url, err }, 'proxy ws error from onReconnect hook')
}
}
proxyWebSocketsWithReconnection(logger, source, target, reconnectOptions, hooks, targetParams)
}

function proxyWebSocketsWithReconnection (logger, source, target, options, targetParams) {
function proxyWebSocketsWithReconnection (logger, source, target, options, hooks, targetParams) {
function close (code, reason) {
target.pingTimer && clearTimeout(source.pingTimer)
target.pingTimer = undefined
Expand All @@ -155,7 +179,7 @@ function proxyWebSocketsWithReconnection (logger, source, target, options, targe
// need to specify the listeners to remove
removeSourceListeners(source)

reconnect(logger, source, options, target, targetParams)
reconnect(logger, source, options, hooks, targetParams)
return
}

Expand All @@ -174,8 +198,15 @@ function proxyWebSocketsWithReconnection (logger, source, target, options, targe
}

/* c8 ignore start */
function sourceOnMessage (data, binary) {
async function sourceOnMessage (data, binary) {
source.isAlive = true
if (hooks.onTargetRequest) {
try {
await hooks.onTargetRequest({ data, binary })
} catch (err) {
logger.error({ target: targetParams.url, err }, 'proxy ws error from onTargetRequest hook')
}
}
waitConnection(target, () => target.send(data, { binary }))
}
function sourceOnPing (data) {
Expand Down Expand Up @@ -211,8 +242,15 @@ function proxyWebSocketsWithReconnection (logger, source, target, options, targe

// source WebSocket is already connected because it is created by ws server
/* c8 ignore start */
target.on('message', (data, binary) => {
target.on('message', async (data, binary) => {
target.isAlive = true
if (hooks.onTargetResponse) {
try {
await hooks.onTargetResponse({ data, binary })
} catch (err) {
logger.error({ target: targetParams.url, err }, 'proxy ws error from onTargetResponse hook')
}
}
source.send(data, { binary })
})
target.on('ping', data => {
Expand Down Expand Up @@ -268,7 +306,7 @@ function handleUpgrade (fastify, rawRequest, socket, head) {
}

class WebSocketProxy {
constructor (fastify, { wsReconnect, wsServerOptions, wsClientOptions, upstream, wsUpstream, replyOptions: { getUpstream } = {} }) {
constructor (fastify, { wsReconnect, wsHooks, wsServerOptions, wsClientOptions, upstream, wsUpstream, replyOptions: { getUpstream } = {} }) {
this.logger = fastify.log
this.wsClientOptions = {
rewriteRequestHeaders: defaultWsHeadersRewrite,
Expand All @@ -279,7 +317,7 @@ class WebSocketProxy {
this.wsUpstream = wsUpstream ? convertUrlToWebSocket(wsUpstream) : ''
this.getUpstream = getUpstream
this.wsReconnect = wsReconnect

this.wsHooks = wsHooks
const wss = new WebSocket.Server({
noServer: true,
...wsServerOptions
Expand Down Expand Up @@ -371,9 +409,9 @@ class WebSocketProxy {

if (this.wsReconnect) {
const targetParams = { url, subprotocols, optionsWs }
proxyWebSocketsWithReconnection(this.logger, source, target, this.wsReconnect, targetParams)
proxyWebSocketsWithReconnection(this.logger, source, target, this.wsReconnect, this.wsHooks, targetParams)
} else {
proxyWebSockets(source, target)
proxyWebSockets(this.logger, source, target, this.wsHooks)
}
}
}
Expand Down
27 changes: 20 additions & 7 deletions src/options.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ const DEFAULT_RECONNECT_DECAY = 1.5
const DEFAULT_CONNECTION_TIMEOUT = 5_000
const DEFAULT_RECONNECT_ON_CLOSE = false
const DEFAULT_LOGS = false
const DEFAULT_ON_RECONNECT = noop

function noop () {}

function validateOptions (options) {
if (!options.upstream && !options.websocket && !((options.upstream === '' || options.wsUpstream === '') && options.replyOptions && typeof options.replyOptions.getUpstream === 'function')) {
Expand Down Expand Up @@ -53,11 +50,28 @@ function validateOptions (options) {
throw new Error('wsReconnect.logs must be a boolean')
}
wsReconnect.logs = wsReconnect.logs ?? DEFAULT_LOGS
}

if (options.wsHooks) {
const wsHooks = options.wsHooks

if (wsHooks.onReconnect !== undefined && typeof wsHooks.onReconnect !== 'function') {
throw new Error('wsHooks.onReconnect must be a function')
}

if (wsHooks.onTargetRequest !== undefined && typeof wsHooks.onTargetRequest !== 'function') {
throw new Error('wsHooks.onTargetRequest must be a function')
}

if (wsReconnect.onReconnect !== undefined && typeof wsReconnect.onReconnect !== 'function') {
throw new Error('wsReconnect.onReconnect must be a function')
if (wsHooks.onTargetResponse !== undefined && typeof wsHooks.onTargetResponse !== 'function') {
throw new Error('wsHooks.onTargetResponse must be a function')
}
} else {
options.wsHooks = {
onReconnect: undefined,
onTargetRequest: undefined,
onTargetResponse: undefined
}
wsReconnect.onReconnect = wsReconnect.onReconnect ?? DEFAULT_ON_RECONNECT
}

return options
Expand All @@ -72,5 +86,4 @@ module.exports = {
DEFAULT_CONNECTION_TIMEOUT,
DEFAULT_RECONNECT_ON_CLOSE,
DEFAULT_LOGS,
DEFAULT_ON_RECONNECT
}
83 changes: 83 additions & 0 deletions test/helper/helper.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
'use strict'

const { createServer } = require('node:http')
const { promisify } = require('node:util')
const { once } = require('node:events')
const Fastify = require('fastify')
const WebSocket = require('ws')
const pinoTest = require('pino-test')
const pino = require('pino')
const proxyPlugin = require('../../')

function waitForLogMessage (loggerSpy, message, max = 100) {
return new Promise((resolve, reject) => {
let count = 0
const fn = (received) => {
if (received.msg === message) {
loggerSpy.off('data', fn)
resolve()
}
count++
if (count > max) {
loggerSpy.off('data', fn)
reject(new Error(`Max message count reached on waitForLogMessage: ${message}`))
}
}
loggerSpy.on('data', fn)
})
}

async function createTargetServer (t, wsTargetOptions, port = 0) {
const targetServer = createServer()
const targetWs = new WebSocket.Server({ server: targetServer, ...wsTargetOptions })
await promisify(targetServer.listen.bind(targetServer))({ port, host: '127.0.0.1' })

t.after(() => {
targetWs.close()
targetServer.close()
})

return { targetServer, targetWs }
}

async function createServices ({ t, wsReconnectOptions, wsTargetOptions, wsServerOptions, wsHooks, targetPort = 0 }) {
const { targetServer, targetWs } = await createTargetServer(t, wsTargetOptions, targetPort)

const loggerSpy = pinoTest.sink()
const logger = pino(loggerSpy)
const proxy = Fastify({ loggerInstance: logger })
proxy.register(proxyPlugin, {
upstream: `ws://127.0.0.1:${targetServer.address().port}`,
websocket: true,
wsReconnect: wsReconnectOptions,
wsServerOptions,
wsHooks
})

await proxy.listen({ port: 0, host: '127.0.0.1' })

const client = new WebSocket(`ws://127.0.0.1:${proxy.server.address().port}`)
await once(client, 'open')

t.after(async () => {
client.close()
await proxy.close()
})

return {
target: {
ws: targetWs,
server: targetServer
},
proxy,
client,
loggerSpy,
logger
}
}

module.exports = {
waitForLogMessage,
createTargetServer,
createServices
}
26 changes: 21 additions & 5 deletions test/options.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
const { test } = require('node:test')
const assert = require('node:assert')
const { validateOptions } = require('../src/options')
const { DEFAULT_PING_INTERVAL, DEFAULT_MAX_RECONNECTION_RETRIES, DEFAULT_RECONNECT_INTERVAL, DEFAULT_RECONNECT_DECAY, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_RECONNECT_ON_CLOSE, DEFAULT_LOGS, DEFAULT_ON_RECONNECT } = require('../src/options')
const {
DEFAULT_PING_INTERVAL, DEFAULT_MAX_RECONNECTION_RETRIES, DEFAULT_RECONNECT_INTERVAL, DEFAULT_RECONNECT_DECAY, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_RECONNECT_ON_CLOSE, DEFAULT_LOGS
} = require('../src/options')

test('validateOptions', (t) => {
const requiredOptions = {
Expand Down Expand Up @@ -41,8 +43,14 @@ test('validateOptions', (t) => {
assert.throws(() => validateOptions({ ...requiredOptions, wsReconnect: { logs: '1' } }), /wsReconnect.logs must be a boolean/)
assert.doesNotThrow(() => validateOptions({ ...requiredOptions, wsReconnect: { logs: true } }))

assert.throws(() => validateOptions({ ...requiredOptions, wsReconnect: { onReconnect: '1' } }), /wsReconnect.onReconnect must be a function/)
assert.doesNotThrow(() => validateOptions({ ...requiredOptions, wsReconnect: { onReconnect: () => { } } }))
assert.throws(() => validateOptions({ ...requiredOptions, wsHooks: { onReconnect: '1' } }), /wsHooks.onReconnect must be a function/)
assert.doesNotThrow(() => validateOptions({ ...requiredOptions, wsHooks: { onReconnect: () => { } } }))

assert.throws(() => validateOptions({ ...requiredOptions, wsHooks: { onTargetRequest: '1' } }), /wsHooks.onTargetRequest must be a function/)
assert.doesNotThrow(() => validateOptions({ ...requiredOptions, wsHooks: { onTargetRequest: () => { } } }))

assert.throws(() => validateOptions({ ...requiredOptions, wsHooks: { onTargetResponse: '1' } }), /wsHooks.onTargetResponse must be a function/)
assert.doesNotThrow(() => validateOptions({ ...requiredOptions, wsHooks: { onTargetResponse: () => { } } }))

// set all values
assert.doesNotThrow(() => validateOptions({
Expand All @@ -55,7 +63,11 @@ test('validateOptions', (t) => {
connectionTimeout: 1,
reconnectOnClose: true,
logs: true,
onReconnect: () => { }
},
wsHooks: {
onReconnect: () => { },
onTargetRequest: () => { },
onTargetResponse: () => { }
}
}))

Expand All @@ -70,7 +82,11 @@ test('validateOptions', (t) => {
connectionTimeout: DEFAULT_CONNECTION_TIMEOUT,
reconnectOnClose: DEFAULT_RECONNECT_ON_CLOSE,
logs: DEFAULT_LOGS,
onReconnect: DEFAULT_ON_RECONNECT
},
wsHooks: {
onReconnect: undefined,
onTargetRequest: undefined,
onTargetResponse: undefined
}
})
})
Loading
Loading