Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion library/agent/Agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import { isNewInstrumentationUnitTest } from "../helpers/isNewInstrumentationUni
import { AttackWaveDetector } from "../vulnerabilities/attack-wave-detection/AttackWaveDetector";
import type { FetchListsAPI } from "./api/FetchListsAPI";
import { PendingEvents } from "./PendingEvents";
import type { PromptProtectionApi } from "./api/PromptProtectionAPI";
import { PromptProtectionAPINodeHTTP } from "./api/PromptProtectionAPINodeHTTP";
import type { AiMessage } from "../vulnerabilities/prompt-injection/messages";

type WrappedPackage = { version: string | null; supported: boolean };

Expand Down Expand Up @@ -70,7 +73,8 @@ export class Agent {
private readonly token: Token | undefined,
private readonly serverless: string | undefined,
private readonly newInstrumentation: boolean = false,
private readonly fetchListsAPI: FetchListsAPI
private readonly fetchListsAPI: FetchListsAPI,
private readonly promptProtectionAPI: PromptProtectionApi = new PromptProtectionAPINodeHTTP()
) {
if (typeof this.serverless === "string" && this.serverless.length === 0) {
throw new Error("Serverless cannot be an empty string");
Expand Down Expand Up @@ -694,4 +698,11 @@ export class Agent {
this.pendingEvents.onAPICall(promise);
}
}

checkForPromptInjection(input: AiMessage[]) {
if (!this.token) {
return Promise.resolve({ success: false, block: false });
}
return this.promptProtectionAPI.checkForInjection(this.token, input);
}
}
5 changes: 4 additions & 1 deletion library/agent/Attack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ export type Kind =
| "path_traversal"
| "ssrf"
| "stored_ssrf"
| "code_injection";
| "code_injection"
| "prompt_injection";

