diff --git a/docs/outbound-requests.md b/docs/outbound-requests.md new file mode 100644 index 000000000..c97d4dcfd --- /dev/null +++ b/docs/outbound-requests.md @@ -0,0 +1,39 @@ +# Monitoring Outbound Requests + +To monitor outbound HTTP/HTTPS requests made by your application, you can use the `addHook` function with the `beforeOutboundRequest` hook. This is useful when you want to track external API calls, log outbound traffic, or analyze what domains your application connects to. + +## Basic Usage + +```js +const { addHook } = require("@aikidosec/firewall"); + +addHook("beforeOutboundRequest", ({ url, port, method }) => { + // url is a URL object: https://nodejs.org/api/url.html#class-url + console.log(`${new Date().toISOString()} - ${method} ${url.href}`); +}); +``` + +## Removing Hooks + +You can remove a previously registered hook using the `removeHook` function: + +```js +const { addHook, removeHook } = require("@aikidosec/firewall"); + +function myHook({ url, port, method }) { + console.log(`${method} ${url.href}`); +} + +addHook("beforeOutboundRequest", myHook); + +// Later, when you want to remove it: +removeHook("beforeOutboundRequest", myHook); +``` + +## Important Notes + +- You can register multiple hooks by calling `addHook` multiple times. +- The same hook function can only be registered once (duplicates are automatically prevented). +- Hooks are triggered for all HTTP/HTTPS requests made through Node.js built-in modules (`http`, `https`), builtin fetch function, undici and anything that uses that. +- Hooks are called when the connection is initiated, before knowing if Zen will block the request. +- Errors thrown in hooks (both sync and async) are silently caught and not logged to prevent breaking your application. diff --git a/library/agent/Agent.ts b/library/agent/Agent.ts index a68cb24e1..098692eb3 100644 --- a/library/agent/Agent.ts +++ b/library/agent/Agent.ts @@ -27,6 +27,7 @@ import { wrapInstalledPackages } from "./wrapInstalledPackages"; import { Wrapper } from "./Wrapper"; import { isAikidoCI } from "../helpers/isAikidoCI"; import { AttackLogger } from "./AttackLogger"; +import { executeHooks } from "./hooks"; import { Packages } from "./Packages"; import { AIStatistics } from "./AIStatistics"; import { isNewInstrumentationUnitTest } from "../helpers/isNewInstrumentationUnitTest"; @@ -579,6 +580,14 @@ export class Agent { this.hostnames.add(hostname, port); } + onConnectHTTP(url: URL, port: number, method: string) { + executeHooks("beforeOutboundRequest", { + url, + port, + method: method.toUpperCase(), + }); + } + onRouteExecute(context: Context) { this.routes.addRoute(context); } diff --git a/library/agent/hooks.test.ts b/library/agent/hooks.test.ts new file mode 100644 index 000000000..5288674e7 --- /dev/null +++ b/library/agent/hooks.test.ts @@ -0,0 +1,150 @@ +import * as t from "tap"; +import { + addHook, + removeHook, + executeHooks, + OutboundRequestInfo, +} from "./hooks"; + +t.test("it works", async (t) => { + let hookOneCalls = 0; + let hookTwoCalls = 0; + + const testRequest: OutboundRequestInfo = { + url: new URL("https://example.com"), + port: 443, + method: "GET", + }; + + function hook1(request: OutboundRequestInfo) { + t.equal(request.url.href, "https://example.com/"); + t.equal(request.port, 443); + t.equal(request.method, "GET"); + hookOneCalls++; + } + + function hook2(request: OutboundRequestInfo) { + t.equal(request.url.href, "https://example.com/"); + t.equal(request.port, 443); + t.equal(request.method, "GET"); + hookTwoCalls++; + } + + function hook3() { + throw new Error("hook3 should not be called"); + } + + t.same(hookOneCalls, 0, "hookOneCalls starts at 0"); + t.same(hookTwoCalls, 0, "hookTwoCalls starts at 0"); + + executeHooks("beforeOutboundRequest", testRequest); + + t.same(hookOneCalls, 0, "hookOneCalls still at 0"); + t.same(hookTwoCalls, 0, "hookTwoCalls still at 0"); + + addHook("beforeOutboundRequest", hook1); + // @ts-expect-error some other hook is not defined in the types + addHook("someOtherHook", hook3); + executeHooks("beforeOutboundRequest", testRequest); + + t.equal(hookOneCalls, 1, "hook1 called once"); + t.equal(hookTwoCalls, 0, "hook2 not called"); + + addHook("beforeOutboundRequest", hook2); + executeHooks("beforeOutboundRequest", testRequest); + + t.equal(hookOneCalls, 2, "hook1 called twice"); + t.equal(hookTwoCalls, 1, "hook2 called once"); + + removeHook("beforeOutboundRequest", hook1); + executeHooks("beforeOutboundRequest", testRequest); + + t.equal(hookOneCalls, 2, "hook1 still called twice"); + t.equal(hookTwoCalls, 2, "hook2 called twice"); + + removeHook("beforeOutboundRequest", hook2); + executeHooks("beforeOutboundRequest", testRequest); + + t.equal(hookOneCalls, 2, "hook1 still called twice"); + t.equal(hookTwoCalls, 2, "hook2 still called twice"); +}); + +t.test("it handles errors gracefully", async (t) => { + let successCalls = 0; + + function throwingHook() { + throw new Error("This should be caught"); + } + + function successHook() { + successCalls++; + } + + const testRequest: OutboundRequestInfo = { + url: new URL("https://example.com"), + port: 443, + method: "POST", + }; + + addHook("beforeOutboundRequest", throwingHook); + addHook("beforeOutboundRequest", successHook); + + // Should not throw even though one hook throws + executeHooks("beforeOutboundRequest", testRequest); + + t.equal( + successCalls, + 1, + "success hook still called despite error in other hook" + ); + + removeHook("beforeOutboundRequest", throwingHook); + removeHook("beforeOutboundRequest", successHook); +}); + +t.test("it handles async hooks with rejected promises", async (t) => { + let asyncCalls = 0; + + async function asyncHook() { + asyncCalls++; + throw new Error("Async error"); + } + + const testRequest: OutboundRequestInfo = { + url: new URL("https://example.com"), + port: 443, + method: "DELETE", + }; + + addHook("beforeOutboundRequest", asyncHook); + + // Should not throw even though async hook rejects + executeHooks("beforeOutboundRequest", testRequest); + + t.equal(asyncCalls, 1, "async hook was called"); + + removeHook("beforeOutboundRequest", asyncHook); +}); + +t.test("it prevents duplicate hooks using Set", async (t) => { + let hookCalls = 0; + + function hook() { + hookCalls++; + } + + const testRequest: OutboundRequestInfo = { + url: new URL("https://example.com"), + port: 443, + method: "GET", + }; + + addHook("beforeOutboundRequest", hook); + addHook("beforeOutboundRequest", hook); // Try to add the same hook again + + executeHooks("beforeOutboundRequest", testRequest); + + t.equal(hookCalls, 1, "hook only called once despite being added twice"); + + removeHook("beforeOutboundRequest", hook); +}); diff --git a/library/agent/hooks.ts b/library/agent/hooks.ts new file mode 100644 index 000000000..a43ba04d8 --- /dev/null +++ b/library/agent/hooks.ts @@ -0,0 +1,60 @@ +export type OutboundRequestInfo = { + url: URL; + port: number; + method: string; +}; + +type HookName = "beforeOutboundRequest"; + +// Map hook names to argument types +interface HookTypes { + beforeOutboundRequest: { + args: [data: OutboundRequestInfo]; + }; +} + +const hooks = new Map< + HookName, + Set<(...args: HookTypes[HookName]["args"]) => void | Promise> +>(); + +export function addHook( + name: N, + fn: (...args: HookTypes[N]["args"]) => void | Promise +) { + if (!hooks.has(name)) { + hooks.set(name, new Set([fn])); + } else { + hooks.get(name)!.add(fn); + } +} + +export function removeHook( + name: N, + fn: (...args: HookTypes[N]["args"]) => void | Promise +) { + hooks.get(name)?.delete(fn); +} + +export function executeHooks( + name: N, + ...args: [...HookTypes[N]["args"]] +): void { + const hookSet = hooks.get(name); + + for (const fn of hookSet ?? []) { + try { + const result = ( + fn as (...args: HookTypes[N]["args"]) => void | Promise + )(...args); + // If it returns a promise, catch any errors but don't wait + if (result instanceof Promise) { + result.catch(() => { + // Silently ignore errors from user hooks + }); + } + } catch { + // Silently ignore errors from user hooks + } + } +} diff --git a/library/index.ts b/library/index.ts index ab2b6925c..794ba8c29 100644 --- a/library/index.ts +++ b/library/index.ts @@ -15,6 +15,7 @@ import { isESM } from "./helpers/isESM"; import { checkIndexImportGuard } from "./helpers/indexImportGuard"; import { setRateLimitGroup } from "./ratelimiting/group"; import { isLibBundled } from "./helpers/isLibBundled"; +import { addHook, removeHook } from "./agent/hooks"; // Prevent logging twice / trying to start agent twice if (!isNewHookSystemUsed()) { @@ -51,6 +52,8 @@ export { addKoaMiddleware, addRestifyMiddleware, setRateLimitGroup, + addHook, + removeHook, }; // Required for ESM / TypeScript default export support @@ -67,4 +70,6 @@ export default { addKoaMiddleware, addRestifyMiddleware, setRateLimitGroup, + addHook, + removeHook, }; diff --git a/library/sinks/Fetch.test.ts b/library/sinks/Fetch.test.ts index 41963a423..9d662310a 100644 --- a/library/sinks/Fetch.test.ts +++ b/library/sinks/Fetch.test.ts @@ -3,6 +3,7 @@ import * as t from "tap"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; +import { addHook, removeHook } from "../agent/hooks"; import { wrap } from "../helpers/wrap"; import { Fetch } from "./Fetch"; import * as dns from "dns"; @@ -92,12 +93,31 @@ t.test( t.same(agent.getHostnames().asArray(), []); + const hookArgs: unknown[] = []; + const hook = (args: unknown) => { + hookArgs.push(args); + }; + addHook("beforeOutboundRequest", hook); await fetch("http://app.aikido.dev"); - + await fetch(new Request("https://app.aikido.dev", { method: "POST" })); t.same(agent.getHostnames().asArray(), [ { hostname: "app.aikido.dev", port: 80, hits: 1 }, + { hostname: "app.aikido.dev", port: 443, hits: 1 }, ]); agent.getHostnames().clear(); + t.same(hookArgs, [ + { + url: new URL("http://app.aikido.dev"), + method: "GET", + port: 80, + }, + { + url: new URL("https://app.aikido.dev/"), + method: "POST", + port: 443, + }, + ]); + removeHook("beforeOutboundRequest", hook); await fetch(new URL("https://app.aikido.dev")); diff --git a/library/sinks/Fetch.ts b/library/sinks/Fetch.ts index 5b69c89e9..39eae204c 100644 --- a/library/sinks/Fetch.ts +++ b/library/sinks/Fetch.ts @@ -15,19 +15,21 @@ export class Fetch implements Wrapper { private inspectHostname( agent: Agent, - hostname: string, - port: number | undefined + url: URL, + port: number | undefined, + method: string ): InterceptorResult { // Let the agent know that we are connecting to this hostname // This is to build a list of all hostnames that the application is connecting to if (typeof port === "number" && port > 0) { - agent.onConnectHostname(hostname, port); + agent.onConnectHostname(url.hostname, port); + agent.onConnectHTTP(url, port, method); } - if (agent.getConfig().shouldBlockOutgoingRequest(hostname)) { + if (agent.getConfig().shouldBlockOutgoingRequest(url.hostname)) { return { operation: "fetch", - hostname: hostname, + hostname: url.hostname, }; } @@ -38,7 +40,7 @@ export class Fetch implements Wrapper { } return checkContextForSSRF({ - hostname: hostname, + hostname: url.hostname, operation: "fetch", context: context, port: port, @@ -47,14 +49,26 @@ export class Fetch implements Wrapper { inspectFetch(args: unknown[], agent: Agent): InterceptorResult { if (args.length > 0) { + // Extract method from options or Request object + let method = "GET"; + if (args[0] instanceof Request) { + method = args[0].method.toUpperCase(); + } else if (args.length > 1 && args[1] && typeof args[1] === "object") { + const options = args[1] as { method?: string }; + if (options.method) { + method = options.method.toUpperCase(); + } + } + // URL string if (typeof args[0] === "string" && args[0].length > 0) { const url = tryParseURL(args[0]); if (url) { const attack = this.inspectHostname( agent, - url.hostname, - getPortFromURL(url) + url, + getPortFromURL(url), + method ); if (attack) { return attack; @@ -71,8 +85,9 @@ export class Fetch implements Wrapper { if (url) { const attack = this.inspectHostname( agent, - url.hostname, - getPortFromURL(url) + url, + getPortFromURL(url), + method ); if (attack) { return attack; @@ -84,8 +99,9 @@ export class Fetch implements Wrapper { if (args[0] instanceof URL && args[0].hostname.length > 0) { const attack = this.inspectHostname( agent, - args[0].hostname, - getPortFromURL(args[0]) + args[0], + getPortFromURL(args[0]), + method ); if (attack) { return attack; @@ -98,8 +114,9 @@ export class Fetch implements Wrapper { if (url) { const attack = this.inspectHostname( agent, - url.hostname, - getPortFromURL(url) + url, + getPortFromURL(url), + method ); if (attack) { return attack; diff --git a/library/sinks/HTTPRequest.test.ts b/library/sinks/HTTPRequest.test.ts index 549f36115..49ffd175c 100644 --- a/library/sinks/HTTPRequest.test.ts +++ b/library/sinks/HTTPRequest.test.ts @@ -3,6 +3,7 @@ import * as dns from "dns"; import * as t from "tap"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; +import { addHook, removeHook } from "../agent/hooks"; import { wrap } from "../helpers/wrap"; import { HTTPRequest } from "./HTTPRequest"; import { createTestAgent } from "../helpers/createTestAgent"; @@ -74,15 +75,27 @@ const oldUrl = require("url"); t.test("it works", (t) => { t.same(agent.getHostnames().asArray(), []); + const hookArgs: unknown[] = []; + const hook = (args: unknown) => { + hookArgs.push(args); + }; + addHook("beforeOutboundRequest", hook); runWithContext(createContext(), () => { const aikido = http.request("http://aikido.dev"); aikido.end(); }); - t.same(agent.getHostnames().asArray(), [ { hostname: "aikido.dev", port: 80, hits: 1 }, ]); agent.getHostnames().clear(); + t.same(hookArgs, [ + { + url: new URL("http://aikido.dev/"), + port: 80, + method: "GET", + }, + ]); + removeHook("beforeOutboundRequest", hook); runWithContext(createContext(), () => { const aikido = https.request("https://aikido.dev"); diff --git a/library/sinks/HTTPRequest.ts b/library/sinks/HTTPRequest.ts index e39b43b72..67dbb6821 100644 --- a/library/sinks/HTTPRequest.ts +++ b/library/sinks/HTTPRequest.ts @@ -19,12 +19,14 @@ export class HTTPRequest implements Wrapper { agent: Agent, url: URL, port: number | undefined, - module: "http" | "https" + module: "http" | "https", + method: string ): InterceptorResult { // Let the agent know that we are connecting to this hostname // This is to build a list of all hostnames that the application is connecting to if (typeof port === "number" && port > 0) { agent.onConnectHostname(url.hostname, port); + agent.onConnectHTTP(url, port, method); } if (agent.getConfig().shouldBlockOutgoingRequest(url.hostname)) { @@ -82,11 +84,21 @@ export class HTTPRequest implements Wrapper { } if (url.hostname.length > 0) { + // Extract method from options object + let method = "GET"; + const optionObj = args.find((arg): arg is RequestOptions => + isOptionsObject(arg) + ); + if (optionObj && optionObj.method) { + method = optionObj.method.toUpperCase(); + } + const attack = this.inspectHostname( agent, url, getPortFromURL(url), - module + module, + method ); if (attack) { return attack; diff --git a/library/sinks/Undici.tests.ts b/library/sinks/Undici.tests.ts index 188d2631e..25dcd3650 100644 --- a/library/sinks/Undici.tests.ts +++ b/library/sinks/Undici.tests.ts @@ -2,6 +2,7 @@ import * as t from "tap"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; +import { addHook, removeHook } from "../agent/hooks"; import { LoggerForTesting } from "../agent/logger/LoggerForTesting"; import { startTestAgent } from "../helpers/startTestAgent"; import { getMajorNodeVersion } from "../helpers/getNodeVersion"; @@ -71,6 +72,39 @@ export async function createUndiciTests(undiciPkgName: string, port: number) { undiciPkgName ) as typeof import("undici-v6"); + const hookArgs: unknown[] = []; + const beforeOutbound = (args: unknown) => { + hookArgs.push(args); + }; + addHook("beforeOutboundRequest", beforeOutbound); + await request({ + protocol: "https:", + hostname: "ssrf-redirects.testssandbox.com", + pathname: "/my-path", + search: "?a=b", + }); + await request(`http://localhost:${port}/api/internal`, { + method: "POST", + }); + t.same(agent.getHostnames().asArray(), [ + { hostname: "ssrf-redirects.testssandbox.com", port: 443, hits: 1 }, + { hostname: "localhost", port: port, hits: 1 }, + ]); + agent.getHostnames().clear(); + t.same(hookArgs, [ + { + url: new URL("https://ssrf-redirects.testssandbox.com/my-path?a=b"), + method: "GET", + port: 443, + }, + { + url: new URL(`http://localhost:${port}/api/internal`), + method: "POST", + port: port, + }, + ]); + removeHook("beforeOutboundRequest", beforeOutbound); + await request("https://ssrf-redirects.testssandbox.com"); t.same(agent.getHostnames().asArray(), [ { diff --git a/library/sinks/Undici.ts b/library/sinks/Undici.ts index 8892297dc..4ac3815f3 100644 --- a/library/sinks/Undici.ts +++ b/library/sinks/Undici.ts @@ -6,12 +6,14 @@ import { Hooks } from "../agent/hooks/Hooks"; import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getSemverNodeVersion } from "../helpers/getNodeVersion"; +import { getPortFromURL } from "../helpers/getPortFromURL"; +import { isPlainObject } from "../helpers/isPlainObject"; import { isVersionGreaterOrEqual } from "../helpers/isVersionGreaterOrEqual"; import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; +import { buildURLFromArgs } from "./undici/buildURLFromObject"; import { wrapDispatch } from "./undici/wrapDispatch"; import { wrapExport } from "../agent/hooks/wrapExport"; -import { getHostnameAndPortFromArgs } from "./undici/getHostnameAndPortFromArgs"; import type { PartialWrapPackageInfo } from "../agent/hooks/WrapPackageInfo"; const methods = [ @@ -26,20 +28,22 @@ const methods = [ export class Undici implements Wrapper { private inspectHostname( agent: Agent, - hostname: string, - port: number | undefined, - method: string + url: URL, + method: string, + httpMethod: string ): InterceptorResult { // Let the agent know that we are connecting to this hostname // This is to build a list of all hostnames that the application is connecting to + const port = getPortFromURL(url); if (typeof port === "number" && port > 0) { - agent.onConnectHostname(hostname, port); + agent.onConnectHostname(url.hostname, port); + agent.onConnectHTTP(url, port, httpMethod); } - if (agent.getConfig().shouldBlockOutgoingRequest(hostname)) { + if (agent.getConfig().shouldBlockOutgoingRequest(url.hostname)) { return { operation: `undici.${method}`, - hostname: hostname, + hostname: url.hostname, }; } @@ -50,7 +54,7 @@ export class Undici implements Wrapper { } return checkContextForSSRF({ - hostname: hostname, + hostname: url.hostname, operation: `undici.${method}`, context, port, @@ -62,14 +66,18 @@ export class Undici implements Wrapper { agent: Agent, method: string ): InterceptorResult { - const hostnameAndPort = getHostnameAndPortFromArgs(args); - if (hostnameAndPort) { - const attack = this.inspectHostname( - agent, - hostnameAndPort.hostname, - hostnameAndPort.port, - method - ); + let httpMethod = "GET"; + if ( + args.length > 1 && + isPlainObject(args[1]) && + typeof args[1].method === "string" + ) { + httpMethod = args[1].method; + } + + const url = buildURLFromArgs(args); + if (url) { + const attack = this.inspectHostname(agent, url, method, httpMethod); if (attack) { return attack; } diff --git a/library/sinks/http-request/isOptionsObject.ts b/library/sinks/http-request/isOptionsObject.ts index facfca768..206a7d2e4 100644 --- a/library/sinks/http-request/isOptionsObject.ts +++ b/library/sinks/http-request/isOptionsObject.ts @@ -1,6 +1,6 @@ /** * Check if the argument is treated as an options object by Node.js. - * For checking if the argument can be used as options for a outgoing HTTP request. + * For checking if the argument can be used as options for an outgoing HTTP request. */ export function isOptionsObject(arg: any): arg is { [key: string]: unknown } { return ( diff --git a/library/sinks/undici/buildURLFromObject.test.ts b/library/sinks/undici/buildURLFromObject.test.ts new file mode 100644 index 000000000..6582f35ee --- /dev/null +++ b/library/sinks/undici/buildURLFromObject.test.ts @@ -0,0 +1,148 @@ +import * as t from "tap"; +import { buildURLFromArgs } from "./buildURLFromObject"; +import { parse as parseUrl } from "url"; + +t.test("empty", async (t) => { + const url = buildURLFromArgs([]); + t.same(url, undefined); +}); + +t.test("it returns an URL instance", async (t) => { + const url = buildURLFromArgs(["http://localhost:4000"]); + t.ok(url instanceof URL); +}); + +t.test("it returns the full url", async () => { + t.same( + buildURLFromArgs([ + { origin: "http://localhost:4000", pathname: "/api", search: "?page=1" }, + ])?.toString(), + "http://localhost:4000/api?page=1" + ); + t.same( + buildURLFromArgs([ + { origin: "http://localhost:4000", path: "/api?page=1" }, + ])?.toString(), + "http://localhost:4000/api?page=1" + ); +}); + +t.test("origin ends with slash", async (t) => { + t.same( + buildURLFromArgs([ + { origin: "http://localhost:4000/", pathname: "/api", search: "?page=1" }, + ])?.toString(), + "http://localhost:4000/api?page=1" + ); + t.same( + buildURLFromArgs([ + { origin: "http://localhost:4000/", path: "/api?page=1" }, + ])?.toString(), + "http://localhost:4000/api?page=1" + ); +}); + +t.test("it works with url string", async (t) => { + t.same( + buildURLFromArgs(["http://localhost:4000"])?.toString(), + "http://localhost:4000/" + ); + t.same( + buildURLFromArgs(["http://localhost?test=1"])?.toString(), + "http://localhost/?test=1" + ); + t.same( + buildURLFromArgs(["https://localhost"])?.toString(), + "https://localhost/" + ); +}); + +t.test("it works with url object", async (t) => { + t.same( + buildURLFromArgs([new URL("http://localhost:4000")])?.toString(), + "http://localhost:4000/" + ); + t.same( + buildURLFromArgs([new URL("http://localhost?test=1")])?.toString(), + "http://localhost/?test=1" + ); + t.same( + buildURLFromArgs([new URL("https://localhost")])?.toString(), + "https://localhost/" + ); +}); + +t.test("it works with an array of strings", async (t) => { + t.same( + buildURLFromArgs([["http://localhost:4000"]])?.toString(), + "http://localhost:4000/" + ); + t.same( + buildURLFromArgs([["http://localhost?test=1"]])?.toString(), + "http://localhost/?test=1" + ); + t.same( + buildURLFromArgs([["https://localhost"]])?.toString(), + "https://localhost/" + ); +}); + +t.test("it works with an legacy url object", async (t) => { + t.same( + buildURLFromArgs([parseUrl("http://localhost:4000")])?.toString(), + "http://localhost:4000/" + ); + t.same( + buildURLFromArgs([parseUrl("http://localhost?test=1")])?.toString(), + "http://localhost/?test=1" + ); + t.same( + buildURLFromArgs([parseUrl("https://localhost")])?.toString(), + "https://localhost/" + ); +}); + +t.test("it works with an options object containing origin", async (t) => { + t.same( + buildURLFromArgs([{ origin: "http://localhost:4000" }])?.toString(), + "http://localhost:4000/" + ); + t.same( + buildURLFromArgs([ + { origin: "http://localhost", search: "?test=1" }, + ])?.toString(), + "http://localhost/?test=1" + ); + t.same( + buildURLFromArgs([{ origin: "https://localhost" }])?.toString(), + "https://localhost/" + ); +}); + +t.test( + "it works with an options object containing protocol, hostname and port", + async (t) => { + t.same( + buildURLFromArgs([ + { protocol: "http:", hostname: "localhost", port: 4000 }, + ])?.toString(), + "http://localhost:4000/" + ); + t.same( + buildURLFromArgs([ + { protocol: "https:", hostname: "localhost" }, + ])?.toString(), + "https://localhost/" + ); + } +); + +t.test("invalid origin url", async (t) => { + t.same(buildURLFromArgs([{ origin: "invalid url" }]), undefined); + t.same(buildURLFromArgs([{ origin: "" }]), undefined); +}); + +t.test("without hostname", async (t) => { + t.same(buildURLFromArgs([{}]), undefined); + t.same(buildURLFromArgs([{ protocol: "https:", port: 4000 }]), undefined); +}); diff --git a/library/sinks/undici/buildURLFromObject.ts b/library/sinks/undici/buildURLFromObject.ts new file mode 100644 index 000000000..0d81cb432 --- /dev/null +++ b/library/sinks/undici/buildURLFromObject.ts @@ -0,0 +1,53 @@ +// oxlint-disable typescript-eslint(no-base-to-string) +// oxlint-disable typescript-eslint(restrict-template-expressions) +import { tryParseURL } from "../../helpers/tryParseURL"; +import { isOptionsObject } from "../http-request/isOptionsObject"; + +export function buildURLFromArgs(args: unknown[]) { + if (args.length === 0) { + return undefined; + } + + if (typeof args[0] === "string") { + return tryParseURL(args[0]); + } + + // undici also exports `fetch` (like the global fetch) + // Fetch accepts any object with a stringifier. User input may be an array if the user provides an array + // query parameter (e.g., ?example[0]=https://example.com/) in frameworks like Express. Since an Array has + // a default stringifier, this is exploitable in a default setup. + // The following condition ensures that we see the same value as what's passed down to the sink. + if (Array.isArray(args[0])) { + return tryParseURL(args[0].toString()); + } + + if (args[0] instanceof URL) { + return args[0]; + } + + if (isOptionsObject(args[0])) { + return buildURLFromObject(args[0] as Record); + } + + return undefined; +} + +// Logic copied from parseURL in https://github.com/nodejs/undici/blob/main/lib/core/util.js +// Note: { hostname: string, port: number } is not accepted by Undici +function buildURLFromObject(url: Record) { + const port = url.port ? url.port : url.protocol === "https:" ? 443 : 80; + + let origin = url.origin + ? url.origin + : `${url.protocol || ""}//${url.hostname || ""}:${port}`; + if (typeof origin === "string" && origin[origin.length - 1] === "/") { + origin = origin.slice(0, origin.length - 1); + } + + let path = url.path ? url.path : `${url.pathname || ""}${url.search || ""}`; + if (typeof path === "string" && path[0] !== "/") { + path = `/${path}`; + } + + return tryParseURL(`${origin}${path}`); +} diff --git a/library/sinks/undici/getHostnameAndPortFromArgs.test.ts b/library/sinks/undici/getHostnameAndPortFromArgs.test.ts index b5b1bfcde..135e27d78 100644 --- a/library/sinks/undici/getHostnameAndPortFromArgs.test.ts +++ b/library/sinks/undici/getHostnameAndPortFromArgs.test.ts @@ -1,7 +1,22 @@ +// There used to be a function named `getHostnameAndPortFromArgs`, we kept the tests for it! import * as t from "tap"; -import { getHostnameAndPortFromArgs as get } from "./getHostnameAndPortFromArgs"; +import { getPortFromURL } from "../../helpers/getPortFromURL"; +import { buildURLFromArgs } from "./buildURLFromObject"; import { parse as parseUrl } from "url"; +function get(args: unknown[]) { + const url = buildURLFromArgs(args); + + if (url) { + return { + hostname: url.hostname, + port: getPortFromURL(url), + }; + } + + return undefined; +} + t.test("it works with url string", async (t) => { t.same(get(["http://localhost:4000"]), { hostname: "localhost", @@ -84,10 +99,6 @@ t.test( hostname: "localhost", port: 4000, }); - t.same(get([{ hostname: "localhost", port: 4000 }]), { - hostname: "localhost", - port: 4000, - }); t.same(get([{ protocol: "https:", hostname: "localhost" }]), { hostname: "localhost", port: 443, diff --git a/library/sinks/undici/getHostnameAndPortFromArgs.ts b/library/sinks/undici/getHostnameAndPortFromArgs.ts deleted file mode 100644 index e27c848c9..000000000 --- a/library/sinks/undici/getHostnameAndPortFromArgs.ts +++ /dev/null @@ -1,95 +0,0 @@ -import { getPortFromURL } from "../../helpers/getPortFromURL"; -import { tryParseURL } from "../../helpers/tryParseURL"; -import { isOptionsObject } from "../http-request/isOptionsObject"; - -type HostnameAndPort = { - hostname: string; - port: number | undefined; -}; - -/** - * Extract hostname and port from the arguments of a undici request. - * Used for SSRF detection. - */ -export function getHostnameAndPortFromArgs( - args: unknown[] -): HostnameAndPort | undefined { - let url: URL | undefined; - if (args.length > 0) { - // URL provided as a string - if (typeof args[0] === "string" && args[0].length > 0) { - url = tryParseURL(args[0]); - } - // Fetch accepts any object with a stringifier. User input may be an array if the user provides an array - // query parameter (e.g., ?example[0]=https://example.com/) in frameworks like Express. Since an Array has - // a default stringifier, this is exploitable in a default setup. - // The following condition ensures that we see the same value as what's passed down to the sink. - if (Array.isArray(args[0])) { - url = tryParseURL(args[0].toString()); - } - - // URL provided as a URL object - if (args[0] instanceof URL) { - url = args[0]; - } - - // If url is not undefined, extract the hostname and port - if (url && url.hostname.length > 0) { - return { - hostname: url.hostname, - port: getPortFromURL(url), - }; - } - - // Check if it can be a request options object - if (isOptionsObject(args[0])) { - return parseOptionsObject(args[0]); - } - } - - return undefined; -} - -/** - * Parse a undici request options object to extract hostname and port. - */ -function parseOptionsObject(obj: any): HostnameAndPort | undefined { - // Origin is preferred over hostname - // See https://github.com/nodejs/undici/blob/c926a43ac5952b8b5a6c7d15529b56599bc1b762/lib/core/util.js#L177 - // oxlint-disable-next-line eqeqeq - if (obj.origin != null && typeof obj.origin === "string") { - const url = tryParseURL(obj.origin); - if (url) { - return { - hostname: url.hostname, - port: getPortFromURL(url), - }; - } - - // Undici should throw an error if the origin is not a valid URL - return undefined; - } - - let port = 80; - if (typeof obj.protocol === "string") { - port = obj.protocol === "https:" ? 443 : 80; - } - if (typeof obj.port === "number") { - port = obj.port; - } else if ( - typeof obj.port === "string" && - Number.isInteger(parseInt(obj.port, 10)) - ) { - port = parseInt(obj.port, 10); - } - - // hostname is required by undici and host is not supported - if (typeof obj.hostname !== "string" || obj.hostname.length === 0) { - return undefined; - } - - return { - hostname: obj.hostname, - port, - }; -}