diff --git a/library/agent/hooks/wrapExport.test.ts b/library/agent/hooks/wrapExport.test.ts index ee60c7021..d73a6b10d 100644 --- a/library/agent/hooks/wrapExport.test.ts +++ b/library/agent/hooks/wrapExport.test.ts @@ -172,6 +172,30 @@ t.test("With callback", async (t) => { { name: "test", type: "external" }, { kind: "outgoing_http_op", + bindContext: false, + inspectArgs: (args) => { + t.same(args, ["input", () => {}]); + }, + } + ); + + toWrap.test("input", () => {}); +}); + +t.test("With callback with bindContext", async (t) => { + const toWrap = { + test(input: string, callback: (input: string) => void) { + callback(input); + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + kind: "outgoing_http_op", + bindContext: true, inspectArgs: (args) => { t.same(args, ["input", bindContext(() => {})]); }, diff --git a/library/agent/hooks/wrapExport.ts b/library/agent/hooks/wrapExport.ts index 62dea4369..30e59dd90 100644 --- a/library/agent/hooks/wrapExport.ts +++ b/library/agent/hooks/wrapExport.ts @@ -31,6 +31,12 @@ export type InterceptorObject = { // This will be used to collect stats // For sources, this will often be undefined kind: OperationKind | undefined; + // Whether to bind the async resource execution context to callback functions passed as arguments + // Only applies to inspectArgs right now + // If the called function uses code where the context would be lost, like calling the callback in a setTimeout + // or an event listener, this should be true + // In other cases this can be false to avoid unnecessary overhead + bindContext?: boolean; }; /** @@ -60,10 +66,12 @@ export function wrapExport( // Run inspectArgs interceptor if provided if (typeof interceptors.inspectArgs === "function") { - // Bind context to functions in arguments - for (let i = 0; i < args.length; i++) { - if (typeof args[i] === "function") { - args[i] = bindContext(args[i]); + if (interceptors.bindContext) { + // Bind context to functions in arguments + for (let i = 0; i < args.length; i++) { + if (typeof args[i] === "function") { + args[i] = bindContext(args[i]); + } } } diff --git a/library/sinks/ChildProcess.test.ts b/library/sinks/ChildProcess.test.ts index c35245d5c..9e7349fc9 100644 --- a/library/sinks/ChildProcess.test.ts +++ b/library/sinks/ChildProcess.test.ts @@ -1,5 +1,5 @@ import * as t from "tap"; -import { Context, runWithContext } from "../agent/Context"; +import { Context, getContext, runWithContext } from "../agent/Context"; import { ChildProcess } from "./ChildProcess"; import { execFile, execFileSync } from "child_process"; import { createTestAgent } from "../helpers/createTestAgent"; @@ -219,4 +219,13 @@ t.test("it works", async (t) => { ); } ); + + await new Promise((resolve) => { + runWithContext(unsafeContext, () => { + exec("ls", () => { + t.same(getContext(), unsafeContext); + resolve(); + }).unref(); + }); + }); }); diff --git a/library/sinks/FileSystem.test.ts b/library/sinks/FileSystem.test.ts index 85394993d..776a1acf5 100644 --- a/library/sinks/FileSystem.test.ts +++ b/library/sinks/FileSystem.test.ts @@ -1,5 +1,5 @@ import * as t from "tap"; -import { Context, runWithContext } from "../agent/Context"; +import { Context, getContext, runWithContext } from "../agent/Context"; import { FileSystem } from "./FileSystem"; import { createTestAgent } from "../helpers/createTestAgent"; @@ -52,11 +52,12 @@ t.test("it works", async (t) => { const { writeFile, writeFileSync, + readFile, rename, realpath, promises: fsDotPromise, realpathSync, - } = require("fs"); + } = require("fs") as typeof import("fs"); const { writeFile: writeFilePromise } = require("fs/promises") as typeof import("fs/promises"); @@ -64,7 +65,9 @@ t.test("it works", async (t) => { t.ok(typeof realpathSync.native === "function"); const runCommandsWithInvalidArgs = () => { + // @ts-expect-error Invalid args test throws(() => writeFile(), /Received undefined/); + // @ts-expect-error Invalid args test throws(() => writeFileSync(), /Received undefined/); }; @@ -308,4 +311,13 @@ t.test("it works", async (t) => { rename(new URL("file:///../../test.txt"), "../test2.txt", () => {}); } ); + + await new Promise((resolve) => { + runWithContext(unsafeContext, () => { + readFile("./test.txt", "utf-8", (err, data) => { + t.same(getContext(), unsafeContext); + resolve(); + }); + }); + }); }); diff --git a/library/sinks/HTTPRequest.test.ts b/library/sinks/HTTPRequest.test.ts index e31d6b223..bb0eab21a 100644 --- a/library/sinks/HTTPRequest.test.ts +++ b/library/sinks/HTTPRequest.test.ts @@ -2,7 +2,7 @@ import * as dns from "dns"; import * as t from "tap"; import { Token } from "../agent/api/Token"; -import { Context, runWithContext } from "../agent/Context"; +import { Context, getContext, runWithContext } from "../agent/Context"; import { wrap } from "../helpers/wrap"; import { HTTPRequest } from "./HTTPRequest"; import { createTestAgent } from "../helpers/createTestAgent"; @@ -358,6 +358,15 @@ t.test("it works", (t) => { } ); + runWithContext(createContext(), () => { + const req = https.get("https://app.aikido.dev", (res) => { + t.same(getContext(), createContext()); + res.on("data", () => {}); + res.on("end", () => {}); + }); + req.end(); + }); + setTimeout(() => { t.end(); }, 3000); diff --git a/library/sinks/MariaDB.test.ts b/library/sinks/MariaDB.test.ts index f151e304e..998a69e32 100644 --- a/library/sinks/MariaDB.test.ts +++ b/library/sinks/MariaDB.test.ts @@ -1,5 +1,5 @@ import * as t from "tap"; -import { runWithContext, type Context } from "../agent/Context"; +import { getContext, runWithContext, type Context } from "../agent/Context"; import { createTestAgent } from "../helpers/createTestAgent"; import { MariaDB } from "./MariaDB"; @@ -267,9 +267,15 @@ t.test("it detects SQL injections using callbacks", (t) => { } } - connection.end(); - pool.end(); - t.end(); + runWithContext(dangerousContext, () => { + connection.query("SELECT 1;", () => { + t.same(getContext(), dangerousContext); + + connection.end(); + pool.end(); + t.end(); + }); + }); } ); } catch (error: any) { diff --git a/library/sinks/MariaDB.ts b/library/sinks/MariaDB.ts index df6bc11be..019ec0036 100644 --- a/library/sinks/MariaDB.ts +++ b/library/sinks/MariaDB.ts @@ -55,6 +55,7 @@ export class MariaDB implements Wrapper { for (const fn of functions) { wrapExport(exports.prototype, fn, pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => this.inspectQuery(args, fn), }); } @@ -66,6 +67,7 @@ export class MariaDB implements Wrapper { for (const fn of functions) { wrapExport(exports.prototype, fn, pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => this.inspectQuery(args, fn), }); } diff --git a/library/sinks/MySQL.ts b/library/sinks/MySQL.ts index e0b9a761a..28292f8f9 100644 --- a/library/sinks/MySQL.ts +++ b/library/sinks/MySQL.ts @@ -55,6 +55,7 @@ export class MySQL implements Wrapper { .onFileRequire("lib/Connection.js", (exports, pkgInfo) => { wrapExport(exports.prototype, "query", pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => this.inspectQuery(args), }); }); diff --git a/library/sinks/MySQL2.tests.ts b/library/sinks/MySQL2.tests.ts index 52bb6f521..739855704 100644 --- a/library/sinks/MySQL2.tests.ts +++ b/library/sinks/MySQL2.tests.ts @@ -1,5 +1,5 @@ import * as t from "tap"; -import { runWithContext, type Context } from "../agent/Context"; +import { getContext, runWithContext, type Context } from "../agent/Context"; import { MySQL2 } from "./MySQL2"; import { startTestAgent } from "../helpers/startTestAgent"; @@ -159,6 +159,15 @@ export function createMySQL2Tests(versionPkgName: string) { runWithContext(safeContext, () => { connection2!.query("-- This is a comment"); }); + + await runWithContext(dangerousContext, () => { + return new Promise((resolve) => { + connection2!.query("SELECT petname FROM cats;", () => { + t.same(getContext(), dangerousContext); + resolve(); + }); + }); + }); } catch (error: any) { t.fail(error); } finally { diff --git a/library/sinks/MySQL2.ts b/library/sinks/MySQL2.ts index fb8e5d5a9..e0dd942d3 100644 --- a/library/sinks/MySQL2.ts +++ b/library/sinks/MySQL2.ts @@ -86,6 +86,7 @@ export class MySQL2 implements Wrapper { // Wrap connection.query wrapExport(connectionPrototype, "query", pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => this.inspectQuery("mysql2.query", args), }); } @@ -94,6 +95,7 @@ export class MySQL2 implements Wrapper { // Wrap connection.execute wrapExport(connectionPrototype, "execute", pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => this.inspectQuery("mysql2.execute", args), }); } diff --git a/library/sinks/Postgres.ts b/library/sinks/Postgres.ts index c749ded7d..2d91c324d 100644 --- a/library/sinks/Postgres.ts +++ b/library/sinks/Postgres.ts @@ -55,6 +55,7 @@ export class Postgres implements Wrapper { .onRequire((exports, pkgInfo) => { wrapExport(exports.Client.prototype, "query", pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => this.inspectQuery(args), }); }); diff --git a/library/sinks/SQLite3.test.ts b/library/sinks/SQLite3.test.ts index 4a5b4aaaa..24f9c7e74 100644 --- a/library/sinks/SQLite3.test.ts +++ b/library/sinks/SQLite3.test.ts @@ -1,5 +1,5 @@ import * as t from "tap"; -import { runWithContext, type Context } from "../agent/Context"; +import { getContext, runWithContext, type Context } from "../agent/Context"; import { SQLite3 } from "./SQLite3"; import { promisify } from "util"; import { createTestAgent } from "../helpers/createTestAgent"; @@ -124,6 +124,25 @@ t.test("it detects SQL injections", async () => { 'SQLITE_ERROR: unrecognized token: "\' SELECT * FROM test"' ); } + + await new Promise((resolve) => { + runWithContext(dangerousContext, () => { + db.get("SELECT petname FROM cats;", () => { + t.match(getContext(), dangerousContext); + + try { + db.get("-- should be blocked", () => {}); + } catch (error: any) { + t.match( + error.message, + /Zen has blocked an SQL injection: sqlite3\.get\(\.\.\.\) originating from body\.myTitle/ + ); + } + + resolve(); + }); + }); + }); } catch (error: any) { t.fail(error); } finally { diff --git a/library/sinks/SQLite3.ts b/library/sinks/SQLite3.ts index 860869ada..5bdbbd3a5 100644 --- a/library/sinks/SQLite3.ts +++ b/library/sinks/SQLite3.ts @@ -81,6 +81,7 @@ export class SQLite3 implements Wrapper { for (const func of sqlFunctions) { wrapExport(db, func, pkgInfo, { kind: "sql_op", + bindContext: true, inspectArgs: (args) => { return this.inspectQuery(`sqlite3.${func}`, args); }, diff --git a/library/sinks/Shelljs.test.ts b/library/sinks/Shelljs.test.ts index bc068e57d..14d487c9d 100644 --- a/library/sinks/Shelljs.test.ts +++ b/library/sinks/Shelljs.test.ts @@ -1,5 +1,5 @@ import * as t from "tap"; -import { runWithContext, type Context } from "../agent/Context"; +import { getContext, runWithContext, type Context } from "../agent/Context"; import { Shelljs } from "./Shelljs"; import { ChildProcess } from "./ChildProcess"; import { FileSystem } from "./FileSystem"; @@ -201,3 +201,16 @@ t.test("invalid arguments are passed to shelljs", async () => { t.same(result.code, 1); }); }); + +t.test("context is available in callbacks", async (t) => { + const shell = require("shelljs"); + + await new Promise((resolve) => { + runWithContext(safeContext, () => { + shell.exec("ls", { silent: true }, () => { + t.match(getContext(), safeContext); + resolve(); + }); + }); + }); +});