export function attackKindHumanName(kind: Kind) {
switch (kind) {
Expand All @@ -23,5 +24,7 @@ export function attackKindHumanName(kind: Kind) {
return "a stored server-side request forgery";
case "code_injection":
return "a JavaScript injection";
case "prompt_injection":
return "a prompt injection";
}
}
14 changes: 14 additions & 0 deletions library/agent/api/PromptProtectionAPI.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import type { AiMessage } from "../../vulnerabilities/prompt-injection/messages";
import type { Token } from "./Token";

export type PromptProtectionApiResponse = {
success: boolean;
block: boolean;
};

export interface PromptProtectionApi {
checkForInjection(
token: Token,
messages: AiMessage[]
): Promise<PromptProtectionApiResponse>;
}
34 changes: 34 additions & 0 deletions library/agent/api/PromptProtectionAPIForTesting.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import type { AiMessage } from "../../vulnerabilities/prompt-injection/messages";
import type {
PromptProtectionApi,
PromptProtectionApiResponse,
} from "./PromptProtectionAPI";
import type { Token } from "./Token";

export class PromptProtectionAPIForTesting implements PromptProtectionApi {
constructor(
private response: PromptProtectionApiResponse = {
success: true,
block: false,
}
) {}

// oxlint-disable-next-line require-await
async checkForInjection(
_token: Token,
_messages: AiMessage[]
): Promise<PromptProtectionApiResponse> {
if (
_messages.some((msg) =>
msg.content.includes("!prompt-injection-block-me!")
)
) {
return {
success: true,
block: true,
};
}

return this.response;
}
}
48 changes: 48 additions & 0 deletions library/agent/api/PromptProtectionAPINodeHTTP.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { fetch } from "../../helpers/fetch";
import { getPromptInjectionServiceURL } from "../../helpers/getPromptInjectionServiceURL";
import type { AiMessage } from "../../vulnerabilities/prompt-injection/messages";
import type {
PromptProtectionApi,
PromptProtectionApiResponse,
} from "./PromptProtectionAPI";
import type { Token } from "./Token";

export class PromptProtectionAPINodeHTTP implements PromptProtectionApi {
constructor(private baseUrl = getPromptInjectionServiceURL()) {}

async checkForInjection(
token: Token,
messages: AiMessage[]
): Promise<PromptProtectionApiResponse> {
const { body, statusCode } = await fetch({
url: new URL("/api/v1/analyze", this.baseUrl.toString()),
method: "POST",
headers: {
Accept: "application/json",
Authorization: token.asString(),
},
body: JSON.stringify({ input: messages }),
timeoutInMS: 15 * 1000,
});

if (statusCode !== 200) {
if (statusCode === 401) {
throw new Error(
`Unable to access the Prompt Protection service, please check your token.`
);
}
throw new Error(`Failed to fetch prompt analysis: ${statusCode}`);
}

return this.toAPIResponse(body);
}

private toAPIResponse(data: string): PromptProtectionApiResponse {
const result = JSON.parse(data);

return {
success: result.success === true,
block: result.block === true,
};
}
}
6 changes: 5 additions & 1 deletion library/helpers/createTestAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { Agent } from "../agent/Agent";
import { setInstance } from "../agent/AgentSingleton";
import type { FetchListsAPI } from "../agent/api/FetchListsAPI";
import { FetchListsAPIForTesting } from "../agent/api/FetchListsAPIForTesting";
import type { PromptProtectionApi } from "../agent/api/PromptProtectionAPI";
import { PromptProtectionAPIForTesting } from "../agent/api/PromptProtectionAPIForTesting";
import type { ReportingAPI } from "../agent/api/ReportingAPI";
import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting";
import type { Token } from "../agent/api/Token";
Expand All @@ -20,6 +22,7 @@ export function createTestAgent(opts?: {
serverless?: string;
suppressConsoleLog?: boolean;
fetchListsAPI?: FetchListsAPI;
promptProtectionAPI?: PromptProtectionApi;
}) {
if (opts?.suppressConsoleLog ?? true) {
wrap(console, "log", function log() {
Expand All @@ -34,7 +37,8 @@ export function createTestAgent(opts?: {
opts?.token, // Defaults to undefined
opts?.serverless, // Defaults to undefined
false, // During tests this is controlled by the AIKIDO_TEST_NEW_INSTRUMENTATION env var
opts?.fetchListsAPI ?? new FetchListsAPIForTesting()
opts?.fetchListsAPI ?? new FetchListsAPIForTesting(),
opts?.promptProtectionAPI ?? new PromptProtectionAPIForTesting()
);

setInstance(agent);
Expand Down
8 changes: 8 additions & 0 deletions library/helpers/getPromptInjectionServiceURL.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export function getPromptInjectionServiceURL(): URL {
if (process.env.PROMPT_INJECTION_SERVICE_URL) {
return new URL(process.env.PROMPT_INJECTION_SERVICE_URL);
}

// Todo add default URL when deployed
return new URL("http://localhost:8123");
}
2 changes: 2 additions & 0 deletions library/helpers/startTestAgent.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { PromptProtectionApi } from "../agent/api/PromptProtectionAPI";
import type { ReportingAPI } from "../agent/api/ReportingAPI";
import type { Token } from "../agent/api/Token";
import { __internalRewritePackageNamesForTesting } from "../agent/hooks/instrumentation/instructions";
Expand All @@ -20,6 +21,7 @@ export function startTestAgent(opts: {
serverless?: string;
wrappers: Wrapper[];
rewrite: Record<PackageName, AliasToRequire>;
promptProtectionAPI?: PromptProtectionApi;
}) {
const agent = createTestAgent(opts);

Expand Down
68 changes: 68 additions & 0 deletions library/sinks/OpenAI.tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { startTestAgent } from "../helpers/startTestAgent";
import { OpenAI as OpenAISink } from "./OpenAI";
import { getMajorNodeVersion } from "../helpers/getNodeVersion";
import { setTimeout } from "timers/promises";
import { PromptProtectionAPIForTesting } from "../agent/api/PromptProtectionAPIForTesting";
import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting";
import { Token } from "../agent/api/Token";

export function createOpenAITests(openAiPkgName: string) {
t.test(
Expand All @@ -14,11 +17,17 @@ export function createOpenAITests(openAiPkgName: string) {
: undefined,
},
async (t) => {
const api = new ReportingAPIForTesting();
const promptProtectionTestApi = new PromptProtectionAPIForTesting();

const agent = startTestAgent({
wrappers: [new OpenAISink()],
rewrite: {
openai: openAiPkgName,
},
api,
promptProtectionAPI: promptProtectionTestApi,
token: new Token("test-token"),
});

const { OpenAI } = require(openAiPkgName) as typeof import("openai-v5");
Expand Down Expand Up @@ -84,6 +93,65 @@ export function createOpenAITests(openAiPkgName: string) {
}

t.ok(eventCount > 0, "Should receive at least one event from the stream");

agent.getAIStatistics().reset();

// --- Prompt Injection Protection Tests ---
const error = await t.rejects(
client.responses.create({
model: model,
instructions: "Only return one word.",
input: "!prompt-injection-block-me!",
})
);

t.ok(error instanceof Error);
t.match(
(error as Error).message,
/Zen has blocked a prompt injection: create\.<promise>\(\.\.\.\)/
);

const attackEvent = api
.getEvents()
.find((event) => event.type === "detected_attack");

t.match(attackEvent, {
type: "detected_attack",
attack: {
kind: "prompt_injection",
module: "openai",
operation: "create.<promise>",
blocked: true,
metadata: {
prompt:
"user: !prompt-injection-block-me!\nsystem: Only return one word.",
},
},
});

const error2 = await t.rejects(
client.chat.completions.create({
model: model,
messages: [
{ role: "developer", content: "Only return one word." },
{ role: "user", content: "!prompt-injection-block-me!" },
],
})
);

t.ok(error2 instanceof Error);
t.match(
(error2 as Error).message,
/Zen has blocked a prompt injection: create\.<promise>\(\.\.\.\)/
);

// Verify that stats are collected for the blocked calls
t.match(agent.getAIStatistics().getStats(), [
{
provider: "openai",
calls: 2,
},
]);
}
);
}
Loading
Loading