Skip to content
4 changes: 3 additions & 1 deletion library/agent/Context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import type { Endpoint } from "./Config";
export type User = { id: string; name?: string };

export type Context = {
url: string | undefined;
url: string | undefined; // Full URL including protocol and host, if available
urlPath?: string | undefined; // The path part of the URL (e.g. /api/user)
method: string | undefined;
query: ParsedQs;
headers: Record<string, string | string[] | undefined>;
Expand Down Expand Up @@ -75,6 +76,7 @@ export function runWithContext<T>(context: Context, fn: () => T) {
// In this way we don't lose the `attackDetected` flag
if (current) {
current.url = context.url;
current.urlPath = context.urlPath;
current.method = context.method;
current.query = context.query;
current.headers = context.headers;
Expand Down
2 changes: 1 addition & 1 deletion library/agent/Source.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export const SOURCES = [
"xml",
"subdomains",
"markUnsafe",
"url",
"urlPath",
] as const;

export type Source = (typeof SOURCES)[number];
49 changes: 49 additions & 0 deletions library/helpers/getRawRequestPath.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import * as t from "tap";
import { getRawRequestPath } from "./getRawRequestPath";

t.test("it returns the raw URL path", async (t) => {
t.equal(getRawRequestPath(""), "/");
t.equal(getRawRequestPath("/"), "/");
t.equal(getRawRequestPath("/?test=abc"), "/");
t.equal(getRawRequestPath("#"), "/");
t.equal(getRawRequestPath("https://example.com"), "/");

t.equal(
getRawRequestPath("https://example.com/path/to/resource"),
"/path/to/resource"
);
t.equal(
getRawRequestPath("http://example.com/path/to/resource/"),
"/path/to/resource/"
);
t.equal(
getRawRequestPath("https://example.com/path/to/resource/123"),
"/path/to/resource/123"
);
t.equal(
getRawRequestPath("https://example.com/path/to/resource/123/456"),
"/path/to/resource/123/456"
);
t.equal(
getRawRequestPath("https://example.com/path/to/resource/123/456/789"),
"/path/to/resource/123/456/789"
);
t.equal(
getRawRequestPath(
"https://example.com/path/to/resource/123/456/789?query=string"
),
"/path/to/resource/123/456/789"
);
t.equal(
getRawRequestPath(
"https://example.com/path/to/resource/123/456/789#fragment"
),
"/path/to/resource/123/456/789"
);
t.equal(
getRawRequestPath(
"https://example.com/path/to/resource/123/456/789?query=string#fragment"
),
"/path/to/resource/123/456/789"
);
});
22 changes: 22 additions & 0 deletions library/helpers/getRawRequestPath.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
export function getRawRequestPath(url: string): string {
let partialUrl = url;

// Remove protocol (http://, https://, etc.)
const pathStart = partialUrl.indexOf("://");
if (pathStart !== -1) partialUrl = partialUrl.slice(pathStart + 3);

// Remove hostname and port
const slashIndex = partialUrl.indexOf("/");
if (slashIndex === -1) return "/"; // only hostname given
partialUrl = partialUrl.slice(slashIndex);

// Remove query and fragment
const queryIndex = partialUrl.indexOf("?");
const hashIndex = partialUrl.indexOf("#");

let endIndex = partialUrl.length;
if (queryIndex !== -1) endIndex = Math.min(endIndex, queryIndex);
if (hashIndex !== -1) endIndex = Math.min(endIndex, hashIndex);

return partialUrl.slice(0, endIndex) || "/";
}
200 changes: 200 additions & 0 deletions library/helpers/getRequestUrl.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import * as t from "tap";
import { getRequestUrl } from "./getRequestUrl";

import { get as httpGet, createServer, IncomingMessage } from "node:http";

t.beforeEach(() => {
delete process.env.AIKIDO_TRUST_PROXY;
});

async function getRealRequest(): Promise<IncomingMessage> {
return new Promise((resolve) => {
const server = createServer((req, res) => {
res.statusCode = 200;
res.end();
server.close(); // stop server once we have a request
resolve(req);
});
server.listen(0, () => {
const { port } = server.address() as any;
// Send a real request to trigger IncomingMessage creation
httpGet({ port, path: "/", headers: {} }, (res) => {
while (res.read()) {
// consume body to prevent test from not exiting
}

t.same(res.statusCode, 200);
}).end();
});
});
}

let baseMockRequest: IncomingMessage | null = null;

async function createMockRequest(
overrides: Partial<IncomingMessage> = {}
): Promise<IncomingMessage> {
if (!baseMockRequest) {
baseMockRequest = await getRealRequest();
}
return Object.assign(
Object.create(Object.getPrototypeOf(baseMockRequest)),
baseMockRequest,
overrides
);
}

t.test("already absolute URL", async (t) => {
t.equal(
getRequestUrl(await createMockRequest({ url: "http://example.com/path" })),
"http://example.com/path"
);

t.equal(
getRequestUrl(
await createMockRequest({ url: "https://example.com/path?test=123" })
),
"https://example.com/path?test=123"
);
});

t.test("no url set", async (t) => {
const mockReq = await createMockRequest({ url: undefined });
t.equal(getRequestUrl(mockReq), `http://${mockReq.headers.host}`);
});

t.test("no url and no host set", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: undefined,
headers: {},
})
),
""
);
});

