Skip to content

Commit c9c4438

Browse files
committed
Fire passthrough-websocket-connect rule event on passthrough WebSockets
1 parent bb61c2c commit c9c4438

File tree

4 files changed

+120
-16
lines changed

4 files changed

+120
-16
lines changed

src/rules/websockets/websocket-handlers.ts

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import * as _ from 'lodash';
22
import net = require('net');
33
import * as url from 'url';
4-
import * as tls from 'tls';
54
import * as http from 'http';
6-
import * as fs from 'fs/promises';
75
import * as WebSocket from 'ws';
86

97
import {
@@ -12,10 +10,11 @@ import {
1210
deserializeProxyConfig
1311
} from "../../serialization/serialization";
1412

15-
import { OngoingRequest, RawHeaders } from "../../types";
13+
import { Headers, OngoingRequest, RawHeaders } from "../../types";
1614

1715
import {
1816
CloseConnectionHandler,
17+
RequestHandlerOptions,
1918
ResetConnectionHandler,
2019
TimeoutHandler
2120
} from '../requests/request-handlers';
@@ -60,7 +59,9 @@ export interface WebSocketHandler extends WebSocketHandlerDefinition {
6059
// The raw socket on which we'll be communicating
6160
socket: net.Socket,
6261
// Initial data received
63-
head: Buffer
62+
head: Buffer,
63+
// Other general handler options
64+
options: RequestHandlerOptions
6465
): Promise<void>;
6566
}
6667

@@ -219,7 +220,7 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
219220
return this._trustedCACertificates;
220221
}
221222

222-
async handle(req: OngoingRequest, socket: net.Socket, head: Buffer) {
223+
async handle(req: OngoingRequest, socket: net.Socket, head: Buffer, options: RequestHandlerOptions) {
223224
this.initializeWsServer();
224225

225226
let { protocol, hostname, port, path } = url.parse(req.url!);
@@ -266,7 +267,7 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
266267
hostHeader[1] = updateHostHeader;
267268
} // Otherwise: falsey means don't touch it.
268269

269-
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
270+
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head, options);
270271
} else if (!hostname) { // No hostname in URL means transparent proxy, so use Host header
271272
const hostHeader = req.headers[hostHeaderName];
272273
[ hostname, port ] = hostHeader!.split(':');
@@ -280,14 +281,14 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
280281
}
281282

282283
const wsUrl = `${protocol}://${hostname}${port ? ':' + port : ''}${path}`;
283-
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
284+
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head, options);
284285
} else {
285286
// Connect directly according to the specified URL
286287
const wsUrl = `${
287288
protocol!.replace('http', 'ws')
288289
}//${hostname}${port ? ':' + port : ''}${path}`;
289290

290-
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
291+
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head, options);
291292
}
292293
}
293294

@@ -296,7 +297,8 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
296297
req: http.IncomingMessage,
297298
rawHeaders: RawHeaders,
298299
incomingSocket: net.Socket,
299-
head: Buffer
300+
head: Buffer,
301+
options: RequestHandlerOptions
300302
) {
301303
const parsedUrl = url.parse(wsUrl);
302304

@@ -370,6 +372,19 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
370372
...caConfig
371373
} as WebSocket.ClientOptions & { lookup: any, maxPayload: number });
372374

