Skip to content

Commit 8b3e4c1

Browse files
committed
Replace __ socket fields with proper symbols & internal type extensions
This is generally cleaner, and avoids any risk of conflicts or weirdness elsewhere.
1 parent 062f534 commit 8b3e4c1

File tree

6 files changed

+181
-133
lines changed

6 files changed

+181
-133
lines changed

custom-typings/node-type-extensions.d.ts

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,113 +3,45 @@
33

44
declare module "net" {
55
import * as net from 'net';
6-
import * as streams from 'stream';
6+
import * as stream from 'stream';
77

88
interface Socket {
9-
// Is this socket trying to send encrypted data upstream? For direct connections
10-
// this always matches socket.encrypted. For CONNECT-proxied connections (where
11-
// the initial connection could be HTTPS and the upstream connection HTTP, or
12-
// vice versa) all on one socket, this is the value for the final hop.
13-
__lastHopEncrypted?: boolean;
14-
15-
// For CONNECT-based socket tunnels, this is the address that was listed in the
16-
// last layer of the tunnelling so far.
17-
__lastHopConnectAddress?: string;
18-
19-
// Extra metadata attached to a TLS socket, taken from the client hello and
20-
// preceeding tunneling steps.
21-
__tlsMetadata?: {}; // Can't ref Mockttp real type here
22-
239
// Normally only defined on TLSSocket, but useful to explicitly include here
2410
// Undefined on plain HTTP, 'true' on TLSSocket.
2511
encrypted?: boolean;
2612

27-
// If there's a client error being sent, we track the corresponding packet
28-
// data on the socket, so that when it fires repeatedly we can combine them
29-
// into a single response & error event.
30-
clientErrorInProgress?: { rawPacket?: Buffer; }
31-
32-
// Our recordings of various timestamps, used for monitoring &
33-
// performance analysis later on
34-
__timingInfo?: {
35-
initialSocket: number; // Initial raw socket time, since unix epoch
36-
37-
// High-precision timestamps:
38-
initialSocketTimestamp: number;
39-
tunnelSetupTimestamp?: number; // Latest CONNECT completion, if any
40-
tlsConnectedTimestamp?: number; // Latest TLS handshake completion, if any
41-
};
42-
4313
// Internal reference to the parent socket, available on TLS sockets
4414
_parent?: Socket;
4515

4616
// Internal reference to the underlying stream, available on _stream_wrap
47-
stream?: streams.Duplex & Partial<net.Socket>;
17+
stream?: stream.Duplex & Partial<net.Socket>;
4818
}
4919
}
5020

5121
declare module "tls" {
52-
import SocketWrapper = require('_stream_wrap');
22+
import * as stream from 'stream';
23+
import * as net from 'net';
5324

5425
interface TLSSocket {
5526
// This is a real field that actually exists - unclear why it's not
5627
// in the type definitions.
5728
servername?: string;
5829

59-
// We cache the initially set remote address & port on sockets, because it's cleared
60-
// before the TLS error callback is called, exactly when we want to read it.
61-
initialRemoteAddress?: string;
62-
initialRemotePort?: number;
63-
64-
// Marker used to detect whether client errors should be reported as TLS issues
65-
// (RST during handshake) or as subsequent client issues (RST during request)
66-
tlsSetupCompleted?: true;
67-
6830
_handle?: { // Internal, used for monkeypatching & error tracking
6931
oncertcb?: (info: any) => any;
70-
_parentWrap?: SocketWrapper;
32+
_parentWrap?: { // SocketWrapper
33+
stream?: stream.Duplex & Partial<net.Socket>
34+
};
7135
}
7236
}
7337
}
7438

75-
// Undocumented module that allows us to turn a stream into a usable net.Socket.
76-
// Deprecated in Node 12+, but I'm hopeful that that will be cancelled...
77-
// Necessary for our HTTP2 re-CONNECT handling, so for now I'm using it regardless.
78-
declare module "_stream_wrap" {
79-
import * as net from 'net';
80-
import * as streams from 'stream';
81-
82-
class SocketWrapper extends net.Socket {
83-
constructor(stream: streams.Duplex);
84-
stream?: streams.Duplex & Partial<net.Socket>;
85-
}
86-
87-
export = SocketWrapper;
88-
}
89-
9039
declare module "http" {
9140
// Two missing methods from the official types:
9241
export function validateHeaderName(name: string): void;
9342
export function validateHeaderValue(name: string, value: unknown): void;
9443
}
9544

