Skip to content

feat: add handleUpgradeRequest route option #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,47 @@ fastify.register(require('@fastify/websocket'), {
})
```

### Custom upgrade request handling

By default, `@fastify/websocket` handles upgrading incoming connections to the websocket protocol before handing off the handler you have defined.
If you wish to handle upgrade events yourself you can pass your own `handleUpgradeRequest` function:

```js
const fastify = require('fastify')()

fastify.register(require('@fastify/websocket'))

fastify.register(async function () {
fastify.route({
method: 'GET',
url: '/hello',
handleUpgradeRequest: (request, source, head) => {
// handle the FastifyRequest which has triggered an upgrade event
// throwing an error will abort the upgrade
// return a Promise for a Websocket to proceed
if (request.params.allow === "false") {
const error = new Error("Upgrade not allow")
error.statusCode = 403
throw error
} else {
return new Promise((resolve) => {
fastify.websocketServer.handleUpgrade(request.raw, socket, head, (ws) => {
resolve(ws)
})
})
}
}
wsHandler: (socket, req) => {
socket.send('hello client')

socket.once('message', chunk => {
socket.close()
})
}
})
})
```

### Creating a stream from the WebSocket

```js
Expand Down
25 changes: 22 additions & 3 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ function fastifyWebsocket (fastify, opts, next) {
}
websocketListenServer.on('upgrade', onUpgrade)