375+
if (options.emitEventCallback) {
376+
const upstreamReq = (upstreamWebSocket as any as { _req: http.ClientRequest })._req;
377+
options.emitEventCallback('passthrough-websocket-connect', {
378+
method: upstreamReq.method,
379+
protocol: upstreamReq.protocol.replace(/:$/, ''),
380+
hostname: upstreamReq.host,
381+
port: effectivePort.toString(),
382+
path: upstreamReq.path,
383+
rawHeaders: objectHeadersToRaw(upstreamReq.getHeaders() as Headers),
384+
subprotocols: filteredSubprotocols
385+
});
386+
}
387+
373388
upstreamWebSocket.once('open', () => {
374389
// Used in the subprotocol selection handler during the upgrade:
375390
(req as InterceptedWebSocketRequest).upstreamWebSocketProtocol = upstreamWebSocket.protocol || false;

src/rules/websockets/websocket-rule.ts

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@ export interface WebSocketRule extends Explainable {
2626

2727
// We don't extend the main interfaces for these, because MockRules are not Serializable
2828
matches(request: OngoingRequest): MaybePromise<boolean>;
29-
handle(request: OngoingRequest, response: net.Socket, head: Buffer, record: boolean): Promise<void>;
29+
handle(
30+
request: OngoingRequest,
31+
response: net.Socket,
32+
head: Buffer,
33+
options: {
34+
record: boolean,
35+
emitEventCallback?: (type: string, event: unknown) => void
36+
}
37+
): Promise<void>;
3038
isComplete(): boolean | null;
3139
}
3240

@@ -71,14 +79,22 @@ export class WebSocketRule implements WebSocketRule {
7179
return matchers.matchesAll(request, this.matchers);
7280
}
7381

74-
handle(req: OngoingRequest, res: net.Socket, head: Buffer, record: boolean): Promise<void> {
82+
handle(
83+
req: OngoingRequest,
84+
res: net.Socket,
85+
head: Buffer,
86+
options: {
87+
record: boolean,
88+
emitEventCallback?: (type: string, event: unknown) => void
89+
}
90+
): Promise<void> {
7591
let handlerPromise = (async () => { // Catch (a)sync errors
76-
return this.handler.handle(req as OngoingRequest & http.IncomingMessage, res, head);
92+
return this.handler.handle(req as OngoingRequest & http.IncomingMessage, res, head, options);
7793
})();
7894

7995
// Requests are added to rule.requests as soon as they start being handled,
8096
// as promises, which resolve only when the response & request body is complete.
81-
if (record) {
97+
if (options.record) {
8298
this.requests.push(
8399
Promise.race([
84100
// When the handler resolves, the request is completed:

src/server/mockttp-server.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,15 +756,24 @@ export class MockttpServer extends AbstractMockttp implements Mockttp {
756756
let nextRule = await nextRulePromise;
757757
if (nextRule) {
758758
if (this.debug) console.log(`Websocket matched rule: ${nextRule.explain()}`);
759-
await nextRule.handle(request, socket, head, this.recordTraffic);
759+
await nextRule.handle(request, socket, head, {
760+
record: this.recordTraffic,
761+
emitEventCallback: (this.eventEmitter.listenerCount('rule-event') !== 0)
762+
? (type, event) => this.announceRuleEventAsync(request.id, nextRule!.id, type, event)
763+
: undefined
764+
});
760765
} else {
761766
// Unmatched requests get passed through untouched automatically. This exists for
762767
// historical/backward-compat reasons, to match the initial WS implementation, and
763768
// will probably be removed to match handleRequest in future.
764769
await this.defaultWsHandler.handle(
765770
request as OngoingRequest & http.IncomingMessage,
766771
socket,
767-
head
772+
head,
773+
{ emitEventCallback: (this.eventEmitter.listenerCount('rule-event') !== 0)
774+
? (type, event) => this.announceRuleEventAsync(request.id, nextRule!.id, type, event)
775+
: undefined
776+
}
768777
);
769778
}
770779
} catch (e) {

test/integration/subscriptions/rule-events.spec.ts

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import * as _ from 'lodash';
2+
import * as WebSocket from 'isomorphic-ws';
23

34
import {
45
getLocal,
6+
RawHeaders,
57
RuleEvent
68
} from "../../..";
79
import {
810
delay,
911
expect,
10-
fetch
12+
fetch,
13+
isNode
1114
} from "../../test-utils";
1215

1316
describe("Rule event susbcriptions", () => {
@@ -168,4 +171,65 @@ describe("Rule event susbcriptions", () => {
168171
expect(responseBodyEvent.rawBody.toString('utf8')).to.equal('Original response body');
169172
});
170173

174+
it("should fire for proxied websockets", async () => {
175+
await remoteServer.forAnyWebSocket().thenPassivelyListen();
176+
const forwardingRule = await server.forAnyWebSocket().thenForwardTo(remoteServer.url);
177+
178+
const ruleEvents: RuleEvent<any>[] = [];
179+
await server.on('rule-event', (e) => ruleEvents.push(e));
180+
181+
const ws = new WebSocket(`ws://localhost:${server.port}`);
182+
const downstreamWsKey = isNode
183+
? (ws as any)._req.getHeaders()['sec-websocket-key']
184+
: undefined;
185+
186+
await new Promise<void>((resolve, reject) => {
187+
ws.addEventListener('open', () => {
188+
resolve();
189+
ws.close();
190+
});
191+
ws.addEventListener('error', reject);
192+
});
193+
194+
await delay(100);
195+
196+
expect(ruleEvents.length).to.equal(1);
197+
198+
const requestId = (await forwardingRule.getSeenRequests())[0].id;
199+
ruleEvents.forEach((event) => {
200+
expect(event.ruleId).to.equal(forwardingRule.id);
201+
expect(event.requestId).to.equal(requestId);
202+
});
203+
204+
expect(ruleEvents.map(e => e.eventType)).to.deep.equal([
205+
'passthrough-websocket-connect'
206+
]);
207+
208+
const connectEvent = ruleEvents[0].eventData;
209+
expect(_.omit(connectEvent, 'rawHeaders')).to.deep.equal({
210+
method: 'GET',
211+
protocol: 'http',
212+
hostname: 'localhost',
213+
// This reports the *modified* port, not the original:
214+
port: remoteServer.port.toString(),
215+
path: '/',
216+
subprotocols: []
217+
});
218+
219+
// This reports the *modified* header, not the original:
220+
expect(connectEvent.rawHeaders).to.deep.include(['host', `localhost:${remoteServer.port}`]);
221+
expect(connectEvent.rawHeaders).to.deep.include(['sec-websocket-version', '13']);
222+
expect(connectEvent.rawHeaders).to.deep.include(['sec-websocket-extensions', 'permessage-deflate; client_max_window_bits']);
223+
expect(connectEvent.rawHeaders).to.deep.include(['connection', 'Upgrade']);
224+
expect(connectEvent.rawHeaders).to.deep.include(['upgrade', 'websocket']);
225+
226+
// Make sure we want to see the upstream WS key, not the downstream one
227+
const upstreamWsKey = (connectEvent.rawHeaders as RawHeaders)
228+
.find(([key]) => key.toLowerCase() === 'sec-websocket-key')!;
229+
expect(upstreamWsKey[1]).to.not.equal(downstreamWsKey);
230+
});
231+
232+
// For now, we only support transformation of websocket URLs in forwarding, and nothing
233+
// else, so initial conn params are the only passthrough data that's useful to expose.
234+
171235
});

0 commit comments

Comments
 (0)