t.test("no host header", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: "/some/path?query=1",
headers: {},
})
),
"/some/path?query=1"
);
});

t.test("relative URL with host header", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: "/some/path?query=1",
headers: { host: "example.com" },
})
),
"http://example.com/some/path?query=1"
);
});

t.test("With X-Forwarded-Host header and trust proxy disabled", async (t) => {
process.env.AIKIDO_TRUST_PROXY = "0";

t.equal(
getRequestUrl(
await createMockRequest({
url: "/forwarded/path",
headers: {
host: "original.com",
"x-forwarded-host": "forwarded.com",
},
})
),
"http://original.com/forwarded/path"
);
});

t.test("With X-Forwarded-Host header and trust proxy enabled", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: "/forwarded/path",
headers: {
host: "original.com",
"x-forwarded-host": "forwarded.com",
},
})
),
"http://forwarded.com/forwarded/path"
);
});

t.test("With X-Forwarded-Proto header and trust proxy enabled", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: "/secure/path",
headers: {
host: "example.com",
"x-forwarded-proto": "https",
},
})
),
"https://example.com/secure/path"
);
});

t.test("With X-Forwarded-Proto header set to invalid value", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: "/weird/path",
headers: {
host: "example.com",
"x-forwarded-proto": "abc",
},
})
),
"http://example.com/weird/path"
);
});

t.test("With X-Forwarded-Proto header and trust proxy disabled", async (t) => {
process.env.AIKIDO_TRUST_PROXY = "0";

t.equal(
getRequestUrl(
await createMockRequest({
url: "/notrust/path",
headers: {
host: "example.com",
"x-forwarded-proto": "https",
},
})
),
"http://example.com/notrust/path"
);
});

t.test("url does not start with slash and is not absolute", async (t) => {
t.equal(
getRequestUrl(
await createMockRequest({
url: "noslash/path",
headers: { host: "example.com" },
})
),
"http://example.com/noslash/path"
);
});

t.test("url does not start with http/https but is absolute", async (t) => {
t.match(
getRequestUrl(
await createMockRequest({ url: "ftp://example.com/resource" })
),
/http:\/\/localhost:\d+\/ftp:\/\/example.com\/resource/
);
});
81 changes: 81 additions & 0 deletions library/helpers/getRequestUrl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import type { IncomingMessage } from "http";
import { Http2ServerRequest } from "http2";
import type { TLSSocket } from "tls";
import { trustProxy } from "./trustProxy";

/**
* Get the full request URL including protocol and host.
* Falls back to relative URL if host is not available.
* Also respects forwarded headers if proxies are trusted.
*/
export function getRequestUrl(
req: IncomingMessage | Http2ServerRequest
): string {
const reqUrl = req.url || "";

// Already absolute URL
if (
reqUrl[0] !== "/" && // performance improvement
(reqUrl.startsWith("http://") || reqUrl.startsWith("https://"))
) {
return reqUrl;
}

// Relative URL
const host = getHost(req);
if (!host) {
// Fallback to relative URL if host is not available
return reqUrl;
}

// Determine protocol, fallback to http if not detectable
const protocol = getProtocol(req);

if (reqUrl.length && !reqUrl.startsWith("/")) {
// Ensure there's a slash between host and path
return `${protocol}://${host}/${reqUrl}`;
}

return `${protocol}://${host}${reqUrl}`;
}

function getHost(
req: IncomingMessage | Http2ServerRequest
): string | undefined {
const forwardedHost = req.headers?.["x-forwarded-host"];

if (typeof forwardedHost === "string" && trustProxy()) {
return forwardedHost;
}

const host =
req instanceof Http2ServerRequest ? req.authority : req.headers?.host;
if (typeof host === "string") {
return host;
}

return undefined;
}

function getProtocol(
req: IncomingMessage | Http2ServerRequest
): "http" | "https" {
const forwarded =
req.headers?.["x-forwarded-proto"] || req.headers?.["x-forwarded-protocol"];
if (typeof forwarded === "string" && trustProxy()) {
const normalized = forwarded.toLowerCase();
if (normalized === "https" || normalized === "http") {
return normalized;
}
}

if (req instanceof Http2ServerRequest && req.scheme === "https") {
return "https";
}

if (req.socket && (req.socket as TLSSocket).encrypted) {
return "https";
}

return "http";
}
Loading