Skip to content

Commit d048ca6

Browse files
committed
Use DNS to make passthrough relative-localhost check more accurate
Previously we just checked for known localhost URLs. This doesn't help for requests to other hostnames that will resolve to localhost though. We now handle that too, by resolving hostnames to IPs and then checking those instead, when necessary. This PR also extracts all that logic, and DNS configuration generally, out from req/WS handlers into shared methods.
1 parent 79674d3 commit d048ca6

File tree

5 files changed

+89
-75
lines changed

5 files changed

+89
-75
lines changed

src/rules/passthrough-handling.ts

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@ import * as _ from 'lodash';
22
import * as tls from 'tls';
33
import url = require('url');
44
import { oneLine } from 'common-tags';
5+
import CacheableLookup from 'cacheable-lookup';
56

67
import { CompletedBody, Headers } from '../types';
78
import { byteLength } from '../util/util';
89
import { asBuffer } from '../util/buffer-utils';
10+
import { isLocalhostAddress } from '../util/socket-util';
11+
import { CachedDns, dnsLookup, DnsLookupFunction } from '../util/dns';
912
import { isMockttpBody, encodeBodyBuffer } from '../util/request-utils';
1013
import { areFFDHECurvesSupported } from '../util/openssl-compat';
1114

1215
import {
1316
CallbackRequestResult,
14-
CallbackResponseMessageResult
17+
CallbackResponseMessageResult,
18+
PassThroughLookupOptions
1519
} from './requests/request-handler-definitions';
1620

1721
// TLS settings for proxied connections, intended to avoid TLS fingerprint blocking
@@ -270,3 +274,55 @@ export function shouldUseStrictHttps(
270274
}
271275
return !skipHttpsErrors;
272276
}
277+
278+
export const getDnsLookupFunction = _.memoize((lookupOptions: PassThroughLookupOptions | undefined) => {
279+
if (!lookupOptions) {
280+
// By default, use 10s caching of hostnames, just to reduce the delay from
281+
// endlessly 10ms query delay for 'localhost' with every request.
282+
return new CachedDns(10000).lookup;
283+
} else {
284+
// Or if options are provided, use those to configure advanced DNS cases:
285+
const cacheableLookup = new CacheableLookup({
286+
maxTtl: lookupOptions.maxTtl,
287+
errorTtl: lookupOptions.errorTtl,
288+
// As little caching of "use the fallback server" as possible:
289+
fallbackDuration: 0
290+
});
291+
292+
if (lookupOptions.servers) {
293+
cacheableLookup.servers = lookupOptions.servers;
294+
}
295+
296+
return cacheableLookup.lookup;
297+
}
298+
});
299+
300+
export async function getClientRelativeHostname(
301+
hostname: string | null,
302+
remoteIp: string | undefined,
303+
lookupFn: DnsLookupFunction
304+
) {
305+
if (!hostname || !remoteIp || isLocalhostAddress(remoteIp)) return hostname;
306+
307+
// Otherwise, we have a request from a different machine (or Docker container/VM/etc) and we need
308+
// to make sure that 'localhost' means _that_ machine, not ourselves.
309+
310+
// This check must be run before req modifications. If a modification changes the address to localhost,
311+
// then presumably it really does mean *this* localhost.
312+
313+
if (
314+
// If the hostname is a known localhost address, we're done:
315+
isLocalhostAddress(hostname) ||
316+
// Otherwise, we look up the IP, so we can accurately check for localhost-bound requests. This adds a little
317+
// delay, but since it's cached we save the equivalent delay in request lookup later, so it should be
318+
// effectively free. We ignore errors to delegate unresolvable etc to request processing later.
319+
isLocalhostAddress(await dnsLookup(lookupFn, hostname).catch(() => null))
320+
) {
321+
return remoteIp;
322+
323+
// Note that we just redirect - we don't update the host header. From the POV of the target, it's still
324+
// 'localhost' traffic that should appear identical to normal.
325+
} else {
326+
return hostname;
327+
}
328+
}

src/rules/requests/request-handlers.ts

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import http = require('http');
77
import https = require('https');
88
import * as fs from 'fs/promises';
99
import * as h2Client from 'http2-wrapper';
10-
import CacheableLookup from 'cacheable-lookup';
1110
import { decode as decodeBase64 } from 'base64-arraybuffer';
1211
import { Transform } from 'stream';
1312
import { stripIndent, oneLine } from 'common-tags';
@@ -60,7 +59,6 @@ import {
6059
withDeserializedCallbackBuffers,
6160
WithSerializedCallbackBuffers
6261
} from '../../serialization/body-serialization';
63-
import { CachedDns, DnsLookupFunction } from '../../util/dns';
6462
import { ErrorLike, isErrorLike } from '../../util/error';
6563