const handleUpgrade = (rawRequest, callback) => {
const defaultHandleUpgrade = (request, _reply, callback) => {
const rawRequest = request.raw
wss.handleUpgrade(rawRequest, rawRequest[kWs], rawRequest[kWsHead], (socket) => {
wss.emit('connection', socket, rawRequest)

Expand Down Expand Up @@ -155,6 +156,22 @@ function fastifyWebsocket (fastify, opts, next) {
let isWebsocketRoute = false
let wsHandler = routeOptions.wsHandler
let handler = routeOptions.handler
const handleUpgrade = routeOptions.handleUpgradeRequest
? (request, reply, callback) => {
const rawRequest = request.raw
routeOptions.handleUpgradeRequest(request, rawRequest[kWs], rawRequest[kWsHead])
.then(socket => {
callback(socket)
})
.catch(error => {
const ended = reply.raw.writableEnded || reply.raw.socket.writableEnded
if (!ended) {
reply.raw.statusCode = error.statusCode || 500
reply.raw.end(error.message)
}
})
}
: defaultHandleUpgrade

if (routeOptions.websocket || routeOptions.wsHandler) {
if (routeOptions.method === 'HEAD') {
Expand Down Expand Up @@ -188,7 +205,7 @@ function fastifyWebsocket (fastify, opts, next) {
// within the route handler, we check if there has been a connection upgrade by looking at request.raw[kWs]. we need to dispatch the normal HTTP handler if not, and hijack to dispatch the websocket handler if so
if (request.raw[kWs]) {
reply.hijack()
handleUpgrade(request.raw, socket => {
const onUpgrade = (socket) => {
let result
try {
if (isWebsocketRoute) {
Expand All @@ -203,7 +220,9 @@ function fastifyWebsocket (fastify, opts, next) {
if (result && typeof result.catch === 'function') {
result.catch(err => errorHandler.call(this, err, socket, request, reply))
}
})
}

handleUpgrade(request, reply, onUpgrade)
} else {
return handler.call(this, request, reply)
}
Expand Down
157 changes: 157 additions & 0 deletions test/base.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,160 @@ test('clashing upgrade handler', async (t) => {
const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
await once(ws, 'error')
})

test('Should handleUpgradeRequest successfully', async (t) => {
t.plan(4)

const fastify = Fastify()
t.after(() => fastify.close())

await fastify.register(fastifyWebsocket)

let customUpgradeCalled = false

fastify.get('/', {
websocket: true,
handleUpgradeRequest: async (request, socket, head) => {
customUpgradeCalled = true
t.assert.equal(typeof socket, 'object', 'socket parameter is provided')
t.assert.equal(Buffer.isBuffer(head), true, 'head parameter is a buffer')

return new Promise((resolve) => {
fastify.websocketServer.handleUpgrade(request.raw, socket, head, (ws) => {
resolve(ws)
})
})
}
}, (socket) => {
socket.on('message', (data) => {
socket.send(`echo: ${data}`)
})
t.after(() => socket.terminate())
})

await fastify.listen({ port: 0 })

const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)
t.after(() => ws.close())

await once(ws, 'open')
ws.send('hello')

const [message] = await once(ws, 'message')
t.assert.equal(message.toString(), 'echo: hello')

t.assert.ok(customUpgradeCalled, 'handleUpgradeRequest was called')
})

test.only('Should handle errors thrown in handleUpgradeRequest', async (t) => {
t.plan(1)

const fastify = Fastify()
t.after(() => fastify.close())

await fastify.register(fastifyWebsocket)

fastify.get('/', {
websocket: true,
handleUpgradeRequest: async () => {
throw new Error('Custom upgrade error')
}
}, () => {
t.fail('websocket handler should not be called when upgrade fails')
})

await fastify.listen({ port: 0 })

const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)

let wsErrorResolved
const wsErrorPromise = new Promise((resolve) => {
wsErrorResolved = resolve
})

ws.on('error', (error) => {
wsErrorResolved(error)
})

const wsError = await wsErrorPromise

t.assert.equal(wsError.message, 'Unexpected server response: 500')
})

test('Should allow for handleUpgradeRequest to send a response to the client before throwing an error', async (t) => {
t.plan(1)

const fastify = Fastify()
t.after(() => fastify.close())

await fastify.register(fastifyWebsocket)

fastify.get('/', {
websocket: true,
handleUpgradeRequest: async () => {
const error = new Error('Forbidden')
error.statusCode = 403
throw error
}
}, () => {
t.fail('websocket handler should not be called when upgrade fails')
})

await fastify.listen({ port: 0 })

const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)

let wsErrorResolved
const wsErrorPromise = new Promise((resolve) => {
wsErrorResolved = resolve
})

ws.on('error', (error) => {
wsErrorResolved(error)
})

const wsError = await wsErrorPromise

t.assert.equal(wsError.message, 'Unexpected server response: 403')
})

test('Should not send a response if handleUpgradeRequest has already ended the underlying socket and thrown an error', async (t) => {
t.plan(1)

const fastify = Fastify()
t.after(() => fastify.close())

await fastify.register(fastifyWebsocket)

fastify.get('/', {
websocket: true,
handleUpgradeRequest: async (request, socket, head) => {
socket.write('HTTP/1.1 400 Bad Request\r\n')
socket.write('Connection: closed\r\n')
socket.write('\r\n')
socket.end()
socket.destroy()

throw new Error('thrown after response has ended')
}
}, () => {
t.fail('websocket handler should not be called when upgrade fails')
})

await fastify.listen({ port: 0 })

const ws = new WebSocket('ws://localhost:' + fastify.server.address().port)

let wsErrorResolved
const wsErrorPromise = new Promise((resolve) => {
wsErrorResolved = resolve
})

ws.on('error', (error) => {
wsErrorResolved(error)
})

const wsError = await wsErrorPromise

t.assert.equal(wsError.message, 'Unexpected server response: 400')
})
2 changes: 2 additions & 0 deletions types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { preCloseAsyncHookHandler, preCloseHookHandler } from 'fastify/types/hoo
import { FastifyReply } from 'fastify/types/reply'
import { RouteGenericInterface } from 'fastify/types/route'
import { IncomingMessage, Server, ServerResponse } from 'node:http'
import { Duplex } from 'node:stream'
import * as WebSocket from 'ws'

interface WebsocketRouteOptions<
Expand All @@ -17,6 +18,7 @@ interface WebsocketRouteOptions<
Logger extends FastifyBaseLogger = FastifyBaseLogger
> {
wsHandler?: fastifyWebsocket.WebsocketHandler<RawServer, RawRequest, RequestGeneric, ContextConfig, SchemaCompiler, TypeProvider, Logger>;
handleUpgradeRequest?: (request: FastifyRequest<RequestGeneric, RawServer, RawRequest, SchemaCompiler, TypeProvider, ContextConfig, Logger>, rawSocket: Duplex, socketHead: Buffer) => Promise<WebSocket.WebSocket>;
}

declare module 'fastify' {
Expand Down
24 changes: 23 additions & 1 deletion types/index.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import fastify, { FastifyBaseLogger, FastifyInstance, FastifyReply, FastifyReque
import { RouteGenericInterface } from 'fastify/types/route'
import type { IncomingMessage } from 'node:http'
import { expectType } from 'tsd'
import { Server } from 'ws'
import { Server, WebSocket as BaseWebSocket } from 'ws'
// eslint-disable-next-line import-x/no-named-default -- Test default export
import fastifyWebsocket, { default as defaultFastifyWebsocket, fastifyWebsocket as namedFastifyWebsocket, WebSocket, WebsocketHandler } from '..'
import { Duplex } from 'node:stream'

const app: FastifyInstance = fastify()
app.register(fastifyWebsocket)
Expand Down Expand Up @@ -82,6 +83,27 @@ const augmentedRouteOptions: RouteOptions = {
}
app.route(augmentedRouteOptions)

const handleUpgradeRequestOptions: RouteOptions = {
method: 'GET',
url: '/route-with-handle-upgrade-request',
handler: (request, reply) => {
expectType<FastifyRequest>(request)
expectType<FastifyReply>(reply)
},
handleUpgradeRequest: (request, socket, head) => {
expectType<FastifyRequest>(request)
expectType<Duplex>(socket)
expectType<Buffer>(head)

return Promise.resolve(new BaseWebSocket('ws://localhost:8080'))
},
wsHandler: (socket, request) => {
expectType<WebSocket>(socket)
expectType<FastifyRequest<RouteGenericInterface>>(request)
},
}
app.route(handleUpgradeRequestOptions)

app.get<{ Params: { foo: string }, Body: { bar: string }, Querystring: { search: string }, Headers: { auth: string } }>('/shorthand-explicit-types', {
websocket: true
}, async (socket, request) => {
Expand Down