96-
declare module "http2" {
97-
import * as net from 'net';
98-
99-
class Http2Session {
100-
// session.socket is cleared before error handling kicks in. That's annoying,
101-
// so we manually preserve the socket elsewhere to work around it.
102-
initialSocket?: net.Socket;
103-
}
104-
105-
class ServerHttp2Stream {
106-
// Treated the same as net.Socket, when we unwrap them in our combo server:
107-
__lastHopEncrypted?: net.Socket['__lastHopEncrypted'];
108-
__lastHopConnectAddress?: net.Socket['__lastHopConnectAddress'];
109-
__timingInfo?: net.Socket['__timingInfo'];
110-
}
111-
}
112-
11345
declare class AggregateError extends Error {
11446
errors: Error[]
11547
}

src/rules/websockets/websocket-handlers.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ import {
5151
WebSocketHandlerDefinition,
5252
WsHandlerDefinitionLookup,
5353
} from './websocket-handler-definitions';
54-
import { resetOrDestroy } from '../../util/socket-util';
54+
import { LastHopEncrypted, resetOrDestroy } from '../../util/socket-util';
5555

5656
export interface WebSocketHandler extends WebSocketHandlerDefinition {
5757
handle(
@@ -292,10 +292,10 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
292292
const hostHeader = req.headers[hostHeaderName];
293293
[ hostname, port ] = hostHeader!.split(':');
294294

295-
// __lastHopEncrypted is set in http-combo-server, for requests that have explicitly
296-
// CONNECTed upstream (which may then up/downgrade from the current encryption).
297-
if (socket.__lastHopEncrypted !== undefined) {
298-
protocol = socket.__lastHopEncrypted ? 'wss' : 'ws';
295+
// LastHopEncrypted is set in http-combo-server, for requests that use TLS in the
296+
// inner-most tunnel (or direct connection) to us.
297+
if (socket[LastHopEncrypted] !== undefined) {
298+
protocol = socket[LastHopEncrypted] ? 'wss' : 'ws';
299299
} else {
300300
protocol = reqMessage.connection.encrypted ? 'wss' : 'ws';
301301
}

src/server/http-combo-server.ts

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ import { shouldPassThrough } from '../util/server-utils';
2424
import {
2525
getParentSocket,
2626
buildSocketTimingInfo,
27-
buildSocketEventData
27+
buildSocketEventData,
28+
SocketIsh,
29+
InitialRemoteAddress,
30+
InitialRemotePort,
31+
SocketTimingInfo,
32+
LastTunnelAddress,
33+
LastHopEncrypted,
34+
TlsMetadata,
35+
TlsSetupCompleted
2836
} from '../util/socket-util';
2937
import { MockttpHttpsOptions } from '../mockttp';
3038
import { buildSocksServer, SocksTcpAddress } from './socks-server';
@@ -43,10 +51,10 @@ const originalSocketInit = (<any>tls.TLSSocket.prototype)._init;
4351
const loadSNI = _handle.oncertcb;
4452
_handle.oncertcb = function (info: any) {
4553
tlsSocket.servername = info.servername;
46-
tlsSocket.initialRemoteAddress = tlsSocket.remoteAddress || // Normal case
54+
tlsSocket[InitialRemoteAddress] = tlsSocket.remoteAddress || // Normal case
4755
tlsSocket._parent?.remoteAddress || // For early failing sockets
4856
tlsSocket._handle?._parentWrap?.stream?.remoteAddress; // For HTTP/2 CONNECT
49-
tlsSocket.initialRemotePort = tlsSocket.remotePort ||
57+
tlsSocket[InitialRemotePort] = tlsSocket.remotePort ||
5058
tlsSocket._parent?.remotePort ||
5159
tlsSocket._handle?._parentWrap?.stream?.remotePort;
5260

@@ -76,7 +84,7 @@ function ifTlsDropped(socket: tls.TLSSocket, errorCallback: () => void) {
7684
// Even if these are shut later on, that doesn't mean they're are rejected connections.
7785
// To differentiate the two cases, we consider connections OK after waiting 10x longer
7886
// than the initial TLS handshake for an unhappy disconnection.
79-
const timing = socket.__timingInfo;
87+
const timing = socket[SocketTimingInfo];
8088
const tlsSetupDuration = timing
8189
? timing.tlsConnectedTimestamp! - (timing.tunnelSetupTimestamp! || timing.initialSocketTimestamp)
8290
: 0;
@@ -89,11 +97,11 @@ function ifTlsDropped(socket: tls.TLSSocket, errorCallback: () => void) {
8997
.then(() => {
9098
// Mark the socket as having completed TLS setup - this ensures that future
9199
// errors fire as client errors, not TLS setup errors.
92-
socket.tlsSetupCompleted = true;
100+
socket[TlsSetupCompleted] = true;
93101
})
94102
.catch(() => {
95103
// If TLS setup was confirmed in any way, we know we don't have a TLS error.
96-
if (socket.tlsSetupCompleted) return;
104+
if (socket[TlsSetupCompleted]) return;
97105

98106
// To get here, the socket must have connected & done the TLS handshake, but then
99107
// closed/ended without ever sending any data. We can fairly confidently assume
@@ -227,8 +235,8 @@ export async function createComboServer(
227235

228236
if (options.debug) console.log(`Proxying SOCKS TCP connection to ${addressString}`);
229237

230-
socket.__timingInfo!.tunnelSetupTimestamp = now();
231-
socket.__lastHopConnectAddress = addressString;
238+
socket[SocketTimingInfo]!.tunnelSetupTimestamp = now();
239+
socket[LastTunnelAddress] = addressString;
232240

233241
// Put the socket back into the server, so we can handle the data within:
234242
server.emit('connection', socket);
@@ -245,11 +253,11 @@ export async function createComboServer(
245253
(server as any)._httpServer.requireHostHeader = false;
246254

247255
server.on('connection', (socket: net.Socket | http2.ServerHttp2Stream) => {
248-
socket.__timingInfo = socket.__timingInfo || buildSocketTimingInfo();
256+
socket[SocketTimingInfo] ||= buildSocketTimingInfo();
249257

250258
// All sockets are initially marked as using unencrypted upstream connections.
251259
// If TLS is used, this is upgraded to 'true' by secureConnection below.
252-
socket.__lastHopEncrypted = false;
260+
socket[LastHopEncrypted] = false;
253261

254262
// For actual sockets, set NODELAY to avoid any buffering whilst streaming. This is
255263
// off by default in Node HTTP, but likely to be enabled soon & is default in curl.
@@ -265,14 +273,14 @@ export async function createComboServer(
265273
copyTimingDetails(parentSocket, socket);
266274
// With TLS metadata, we only propagate directly from parent sockets, not through
267275
// CONNECT etc - we only want it if the final hop is TLS, previous values don't matter.
268-
socket.__tlsMetadata ??= parentSocket.__tlsMetadata;
269-
} else if (!socket.__timingInfo) {
270-
socket.__timingInfo = buildSocketTimingInfo();
276+
socket[TlsMetadata] ??= parentSocket[TlsMetadata];
277+
} else if (!socket[SocketTimingInfo]) {
278+
socket[SocketTimingInfo] = buildSocketTimingInfo();
271279
}
272280

273-
socket.__timingInfo!.tlsConnectedTimestamp = now();
281+
socket[SocketTimingInfo]!.tlsConnectedTimestamp = now();
274282

275-
socket.__lastHopEncrypted = true;
283+
socket[LastHopEncrypted] = true;
276284
ifTlsDropped(socket, () => {
277285
tlsClientErrorListener(socket, buildTlsError(socket, 'closed'));
278286
});
@@ -282,7 +290,7 @@ export async function createComboServer(
282290
// happens immediately after the connection preface, as long as the connection is OK.
283291
server!.on('session', (session) => {
284292
session.once('remoteSettings', () => {
285-
session.socket.tlsSetupCompleted = true;
293+
(session.socket as tls.TLSSocket)[TlsSetupCompleted] = true;
286294
});
287295
});
288296

@@ -321,8 +329,8 @@ export async function createComboServer(
321329
if (options.debug) console.log(`Proxying HTTP/1 CONNECT to ${connectUrl}`);
322330

323331
socket.write('HTTP/' + req.httpVersion + ' 200 OK\r\n\r\n', 'utf-8', () => {
324-
socket.__timingInfo!.tunnelSetupTimestamp = now();
325-
socket.__lastHopConnectAddress = connectUrl;
332+
socket[SocketTimingInfo]!.tunnelSetupTimestamp = now();
333+
socket[LastTunnelAddress] = connectUrl;
326334
server.emit('connection', socket);
327335
});
328336
}
@@ -343,7 +351,7 @@ export async function createComboServer(
343351
res.writeHead(200, {});
344352
copyAddressDetails(res.socket, res.stream);
345353
copyTimingDetails(res.socket, res.stream);
346-
res.stream.__lastHopConnectAddress = connectUrl;
354+
res.stream[LastTunnelAddress] = connectUrl;
347355

348356
// When layering HTTP/2 on JS streams, we have to make sure the JS stream won't autoclose
349357
// when the other side does, because the upper HTTP/2 layers want to handle shutdown, so
@@ -359,15 +367,13 @@ export async function createComboServer(
359367
return makeDestroyable(server);
360368
}
361369

362-
type SocketIsh<MinProps extends keyof net.Socket> =
363-
streams.Duplex & Partial<Pick<net.Socket, MinProps>>;
364370

365371
const SOCKET_ADDRESS_METADATA_FIELDS = [
366372
'localAddress',
367373
'localPort',
368374
'remoteAddress',
369375
'remotePort',
370-
'__lastHopConnectAddress'
376+
LastTunnelAddress
371377
] as const;
372378

373379
// Update the target socket(-ish) with the address details from the source socket,
@@ -388,13 +394,13 @@ function copyAddressDetails(
388394
});
389395
}
390396

391-
function copyTimingDetails<T extends SocketIsh<'__timingInfo'>>(
392-
source: SocketIsh<'__timingInfo'>,
397+
function copyTimingDetails<T extends SocketIsh<typeof SocketTimingInfo>>(
398+
source: SocketIsh<typeof SocketTimingInfo>,
393399
target: T
394-
): asserts target is T & { __timingInfo: Required<net.Socket>['__timingInfo'] } {
395-
if (!target.__timingInfo) {
400+
): asserts target is T & { [SocketTimingInfo]: Required<net.Socket>[typeof SocketTimingInfo] } {
401+
if (!target[SocketTimingInfo]) {
396402
// Clone timing info, don't copy it - child sockets get their own independent timing stats
397-
target.__timingInfo = Object.assign({}, source.__timingInfo);
403+
target[SocketTimingInfo] = Object.assign({}, source[SocketTimingInfo]);
398404
}
399405
}
400406

@@ -427,17 +433,17 @@ function analyzeAndMaybePassThroughTls(
427433
// CONNECT or SOCKS) is even better. Note that this may be a hostname or IPv4/6 address:
428434
let connectHostname: string | undefined;
429435
let connectPort: string | undefined;
430-
if (socket.__lastHopConnectAddress) {
431-
const lastColonIndex = socket.__lastHopConnectAddress.lastIndexOf(':');
436+
if (socket[LastTunnelAddress]) {
437+
const lastColonIndex = socket[LastTunnelAddress].lastIndexOf(':');
432438
if (lastColonIndex !== -1) {
433-
connectHostname = socket.__lastHopConnectAddress.slice(0, lastColonIndex);
434-
connectPort = socket.__lastHopConnectAddress.slice(lastColonIndex + 1);
439+
connectHostname = socket[LastTunnelAddress].slice(0, lastColonIndex);
440+
connectPort = socket[LastTunnelAddress].slice(lastColonIndex + 1);
435441
} else {
436-
connectHostname = socket.__lastHopConnectAddress;
442+
connectHostname = socket[LastTunnelAddress];
437443
}
438444
}
439445

440-
socket.__tlsMetadata = {
446+
socket[TlsMetadata] = {
441447
sniHostname,
442448
connectHostname,
443449
connectPort,

0 commit comments

Comments
 (0)