6664
import { assertParamDereferenced, RuleParameters } from '../rule-parameters';
@@ -74,7 +72,9 @@ import {
7472
OVERRIDABLE_REQUEST_PSEUDOHEADERS,
7573
buildOverriddenBody,
7674
UPSTREAM_TLS_OPTIONS,
77-
shouldUseStrictHttps
75+
shouldUseStrictHttps,
76+
getClientRelativeHostname,
77+
getDnsLookupFunction
7878
} from '../passthrough-handling';
7979

8080
import {
@@ -380,33 +380,6 @@ export class PassThroughHandler extends PassThroughHandlerDefinition {
380380
return this._trustedCACertificates;
381381
}
382382

383-
private _cacheableLookupInstance: CacheableLookup | CachedDns | undefined;
384-
private lookup(): DnsLookupFunction {
385-
if (!this.lookupOptions) {
386-
if (!this._cacheableLookupInstance) {
387-
// By default, use 10s caching of hostnames, just to reduce the delay from
388-
// endlessly 10ms query delay for 'localhost' with every request.
389-
this._cacheableLookupInstance = new CachedDns(10000);
390-
}
391-
return this._cacheableLookupInstance.lookup;
392-
} else {
393-
if (!this._cacheableLookupInstance) {
394-
this._cacheableLookupInstance = new CacheableLookup({
395-
maxTtl: this.lookupOptions.maxTtl,
396-
errorTtl: this.lookupOptions.errorTtl,
397-
// As little caching of "use the fallback server" as possible:
398-
fallbackDuration: 0
399-
});
400-
401-
if (this.lookupOptions.servers) {
402-
this._cacheableLookupInstance.servers = this.lookupOptions.servers;
403-
}
404-
}
405-
406-
return this._cacheableLookupInstance.lookup;
407-
}
408-
}
409-
410383
async handle(clientReq: OngoingRequest, clientRes: OngoingResponse) {
411384
// Don't let Node add any default standard headers - we want full control
412385
dropDefaultHeaders(clientRes);
@@ -434,14 +407,11 @@ export class PassThroughHandler extends PassThroughHandlerDefinition {
434407

435408
const isH2Downstream = isHttp2(clientReq);
436409

437-
if (isLocalhostAddress(hostname) && clientReq.remoteIpAddress && !isLocalhostAddress(clientReq.remoteIpAddress)) {
438-
// If we're proxying localhost traffic from another remote machine, then we should really be proxying
439-
// back to that machine, not back to ourselves! Best example is docker containers: if we capture & inspect
440-
// their localhost traffic, it should still be sent back into that docker container.
441-
hostname = clientReq.remoteIpAddress;
442-
443-
// We don't update the host header - from the POV of the target, it's still localhost traffic.
444-
}
410+
hostname = await getClientRelativeHostname(
411+
hostname,
412+
clientReq.remoteIpAddress,
413+
getDnsLookupFunction(this.lookupOptions)
414+
);
445415

446416
if (this.forwarding) {
447417
const { targetHost, updateHostHeader } = this.forwarding;
@@ -747,7 +717,7 @@ export class PassThroughHandler extends PassThroughHandlerDefinition {
747717
headers: shouldTryH2Upstream
748718
? rawHeadersToObjectPreservingCase(rawHeaders)
749719
: flattenPairedRawHeaders(rawHeaders) as any,
750-
lookup: this.lookup() as typeof dns.lookup,
720+
lookup: getDnsLookupFunction(this.lookupOptions) as typeof dns.lookup,
751721
// ^ Cast required to handle __promisify__ type hack in the official Node types
752722
agent,
753723
// TLS options:

src/rules/websockets/websocket-handlers.ts

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import * as tls from 'tls';
55
import * as http from 'http';
66
import * as fs from 'fs/promises';
77
import * as WebSocket from 'ws';
8-
import CacheableLookup from 'cacheable-lookup';
98

109
import {
1110
ClientServerChannel,
@@ -29,14 +28,15 @@ import {
2928
rawHeadersToObjectPreservingCase
3029
} from '../../util/header-utils';
3130
import { streamToBuffer } from '../../util/buffer-utils';
32-
import { isLocalhostAddress } from '../../util/socket-util';
3331
import { MaybePromise } from '../../util/type-utils';
3432

3533
import { getAgent } from '../http-agents';
3634
import { ProxySettingSource } from '../proxy-config';
3735
import { assertParamDereferenced, RuleParameters } from '../rule-parameters';
3836
import {
3937
UPSTREAM_TLS_OPTIONS,
38+
getClientRelativeHostname,
39+
getDnsLookupFunction,
4040
shouldUseStrictHttps
4141
} from '../passthrough-handling';
4242

@@ -212,26 +212,6 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
212212
return this._trustedCACertificates;
213213
}
214214

215-
private _cacheableLookupInstance: CacheableLookup | undefined;
216-
private lookup() {
217-
if (!this.lookupOptions) return undefined;
218-
219-
if (!this._cacheableLookupInstance) {
220-
this._cacheableLookupInstance = new CacheableLookup({
221-
maxTtl: this.lookupOptions.maxTtl,
222-
errorTtl: this.lookupOptions.errorTtl,
223-
// As little caching of "use the fallback server" as possible:
224-
fallbackDuration: 0
225-
});
226-
227-
if (this.lookupOptions.servers) {
228-
this._cacheableLookupInstance.servers = this.lookupOptions.servers;
229-
}
230-
}
231-
232-
return this._cacheableLookupInstance.lookup;
233-
}
234-
235215
async handle(req: OngoingRequest, socket: net.Socket, head: Buffer) {
236216
this.initializeWsServer();
237217

@@ -242,14 +222,11 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
242222
const isH2Downstream = isHttp2(req);
243223
const hostHeaderName = isH2Downstream ? ':authority' : 'host';
244224

245-
if (isLocalhostAddress(hostname) && req.remoteIpAddress && !isLocalhostAddress(req.remoteIpAddress)) {
246-
// If we're proxying localhost traffic from another remote machine, then we should really be proxying
247-
// back to that machine, not back to ourselves! Best example is docker containers: if we capture & inspect
248-
// their localhost traffic, it should still be sent back into that docker container.
249-
hostname = req.remoteIpAddress;
250-
251-
// We don't update the host header - from the POV of the target, it's still localhost traffic.
252-
}
225+
hostname = await getClientRelativeHostname(
226+
hostname,
227+
req.remoteIpAddress,
228+
getDnsLookupFunction(this.lookupOptions)
229+
);
253230

254231
if (this.forwarding) {
255232
const { targetHost, updateHostHeader } = this.forwarding;
@@ -348,7 +325,7 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
348325
const upstreamWebSocket = new WebSocket(wsUrl, {
349326
maxPayload: 0,
350327
agent,
351-
lookup: this.lookup(),
328+
lookup: getDnsLookupFunction(this.lookupOptions),
352329
headers: _.omitBy(headers, (_v, headerName) =>
353330
headerName.toLowerCase().startsWith('sec-websocket') ||
354331
headerName.toLowerCase() === 'connection' ||

src/util/dns.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,13 @@ export class CachedDns {
4040
}
4141
}
4242

43+
}
44+
45+
export function dnsLookup(lookupFn: typeof dns.lookup | DnsLookupFunction, hostname: string) {
46+
return new Promise<string>((resolve, reject) => {
47+
(lookupFn as typeof dns.lookup)(hostname!, (err, address) => {
48+
if (err) reject(err);
49+
else resolve(address);
50+
});
51+
})
4352
}

src/util/socket-util.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ const normalizeIp = (ip: string | null | undefined) =>
4545
: ip;
4646

4747
export const isLocalhostAddress = (host: string | null | undefined) =>
48-
host === 'localhost' || // Most common
49-
host?.endsWith('.localhost') ||
50-
host === '::1' || // IPv6
51-
normalizeIp(host)?.match(/^127\.\d{1,3}\.\d{1,3}\.\d{1,3}$/); // 127.0.0.0/8 range
48+
!!host && ( // Null/undef are something else weird, but not localhost
49+
host === 'localhost' || // Most common
50+
host.endsWith('.localhost') ||
51+
host === '::1' || // IPv6
52+
normalizeIp(host)!.match(/^127\.\d{1,3}\.\d{1,3}\.\d{1,3}$/) // 127.0.0.0/8 range
53+
);
5254

5355

5456
// Check whether an incoming socket is the other end of one of our outgoing sockets:

0 commit comments

Comments
 (0)