diff --git a/.github/workflows/qa-tests.yml b/.github/workflows/qa-tests.yml index df4c57442..e0f3a79b9 100644 --- a/.github/workflows/qa-tests.yml +++ b/.github/workflows/qa-tests.yml @@ -49,7 +49,7 @@ jobs: cp firewall-node/.github/workflows/Dockerfile.qa zen-demo-nodejs/Dockerfile - name: Run Firewall QA Tests - uses: AikidoSec/firewall-tester-action@releases/v1 + uses: AikidoSec/firewall-tester-action@v1.0.5 with: dockerfile_path: ./zen-demo-nodejs/Dockerfile app_port: 3000 diff --git a/end2end/tests/hono-xml-blocklists.test.js b/end2end/tests/hono-xml-blocklists.test.js index c9a5a481a..fe34f20e9 100644 --- a/end2end/tests/hono-xml-blocklists.test.js +++ b/end2end/tests/hono-xml-blocklists.test.js @@ -300,11 +300,8 @@ t.test("it does not block bypass IP if in blocklist", (t) => { "X-Forwarded-For": "1.3.2.2", }, }); - t.same(resp3.status, 403); - t.same( - await resp3.text(), - `Your IP address is not allowed to access this resource. (Your IP: 1.3.2.2)` - ); + t.same(resp3.status, 200); + t.match(await resp3.text(), "Admin panel"); }) .catch((error) => { t.fail(error); diff --git a/library/agent/Agent.ts b/library/agent/Agent.ts index 8d9f56766..3fb7494ce 100644 --- a/library/agent/Agent.ts +++ b/library/agent/Agent.ts @@ -33,6 +33,7 @@ import { isNewInstrumentationUnitTest } from "../helpers/isNewInstrumentationUni import { AttackWaveDetector } from "../vulnerabilities/attack-wave-detection/AttackWaveDetector"; import type { FetchListsAPI } from "./api/FetchListsAPI"; import { PendingEvents } from "./PendingEvents"; +import { domainToUnicode } from "node:url"; type WrappedPackage = { version: string | null; supported: boolean }; @@ -565,6 +566,14 @@ export class Agent { } onConnectHostname(hostname: string, port: number) { + try { + // new URL(...) always converts hostnames to punycode + // When reporting them in heartbeats, we want to send the unicode version + hostname = domainToUnicode(hostname); + } catch (e: any) { + this.logger.log(`Failed to convert hostname to unicode: ${e.message}`); + } + this.hostnames.add(hostname, port); } diff --git a/library/agent/ServiceConfig.ts b/library/agent/ServiceConfig.ts index 2d84a2e99..971dd291c 100644 --- a/library/agent/ServiceConfig.ts +++ b/library/agent/ServiceConfig.ts @@ -4,6 +4,7 @@ import { isPrivateIP } from "../vulnerabilities/ssrf/isPrivateIP"; import type { Endpoint, EndpointConfig, Domain } from "./Config"; import type { IPList, UserAgentDetails } from "./api/FetchListsAPI"; import { safeCreateRegExp } from "./safeCreateRegExp"; +import { addIPv4MappedAddresses } from "../helpers/addIPv4MappedAddresses"; export class ServiceConfig { private blockedUserIds: Map = new Map(); @@ -98,7 +99,9 @@ export class ServiceConfig { this.bypassedIPAddresses = undefined; return; } - this.bypassedIPAddresses = new IPMatcher(ipAddresses); + this.bypassedIPAddresses = new IPMatcher( + addIPv4MappedAddresses(ipAddresses) + ); } isBypassedIP(ip: string) { diff --git a/library/agent/applyHooks.test.ts b/library/agent/applyHooks.test.ts index c07c90789..2167c70b8 100644 --- a/library/agent/applyHooks.test.ts +++ b/library/agent/applyHooks.test.ts @@ -88,74 +88,79 @@ t.test( } ); -t.test("it ignores route if force protection off is on", async (t) => { - const inspectionCalls: { args: unknown[] }[] = []; +t.test( + "it still inspects outbound connections if force protection off is on", + async (t) => { + const inspectionCalls: { args: unknown[] }[] = []; - const hooks = new Hooks(); - hooks.addBuiltinModule("dns/promises").onRequire((exports, pkgInfo) => { - wrapExport(exports, "lookup", pkgInfo, { - kind: "outgoing_http_op", - inspectArgs: (args, agent) => { - inspectionCalls.push({ args }); - }, + const hooks = new Hooks(); + hooks.addBuiltinModule("dns/promises").onRequire((exports, pkgInfo) => { + wrapExport(exports, "lookup", pkgInfo, { + kind: "outgoing_http_op", + inspectArgs: (args, agent) => { + inspectionCalls.push({ args }); + }, + }); }); - }); - applyHooks(hooks, agent.isUsingNewInstrumentation()); + applyHooks(hooks, agent.isUsingNewInstrumentation()); - reportingAPI.setResult({ - success: true, - endpoints: [ - { - method: "GET", - route: "/route", - forceProtectionOff: true, - rateLimiting: { - enabled: false, - maxRequests: 0, - windowSizeInMS: 0, + reportingAPI.setResult({ + success: true, + endpoints: [ + { + method: "GET", + route: "/route", + forceProtectionOff: true, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, + }, }, - }, - ], - heartbeatIntervalInMS: 10 * 60 * 1000, - blockedUserIds: [], - allowedIPAddresses: [], - configUpdatedAt: 0, - }); + ], + heartbeatIntervalInMS: 10 * 60 * 1000, + blockedUserIds: [], + allowedIPAddresses: [], + configUpdatedAt: 0, + }); - // Read rules from API - await agent.flushStats(1000); + // Read rules from API + await agent.flushStats(1000); - const { lookup } = require("dns/promises"); + const { lookup } = require("dns/promises"); - await lookup("www.google.com"); - t.same(inspectionCalls, [{ args: ["www.google.com"] }]); + await lookup("www.google.com"); + t.same(inspectionCalls, [{ args: ["www.google.com"] }]); - await runWithContext(context, async () => { - await lookup("www.aikido.dev"); - }); + await runWithContext(context, async () => { + await lookup("www.aikido.dev"); + }); - t.same(inspectionCalls, [ - { args: ["www.google.com"] }, - { args: ["www.aikido.dev"] }, - ]); - - await runWithContext( - { - ...context, - method: "GET", - route: "/route", - }, - async () => { - await lookup("www.times.com"); - } - ); + t.same(inspectionCalls, [ + { args: ["www.google.com"] }, + { args: ["www.aikido.dev"] }, + ]); - t.same(inspectionCalls, [ - { args: ["www.google.com"] }, - { args: ["www.aikido.dev"] }, - ]); -}); + // forceProtectionOff still allows outbound connection inspection + await runWithContext( + { + ...context, + method: "GET", + route: "/route", + }, + async () => { + await lookup("www.times.com"); + } + ); + + t.same(inspectionCalls, [ + { args: ["www.google.com"] }, + { args: ["www.aikido.dev"] }, + { args: ["www.times.com"] }, + ]); + } +); t.test("it does not report attack if IP is allowed", async (t) => { const hooks = new Hooks(); diff --git a/library/agent/hooks/wrapExport.ts b/library/agent/hooks/wrapExport.ts index b5cc1a745..af3be4a43 100644 --- a/library/agent/hooks/wrapExport.ts +++ b/library/agent/hooks/wrapExport.ts @@ -3,7 +3,7 @@ import type { Agent } from "../Agent"; import { getInstance } from "../AgentSingleton"; import { OperationKind } from "../api/Event"; import { bindContext, getContext } from "../Context"; -import type { InterceptorResult } from "./InterceptorResult"; +import { type InterceptorResult, isAttackResult } from "./InterceptorResult"; import type { PartialWrapPackageInfo } from "./WrapPackageInfo"; import { wrapDefaultOrNamed } from "./wrapDefaultOrNamed"; import { onInspectionInterceptorResult } from "./onInspectionInterceptorResult"; @@ -151,14 +151,6 @@ export function inspectArgs( methodName: string, kind: OperationKind | undefined ) { - if (context) { - const matches = agent.getConfig().getEndpoints(context); - - if (matches.find((match) => match.forceProtectionOff)) { - return; - } - } - const start = performance.now(); let result: InterceptorResult = undefined; @@ -177,6 +169,16 @@ export function inspectArgs( }); } + // When forceProtectionOff is enabled, skip attack detection + // but still allow outbound connection blocking + if (context && isAttackResult(result)) { + const matches = agent.getConfig().getEndpoints(context); + + if (matches.find((match) => match.forceProtectionOff)) { + return; + } + } + onInspectionInterceptorResult( context, agent, diff --git a/library/helpers/addIPv4MappedAddresses.test.ts b/library/helpers/addIPv4MappedAddresses.test.ts new file mode 100644 index 000000000..c2286810f --- /dev/null +++ b/library/helpers/addIPv4MappedAddresses.test.ts @@ -0,0 +1,21 @@ +import * as t from "tap"; +import { addIPv4MappedAddresses } from "./addIPv4MappedAddresses"; + +t.test("it adds IPv4-mapped IPv6 addresses", async (t) => { + t.same( + addIPv4MappedAddresses([ + "1.2.3.4", + "23.45.67.89/24", + "2606:2800:220:1:248:1893:25c8:1946", + "2001:0db9:abcd:1234::/64", + ]), + [ + "1.2.3.4", + "23.45.67.89/24", + "2606:2800:220:1:248:1893:25c8:1946", + "2001:0db9:abcd:1234::/64", + "::ffff:1.2.3.4/128", + "::ffff:23.45.67.89/120", + ] + ); +}); diff --git a/library/helpers/addIPv4MappedAddresses.ts b/library/helpers/addIPv4MappedAddresses.ts new file mode 100644 index 000000000..9e8b1c870 --- /dev/null +++ b/library/helpers/addIPv4MappedAddresses.ts @@ -0,0 +1,10 @@ +import mapIPv4ToIPv6 from "./mapIPv4ToIPv6"; + +/** + * Adds IPv4-mapped IPv6 versions for all IPv4 addresses in the array. + * e.g. ["1.2.3.4", "2001:db8::/32"] -> ["1.2.3.4", "2001:db8::/32", "::ffff:1.2.3.4/128"] + */ +export function addIPv4MappedAddresses(ips: string[]): string[] { + const ipv4Addresses = ips.filter((ip) => !ip.includes(":")); + return ips.concat(ipv4Addresses.map(mapIPv4ToIPv6)); +} diff --git a/library/middleware/shouldBlockRequest.ts b/library/middleware/shouldBlockRequest.ts index f1e356b8a..634efb3c2 100644 --- a/library/middleware/shouldBlockRequest.ts +++ b/library/middleware/shouldBlockRequest.ts @@ -33,6 +33,14 @@ export function shouldBlockRequest(): Result { updateContext(context, "executedMiddleware", true); agent.onMiddlewareExecuted(); + const isBypassedIP = + context.remoteAddress && + agent.getConfig().isBypassedIP(context.remoteAddress); + + if (isBypassedIP) { + return { block: false }; + } + if (context.user && agent.getConfig().isUserBlocked(context.user.id)) { return { block: true, type: "blocked", trigger: "user" }; } diff --git a/library/ratelimiting/shouldRateLimitRequest.test.ts b/library/ratelimiting/shouldRateLimitRequest.test.ts index 773303b5b..3afec6039 100644 --- a/library/ratelimiting/shouldRateLimitRequest.test.ts +++ b/library/ratelimiting/shouldRateLimitRequest.test.ts @@ -182,37 +182,6 @@ t.test("it rate limits localhost when not in production mode", async (t) => { }); }); -t.test("it does not rate limit when the IP is allowed", async (t) => { - const agent = await createAgent( - [ - { - method: "POST", - route: "/login", - forceProtectionOff: false, - rateLimiting: { - enabled: true, - maxRequests: 3, - windowSizeInMS: 1000, - }, - }, - ], - ["1.2.3.4"] - ); - - t.same(shouldRateLimitRequest(createContext("1.2.3.4"), agent), { - block: false, - }); - t.same(shouldRateLimitRequest(createContext("1.2.3.4"), agent), { - block: false, - }); - t.same(shouldRateLimitRequest(createContext("1.2.3.4"), agent), { - block: false, - }); - t.same(shouldRateLimitRequest(createContext("1.2.3.4"), agent), { - block: false, - }); -}); - t.test("it rate limits by user", async (t) => { const agent = await createAgent([ { @@ -437,40 +406,6 @@ t.test( } ); -t.test( - "it does not rate limit requests from allowed ip with user", - async (t) => { - const agent = await createAgent( - [ - { - method: "POST", - route: "/login", - forceProtectionOff: false, - rateLimiting: { - enabled: true, - maxRequests: 3, - windowSizeInMS: 1000, - }, - }, - ], - ["1.2.3.4"] - ); - - t.same(shouldRateLimitRequest(createContext("1.2.3.4", "123"), agent), { - block: false, - }); - t.same(shouldRateLimitRequest(createContext("1.2.3.4", "123"), agent), { - block: false, - }); - t.same(shouldRateLimitRequest(createContext("1.2.3.4", "123"), agent), { - block: false, - }); - t.same(shouldRateLimitRequest(createContext("1.2.3.4", "123"), agent), { - block: false, - }); - } -); - t.test( "it does not consume rate limit for user a second time (same request)", async (t) => { diff --git a/library/ratelimiting/shouldRateLimitRequest.ts b/library/ratelimiting/shouldRateLimitRequest.ts index b1b97b51d..d5b2ddc33 100644 --- a/library/ratelimiting/shouldRateLimitRequest.ts +++ b/library/ratelimiting/shouldRateLimitRequest.ts @@ -53,12 +53,7 @@ export function shouldRateLimitRequest( isLocalhostIP(context.remoteAddress) && isProduction; - // Allow requests from allowed IPs, e.g. never rate limit office IPs - const isBypassedIP = - context.remoteAddress && - agent.getConfig().isBypassedIP(context.remoteAddress); - - if (isFromLocalhostInProduction || isBypassedIP) { + if (isFromLocalhostInProduction) { return { block: false }; } diff --git a/library/sources/FunctionsFramework.ts b/library/sources/FunctionsFramework.ts index 5d8c02c98..711cde2cd 100644 --- a/library/sources/FunctionsFramework.ts +++ b/library/sources/FunctionsFramework.ts @@ -1,6 +1,7 @@ /* eslint-disable max-lines-per-function */ +import { Agent } from "../agent/Agent"; import { getInstance } from "../agent/AgentSingleton"; -import { getContext, runWithContext } from "../agent/Context"; +import { Context, getContext, runWithContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; import { wrapExport } from "../agent/hooks/wrapExport"; import { PartialWrapPackageInfo } from "../agent/hooks/WrapPackageInfo"; @@ -77,24 +78,7 @@ export function createCloudFunctionWrapper(fn: HttpFunction): HttpFunction { } finally { const context = getContext(); if (agent && context) { - if ( - context.route && - context.method && - Number.isInteger(res.statusCode) - ) { - const shouldDiscover = shouldDiscoverRoute({ - statusCode: res.statusCode, - method: context.method, - route: context.route, - }); - - if (shouldDiscover) { - agent.onRouteExecute(context); - } - } - - const stats = agent.getInspectionStatistics(); - stats.onRequest(); + incrementStatsAndDiscoverAPISpec(context, agent, res.statusCode); await agent.getPendingEvents().waitUntilSent(getTimeoutInMS()); @@ -112,6 +96,34 @@ export function createCloudFunctionWrapper(fn: HttpFunction): HttpFunction { }; } +function incrementStatsAndDiscoverAPISpec( + context: Context, + agent: Agent, + statusCode: number +) { + if ( + context.remoteAddress && + agent.getConfig().isBypassedIP(context.remoteAddress) + ) { + return; + } + + if (context.route && context.method && Number.isInteger(statusCode)) { + const shouldDiscover = shouldDiscoverRoute({ + statusCode: statusCode, + method: context.method, + route: context.route, + }); + + if (shouldDiscover) { + agent.onRouteExecute(context); + } + } + + const stats = agent.getInspectionStatistics(); + stats.onRequest(); +} + export class FunctionsFramework implements Wrapper { onRequire(exports: any, pkgInfo: PartialWrapPackageInfo) { wrapExport(exports, "http", pkgInfo, { diff --git a/library/sources/Lambda.ts b/library/sources/Lambda.ts index 4db74f9e9..04c11b7fe 100644 --- a/library/sources/Lambda.ts +++ b/library/sources/Lambda.ts @@ -1,4 +1,5 @@ import type { Callback, Context, Handler } from "aws-lambda"; +import { Agent } from "../agent/Agent"; import { getInstance } from "../agent/AgentSingleton"; import { runWithContext, Context as AgentContext } from "../agent/Context"; import { envToBool } from "../helpers/envToBool"; @@ -221,25 +222,7 @@ export function createLambdaWrapper(handler: Handler): Handler { return result; } finally { if (agent) { - if ( - isGatewayEvent(event) && - isGatewayResponse(result) && - agentContext.route && - agentContext.method - ) { - const shouldDiscover = shouldDiscoverRoute({ - statusCode: result.statusCode, - method: agentContext.method, - route: agentContext.route, - }); - - if (shouldDiscover) { - agent.onRouteExecute(agentContext); - } - } - - const stats = agent.getInspectionStatistics(); - stats.onRequest(); + incrementStatsAndDiscoverAPISpec(agentContext, agent, event, result); await agent.getPendingEvents().waitUntilSent(getTimeoutInMS()); @@ -255,6 +238,40 @@ export function createLambdaWrapper(handler: Handler): Handler { }; } +function incrementStatsAndDiscoverAPISpec( + agentContext: AgentContext, + agent: Agent, + event: unknown, + result: unknown +) { + if ( + agentContext.remoteAddress && + agent.getConfig().isBypassedIP(agentContext.remoteAddress) + ) { + return; + } + + if ( + isGatewayEvent(event) && + isGatewayResponse(result) && + agentContext.route && + agentContext.method + ) { + const shouldDiscover = shouldDiscoverRoute({ + statusCode: result.statusCode, + method: agentContext.method, + route: agentContext.route, + }); + + if (shouldDiscover) { + agent.onRouteExecute(agentContext); + } + } + + const stats = agent.getInspectionStatistics(); + stats.onRequest(); +} + let loggedWarningUnsupportedTrigger = false; function logWarningUnsupportedTrigger() { diff --git a/library/sources/http-server/checkIfRequestIsBlocked.ts b/library/sources/http-server/checkIfRequestIsBlocked.ts index 6d30dd2df..02450985d 100644 --- a/library/sources/http-server/checkIfRequestIsBlocked.ts +++ b/library/sources/http-server/checkIfRequestIsBlocked.ts @@ -40,6 +40,14 @@ export function checkIfRequestIsBlocked( // Also ensures that the statistics are only counted once res[checkedBlocks] = true; + const isBypassedIP = + context.remoteAddress && + agent.getConfig().isBypassedIP(context.remoteAddress); + + if (isBypassedIP) { + return false; + } + if (!ipAllowedToAccessRoute(context, agent)) { res.statusCode = 403; res.setHeader("Content-Type", "text/plain"); @@ -54,14 +62,6 @@ export function checkIfRequestIsBlocked( return true; } - const isBypassedIP = - context.remoteAddress && - agent.getConfig().isBypassedIP(context.remoteAddress); - - if (isBypassedIP) { - return false; - } - if ( context.remoteAddress && !agent.getConfig().isAllowedIPAddress(context.remoteAddress).allowed diff --git a/library/sources/http-server/createRequestListener.ts b/library/sources/http-server/createRequestListener.ts index ceef2b31a..2998fe5d9 100644 --- a/library/sources/http-server/createRequestListener.ts +++ b/library/sources/http-server/createRequestListener.ts @@ -93,6 +93,13 @@ function onFinishRequestHandler( // Mark the request as counted req[countedRequest] = true; + if ( + context.remoteAddress && + agent.getConfig().isBypassedIP(context.remoteAddress) + ) { + return; + } + if (context.route && context.method) { const shouldDiscover = shouldDiscoverRoute({ statusCode: res.statusCode, @@ -113,11 +120,7 @@ function onFinishRequestHandler( agent.onRouteRateLimited(context.rateLimitedEndpoint); } - if ( - context.remoteAddress && - !agent.getConfig().isBypassedIP(context.remoteAddress) && - agent.getAttackWaveDetector().check(context) - ) { + if (context.remoteAddress && agent.getAttackWaveDetector().check(context)) { agent.onDetectedAttackWave({ request: context, }); diff --git a/library/sources/http-server/http2/createStreamListener.ts b/library/sources/http-server/http2/createStreamListener.ts index cb8aba993..fff86da44 100644 --- a/library/sources/http-server/http2/createStreamListener.ts +++ b/library/sources/http-server/http2/createStreamListener.ts @@ -53,7 +53,18 @@ function discoverRouteFromStream( stream: ServerHttp2Stream, agent: Agent ) { - if (context && context.route && context.method) { + if (!context) { + return; + } + + if ( + context.remoteAddress && + agent.getConfig().isBypassedIP(context.remoteAddress) + ) { + return; + } + + if (context.route && context.method) { const statusCode = parseInt(stream.sentHeaders[":status"] as string); if (!isNaN(statusCode)) { @@ -78,7 +89,6 @@ function discoverRouteFromStream( if ( context.remoteAddress && - !agent.getConfig().isBypassedIP(context.remoteAddress) && agent.getAttackWaveDetector().check(context) ) { agent.onDetectedAttackWave({ diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts index 1642869f6..b22d34fc2 100644 --- a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts @@ -183,6 +183,17 @@ function wrapDNSLookupCallback( } } + const isBypassedIP = + context && + context.remoteAddress && + agent.getConfig().isBypassedIP(context.remoteAddress); + + if (isBypassedIP) { + // If the IP address is allowed, we don't need to block the request + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); + } + if (!found) { if (imdsIpResult.isIMDS) { // Stored SSRF attack executed during another request (context set) @@ -211,17 +222,6 @@ function wrapDNSLookupCallback( return callback(err, addresses, family); } - const isBypassedIP = - context && - context.remoteAddress && - agent.getConfig().isBypassedIP(context.remoteAddress); - - if (isBypassedIP) { - // If the IP address is allowed, we don't need to block the request - // Just call the original callback to allow the DNS lookup - return callback(err, addresses, family); - } - // Used to get the stack trace of the calling location // We don't throw the error, we just use it to get the stack trace const stackTraceError = callingLocationStackTrace || new Error();