diff --git a/README.md b/README.md index 5502cfe..933c8f0 100644 --- a/README.md +++ b/README.md @@ -263,6 +263,47 @@ fastify.register(require('@fastify/websocket'), { }) ``` +### Custom upgrade request handling + +By default, `@fastify/websocket` handles upgrading incoming connections to the websocket protocol before handing off the handler you have defined. +If you wish to handle upgrade events yourself you can pass your own `handleUpgradeRequest` function: + +```js +const fastify = require('fastify')() + +fastify.register(require('@fastify/websocket')) + +fastify.register(async function () { + fastify.route({ + method: 'GET', + url: '/hello', + handleUpgradeRequest: (request, source, head) => { + // handle the FastifyRequest which has triggered an upgrade event + // throwing an error will abort the upgrade + // return a Promise for a Websocket to proceed + if (request.params.allow === "false") { + const error = new Error("Upgrade not allow") + error.statusCode = 403 + throw error + } else { + return new Promise((resolve) => { + fastify.websocketServer.handleUpgrade(request.raw, socket, head, (ws) => { + resolve(ws) + }) + }) + } + } + wsHandler: (socket, req) => { + socket.send('hello client') + + socket.once('message', chunk => { + socket.close() + }) + } + }) +}) +``` + ### Creating a stream from the WebSocket ```js diff --git a/index.js b/index.js index c76f1db..b6781b8 100644 --- a/index.js +++ b/index.js @@ -123,7 +123,8 @@ function fastifyWebsocket (fastify, opts, next) { } websocketListenServer.on('upgrade', onUpgrade) - const handleUpgrade = (rawRequest, callback) => { + const defaultHandleUpgrade = (request, _reply, callback) => { + const rawRequest = request.raw wss.handleUpgrade(rawRequest, rawRequest[kWs], rawRequest[kWsHead], (socket) => { wss.emit('connection', socket, rawRequest) @@ -155,6 +156,22 @@ function fastifyWebsocket (fastify, opts, next) { let isWebsocketRoute = false let wsHandler = routeOptions.wsHandler let handler = routeOptions.handler + const handleUpgrade = routeOptions.handleUpgradeRequest + ? (request, reply, callback) => { + const rawRequest = request.raw + routeOptions.handleUpgradeRequest(request, rawRequest[kWs], rawRequest[kWsHead]) + .then(socket => { + callback(socket) + }) + .catch(error => { + const ended = reply.raw.writableEnded || reply.raw.socket.writableEnded + if (!ended) { + reply.raw.statusCode = error.statusCode || 500 + reply.raw.end(error.message) + } + }) + } + : defaultHandleUpgrade if (routeOptions.websocket || routeOptions.wsHandler) { if (routeOptions.method === 'HEAD') { @@ -188,7 +205,7 @@ function fastifyWebsocket (fastify, opts, next) { // within the route handler, we check if there has been a connection upgrade by looking at request.raw[kWs]. we need to dispatch the normal HTTP handler if not, and hijack to dispatch the websocket handler if so if (request.raw[kWs]) { reply.hijack() - handleUpgrade(request.raw, socket => { + const onUpgrade = (socket) => { let result try { if (isWebsocketRoute) { @@ -203,7 +220,9 @@ function fastifyWebsocket (fastify, opts, next) { if (result && typeof result.catch === 'function') { result.catch(err => errorHandler.call(this, err, socket, request, reply)) } - }) + } + + handleUpgrade(request, reply, onUpgrade) } else { return handler.call(this, request, reply) } diff --git a/test/base.test.js b/test/base.test.js index 5fbb287..c22de63 100644 --- a/test/base.test.js +++ b/test/base.test.js @@ -661,3 +661,160 @@ test('clashing upgrade handler', async (t) => { const ws = new WebSocket('ws://localhost:' + fastify.server.address().port) await once(ws, 'error') }) + +test('Should handleUpgradeRequest successfully', async (t) => { + t.plan(4) + + const fastify = Fastify() + t.after(() => fastify.close()) + + await fastify.register(fastifyWebsocket) + + let customUpgradeCalled = false + + fastify.get('/', { + websocket: true, + handleUpgradeRequest: async (request, socket, head) => { + customUpgradeCalled = true + t.assert.equal(typeof socket, 'object', 'socket parameter is provided') + t.assert.equal(Buffer.isBuffer(head), true, 'head parameter is a buffer') + + return new Promise((resolve) => { + fastify.websocketServer.handleUpgrade(request.raw, socket, head, (ws) => { + resolve(ws) + }) + }) + } + }, (socket) => { + socket.on('message', (data) => { + socket.send(`echo: ${data}`) + }) + t.after(() => socket.terminate()) + }) + + await fastify.listen({ port: 0 }) + + const ws = new WebSocket('ws://localhost:' + fastify.server.address().port) + t.after(() => ws.close()) + + await once(ws, 'open') + ws.send('hello') + + const [message] = await once(ws, 'message') + t.assert.equal(message.toString(), 'echo: hello') + + t.assert.ok(customUpgradeCalled, 'handleUpgradeRequest was called') +}) + +test.only('Should handle errors thrown in handleUpgradeRequest', async (t) => { + t.plan(1) + + const fastify = Fastify() + t.after(() => fastify.close()) + + await fastify.register(fastifyWebsocket) + + fastify.get('/', { + websocket: true, + handleUpgradeRequest: async () => { + throw new Error('Custom upgrade error') + } + }, () => { + t.fail('websocket handler should not be called when upgrade fails') + }) + + await fastify.listen({ port: 0 }) + + const ws = new WebSocket('ws://localhost:' + fastify.server.address().port) + + let wsErrorResolved + const wsErrorPromise = new Promise((resolve) => { + wsErrorResolved = resolve + }) + + ws.on('error', (error) => { + wsErrorResolved(error) + }) + + const wsError = await wsErrorPromise + + t.assert.equal(wsError.message, 'Unexpected server response: 500') +}) + +test('Should allow for handleUpgradeRequest to send a response to the client before throwing an error', async (t) => { + t.plan(1) + + const fastify = Fastify() + t.after(() => fastify.close()) + + await fastify.register(fastifyWebsocket) + + fastify.get('/', { + websocket: true, + handleUpgradeRequest: async () => { + const error = new Error('Forbidden') + error.statusCode = 403 + throw error + } + }, () => { + t.fail('websocket handler should not be called when upgrade fails') + }) + + await fastify.listen({ port: 0 }) + + const ws = new WebSocket('ws://localhost:' + fastify.server.address().port) + + let wsErrorResolved + const wsErrorPromise = new Promise((resolve) => { + wsErrorResolved = resolve + }) + + ws.on('error', (error) => { + wsErrorResolved(error) + }) + + const wsError = await wsErrorPromise + + t.assert.equal(wsError.message, 'Unexpected server response: 403') +}) + +test('Should not send a response if handleUpgradeRequest has already ended the underlying socket and thrown an error', async (t) => { + t.plan(1) + + const fastify = Fastify() + t.after(() => fastify.close()) + + await fastify.register(fastifyWebsocket) + + fastify.get('/', { + websocket: true, + handleUpgradeRequest: async (request, socket, head) => { + socket.write('HTTP/1.1 400 Bad Request\r\n') + socket.write('Connection: closed\r\n') + socket.write('\r\n') + socket.end() + socket.destroy() + + throw new Error('thrown after response has ended') + } + }, () => { + t.fail('websocket handler should not be called when upgrade fails') + }) + + await fastify.listen({ port: 0 }) + + const ws = new WebSocket('ws://localhost:' + fastify.server.address().port) + + let wsErrorResolved + const wsErrorPromise = new Promise((resolve) => { + wsErrorResolved = resolve + }) + + ws.on('error', (error) => { + wsErrorResolved(error) + }) + + const wsError = await wsErrorPromise + + t.assert.equal(wsError.message, 'Unexpected server response: 400') +}) diff --git a/types/index.d.ts b/types/index.d.ts index 4ff7cab..ee83b45 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -5,6 +5,7 @@ import { preCloseAsyncHookHandler, preCloseHookHandler } from 'fastify/types/hoo import { FastifyReply } from 'fastify/types/reply' import { RouteGenericInterface } from 'fastify/types/route' import { IncomingMessage, Server, ServerResponse } from 'node:http' +import { Duplex } from 'node:stream' import * as WebSocket from 'ws' interface WebsocketRouteOptions< @@ -17,6 +18,7 @@ interface WebsocketRouteOptions< Logger extends FastifyBaseLogger = FastifyBaseLogger > { wsHandler?: fastifyWebsocket.WebsocketHandler; + handleUpgradeRequest?: (request: FastifyRequest, rawSocket: Duplex, socketHead: Buffer) => Promise; } declare module 'fastify' { diff --git a/types/index.test-d.ts b/types/index.test-d.ts index eeebc55..46b43ab 100644 --- a/types/index.test-d.ts +++ b/types/index.test-d.ts @@ -4,9 +4,10 @@ import fastify, { FastifyBaseLogger, FastifyInstance, FastifyReply, FastifyReque import { RouteGenericInterface } from 'fastify/types/route' import type { IncomingMessage } from 'node:http' import { expectType } from 'tsd' -import { Server } from 'ws' +import { Server, WebSocket as BaseWebSocket } from 'ws' // eslint-disable-next-line import-x/no-named-default -- Test default export import fastifyWebsocket, { default as defaultFastifyWebsocket, fastifyWebsocket as namedFastifyWebsocket, WebSocket, WebsocketHandler } from '..' +import { Duplex } from 'node:stream' const app: FastifyInstance = fastify() app.register(fastifyWebsocket) @@ -82,6 +83,27 @@ const augmentedRouteOptions: RouteOptions = { } app.route(augmentedRouteOptions) +const handleUpgradeRequestOptions: RouteOptions = { + method: 'GET', + url: '/route-with-handle-upgrade-request', + handler: (request, reply) => { + expectType(request) + expectType(reply) + }, + handleUpgradeRequest: (request, socket, head) => { + expectType(request) + expectType(socket) + expectType(head) + + return Promise.resolve(new BaseWebSocket('ws://localhost:8080')) + }, + wsHandler: (socket, request) => { + expectType(socket) + expectType>(request) + }, +} +app.route(handleUpgradeRequestOptions) + app.get<{ Params: { foo: string }, Body: { bar: string }, Querystring: { search: string }, Headers: { auth: string } }>('/shorthand-explicit-types', { websocket: true }, async (socket, request) => {