Skip to content

Commit 931055f

Browse files
committed
Correctly mirror upstream websocket subprotocols back to clients
Previously the first subprotocol was always used automatically, regardless of the upstream protocol. That was almost always what you wanted (most cases just use one subprotocol anyway) but could easily be wrong in some advanced cases. We now specifically match what the upstream server sends, when forwarding. When not forwarding, we stick to the same behaviour, replying with the first suggested protocol, if one is provided.
1 parent 1049ba8 commit 931055f

File tree

2 files changed

+74
-13
lines changed

2 files changed

+74
-13
lines changed

src/rules/websockets/websocket-handlers.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ export interface WebSocketHandler extends WebSocketHandlerDefinition {
6565
): Promise<void>;
6666
}
6767

68+
interface InterceptedWebSocketRequest extends http.IncomingMessage {
69+
upstreamWebSocketProtocol?: string | false;
70+
}
71+
6872
interface InterceptedWebSocket extends WebSocket {
6973
upstreamWebSocket: WebSocket;
7074
}
@@ -188,7 +192,17 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
188192
private initializeWsServer() {
189193
if (this.wsServer) return;
190194

191-
this.wsServer = new WebSocket.Server({ noServer: true });
195+
this.wsServer = new WebSocket.Server({
196+
noServer: true,
197+
// Mirror subprotocols back to the client:
198+
handleProtocols(protocols, request: InterceptedWebSocketRequest) {
199+
return request.upstreamWebSocketProtocol
200+
// If there's no upstream socket, default to mirroring the first protocol. This matches
201+
// WS's default behaviour - we could be stricter, but it'd be a breaking change.
202+
?? protocols.values().next().value
203+
?? false; // If there were no protocols specific and this is called for some reason
204+
},
205+
});
192206
this.wsServer.on('connection', (ws: InterceptedWebSocket) => {
193207
pipeWebSocket(ws, ws.upstreamWebSocket);
194208
pipeWebSocket(ws.upstreamWebSocket, ws);
@@ -355,6 +369,9 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
355369
} as WebSocket.ClientOptions & { lookup: any, maxPayload: number });
356370

357371
upstreamWebSocket.once('open', () => {
372+
// Used in the subprotocol selection handler during the upgrade:
373+
(req as InterceptedWebSocketRequest).upstreamWebSocketProtocol = upstreamWebSocket.protocol || false;
374+
358375
this.wsServer!.handleUpgrade(req, incomingSocket, head, (ws) => {
359376
(<InterceptedWebSocket> ws).upstreamWebSocket = upstreamWebSocket;
360377
incomingSocket.emit('ws-upgrade', ws);

test/integration/websockets.spec.ts

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,19 @@ nodeOnly(() => {
7575

7676
// Real server that echoes every message
7777
wsPort = await portfinder.getPortPromise();
78-
wsServer = new WebSocket.Server({ port: wsPort });
78+
wsServer = new WebSocket.Server({
79+
port: wsPort,
80+
handleProtocols: (protocols, request) => {
81+
const protocolIndex = request.headers['echo-ws-protocol-index'];
82+
if (protocolIndex !== undefined) {
83+
return [...protocols.values()][
84+
parseInt(protocolIndex as string)
85+
];
86+
} else {
87+
return false;
88+
}
89+
}
90+
});
7991

8092
wsServer.on('connection', (ws, request) => {
8193
if (request.headers['echo-headers']) {
@@ -146,17 +158,13 @@ nodeOnly(() => {
146158
it("forwards the incoming requests's headers", async () => {
147159
mockServer.forAnyWebSocket().thenPassThrough();
148160

149-
const ws = new WebSocket(
150-
`ws://localhost:${wsPort}`,
151-
['subprotocol-a', 'subprotocol-b'],
152-
{
153-
agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`),
154-
headers: {
155-
'echo-headers': 'true',
156-
'Funky-HEADER-casing': 'Header-Value'
157-
}
161+
const ws = new WebSocket(`ws://localhost:${wsPort}`, {
162+
agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`),
163+
headers: {
164+
'echo-headers': 'true',
165+
'Funky-HEADER-casing': 'Header-Value'
158166
}
159-
);
167+
});
160168

161169
const response = await new Promise<Buffer>((resolve, reject) => {
162170
ws.on('message', resolve);
@@ -176,7 +184,43 @@ nodeOnly(() => {
176184
[ 'Sec-WebSocket-Version', '13' ],
177185
[ 'Connection', 'Upgrade' ],
178186
[ 'Upgrade', 'websocket' ],
179-
[ 'Sec-WebSocket-Extensions', 'permessage-deflate; client_max_window_bits' ],
187+
[ 'Sec-WebSocket-Extensions', 'permessage-deflate; client_max_window_bits' ]
188+
]);
189+
});
190+
191+
192+
it("forwards the incoming requests' & resulting response's subprotocols", async () => {
193+
mockServer.forAnyWebSocket().thenPassThrough();
194+
195+
const ws = new WebSocket(
196+
`ws://localhost:${wsPort}`,
197+
['subprotocol-a', 'subprotocol-b'], // Request two sub protocols
198+
{
199+
agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`),
200+
headers: {
201+
'echo-headers': 'true',
202+
'echo-ws-protocol-index': 1 // Server should select index 1 (2nd)
203+
}
204+
}
205+
);
206+
207+
const response = await new Promise<Buffer>((resolve, reject) => {
208+
ws.on('message', resolve);
209+
ws.on('error', reject);
210+
});
211+
212+
// The server's selected subprotocol should be mirrored back to the client:
213+
expect(ws.protocol).to.equal('subprotocol-b');
214+
215+
ws.close(1000);
216+
217+
const protocolHeaders = JSON.parse(response.toString()).filter(([key]: [key: string]) =>
218+
// The key is random, so we don't check it here.
219+
key == 'Sec-WebSocket-Protocol'
220+
);
221+
222+
// Server should have seen both requested protocols:
223+
expect(protocolHeaders).to.deep.equal([
180224
[ 'Sec-WebSocket-Protocol', 'subprotocol-a,subprotocol-b' ]
181225
]);
182226
});

0 commit comments

Comments
 (0)