Skip to content

Commit 4a2c81f

Browse files
authored
Merge pull request #629 from AikidoSec/ai-calls
Add support for OpenAI SDK and AWS Bedrock Client
2 parents ee4617b + 9ece919 commit 4a2c81f

File tree

21 files changed

+4014
-4
lines changed

21 files changed

+4014
-4
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ See list above for supported database drivers.
9898

9999
*[`@koa/router`](https://www.npmjs.com/package/@koa/router) 13.x, 12.x, 11.x and 10.x
100100

101+
### AI SDKs
102+
103+
Zen instruments the following AI SDKs to track which models are used and how many tokens are consumed, allowing you to monitor your AI usage and costs:
104+
105+
*[`openai`](https://www.npmjs.com/package/openai) 4.x
106+
*[`@aws-sdk/client-bedrock-runtime`](https://www.npmjs.com/package/@aws-sdk/client-bedrock-runtime) 3.x
107+
108+
_Note: Prompt injection attacks are currently not covered by Zen._
101109

102110
## Installation
103111

library/agent/AIStatistics.test.ts

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import * as t from "tap";
2+
import { AIStatistics } from "./AIStatistics";
3+
4+
t.test("it initializes with empty state", async () => {
5+
const stats = new AIStatistics();
6+
7+
t.same(stats.getStats(), []);
8+
t.equal(stats.isEmpty(), true);
9+
});
10+
11+
t.test("it tracks basic AI calls", async () => {
12+
const stats = new AIStatistics();
13+
14+
stats.onAICall({
15+
provider: "openai",
16+
model: "gpt-4",
17+
inputTokens: 100,
18+
outputTokens: 50,
19+
});
20+
21+
const result = stats.getStats();
22+
t.equal(result.length, 1);
23+
t.same(result[0], {
24+
provider: "openai",
25+
model: "gpt-4",
26+
calls: 1,
27+
tokens: {
28+
input: 100,
29+
output: 50,
30+
total: 150,
31+
},
32+
});
33+
34+
t.equal(stats.isEmpty(), false);
35+
});
36+
37+
t.test("it tracks multiple calls to the same provider/model", async () => {
38+
const stats = new AIStatistics();
39+
40+
stats.onAICall({
41+
provider: "openai",
42+
model: "gpt-4",
43+
inputTokens: 100,
44+
outputTokens: 50,
45+
});
46+
47+
stats.onAICall({
48+
provider: "openai",
49+
model: "gpt-4",
50+
inputTokens: 200,
51+
outputTokens: 75,
52+
});
53+
54+
const result = stats.getStats();
55+
t.same(result.length, 1);
56+
t.same(result[0], {
57+
provider: "openai",
58+
model: "gpt-4",
59+
calls: 2,
60+
tokens: {
61+
input: 300,
62+
output: 125,
63+
total: 425,
64+
},
65+
});
66+
});
67+
68+
t.test(
69+
"it tracks different provider/model combinations separately",
70+
async () => {
71+
const stats = new AIStatistics();
72+
73+
stats.onAICall({
74+
provider: "openai",
75+
model: "gpt-4",
76+
inputTokens: 100,
77+
outputTokens: 50,
78+
});
79+
80+
stats.onAICall({
81+
provider: "openai",
82+
model: "gpt-3.5-turbo",
83+
inputTokens: 80,
84+
outputTokens: 40,
85+
});
86+
87+
stats.onAICall({
88+
provider: "anthropic",
89+
model: "claude-3",
90+
inputTokens: 120,
91+
outputTokens: 60,
92+
});
93+
94+
const result = stats.getStats();
95+
t.equal(result.length, 3);
96+
97+
// Sort by provider:model for consistent testing
98+
result.sort((a, b) =>
99+
`${a.provider}:${a.model}`.localeCompare(`${b.provider}:${b.model}`)
100+
);
101+
102+
t.same(result[0], {
103+
provider: "anthropic",
104+
model: "claude-3",
105+
calls: 1,
106+
tokens: {
107+
input: 120,
108+
output: 60,
109+
total: 180,
110+
},
111+
});
112+
113+
t.same(result[1], {
114+
provider: "openai",
115+
model: "gpt-3.5-turbo",
116+
calls: 1,
117+
tokens: {
118+
input: 80,
119+
output: 40,
120+
total: 120,
121+
},
122+
});
123+
124+
t.same(result[2], {
125+
provider: "openai",
126+
model: "gpt-4",
127+
calls: 1,
128+
tokens: {
129+
input: 100,
130+
output: 50,
131+
total: 150,
132+
},
133+
});
134+
}
135+
);
136+
137+
t.test("it resets all statistics", async () => {
138+
const stats = new AIStatistics();
139+
140+
stats.onAICall({
141+
provider: "openai",
142+
model: "gpt-4",
143+
inputTokens: 100,
144+
outputTokens: 50,
145+
});
146+
147+
stats.onAICall({
148+
provider: "anthropic",
149+
model: "claude-3",
150+
inputTokens: 120,
151+
outputTokens: 60,
152+
});
153+
154+
t.equal(stats.isEmpty(), false);
155+
t.equal(stats.getStats().length, 2);
156+
157+
stats.reset();
158+
159+
t.equal(stats.isEmpty(), true);
160+
t.same(stats.getStats(), []);
161+
});
162+
163+
t.test("it handles zero token inputs", async () => {
164+
const stats = new AIStatistics();
165+
166+
stats.onAICall({
167+
provider: "openai",
168+
model: "gpt-4",
169+
inputTokens: 0,
170+
outputTokens: 0,
171+
});
172+
173+
const result = stats.getStats();
174+
t.equal(result.length, 1);
175+
t.same(result[0].tokens, {
176+
input: 0,
177+
output: 0,
178+
total: 0,
179+
});
180+
});
181+
182+
t.test("called with empty provider", async () => {
183+
const stats = new AIStatistics();
184+
185+
stats.onAICall({
186+
provider: "",
187+
model: "gpt-4",
188+
inputTokens: 100,
189+
outputTokens: 50,
190+
});
191+
192+
t.same(true, stats.isEmpty());
193+
});
194+
195+
t.test("called with empty model", async () => {
196+
const stats = new AIStatistics();
197+
198+
stats.onAICall({
199+
provider: "openai",
200+
model: "",
201+
inputTokens: 100,
202+
outputTokens: 50,
203+
});
204+
205+
t.same(true, stats.isEmpty());
206+
});

library/agent/AIStatistics.ts

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
type AIProviderStats = {
2+
provider: string;
3+
model: string;
4+
calls: number;
5+
tokens: {
6+
input: number;
7+
output: number;
8+
total: number;
9+
};
10+
};
11+
12+
export class AIStatistics {
13+
private calls: Map<string, AIProviderStats> = new Map();
14+
15+
private getProviderKey(provider: string, model: string): string {
16+
return `${provider}:${model}`;
17+
}
18+
19+
private getRouteKey(path: string, method: string): string {
20+
return `${method}:${path}`;
21+
}
22+
23+
private ensureProviderStats(
24+
provider: string,
25+
model: string
26+
): AIProviderStats {
27+
const key = this.getProviderKey(provider, model);
28+
29+
if (!this.calls.has(key)) {
30+
this.calls.set(key, {
31+
provider,
32+
model,
33+
calls: 0,
34+
tokens: {
35+
input: 0,
36+
output: 0,
37+
total: 0,
38+
},
39+
});
40+
}
41+
42+
return this.calls.get(key)!;
43+
}
44+
45+
onAICall({
46+
provider,
47+
model,
48+
inputTokens,
49+
outputTokens,
50+
}: {
51+
provider: string;
52+
model: string;
53+
inputTokens: number;
54+
outputTokens: number;
55+
}) {
56+
if (!provider || !model) {
57+
return;
58+
}
59+
60+
const providerStats = this.ensureProviderStats(provider, model);
61+
providerStats.calls += 1;
62+
providerStats.tokens.input += inputTokens;
63+
providerStats.tokens.output += outputTokens;
64+
providerStats.tokens.total += inputTokens + outputTokens;
65+
}
66+
67+
getStats() {
68+
return Array.from(this.calls.values()).map((stats) => {
69+
return {
70+
provider: stats.provider,
71+
model: stats.model,
72+
calls: stats.calls,
73+
tokens: {
74+
input: stats.tokens.input,
75+
output: stats.tokens.output,
76+
total: stats.tokens.total,
77+
},
78+
};
79+
});
80+
}
81+
82+
reset() {
83+
this.calls.clear();
84+
}
85+
86+
isEmpty(): boolean {
87+
return this.calls.size === 0;
88+
}
89+
}

library/agent/Agent.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import { Wrapper } from "./Wrapper";
2626
import { isAikidoCI } from "../helpers/isAikidoCI";
2727
import { AttackLogger } from "./AttackLogger";
2828
import { Packages } from "./Packages";
29+
import { AIStatistics } from "./AIStatistics";
2930

3031
type WrappedPackage = { version: string | null; supported: boolean };
3132

@@ -58,6 +59,7 @@ export class Agent {
5859
maxPerfSamplesInMemory: 5000,
5960
maxCompressedStatsInMemory: 20, // per operation
6061
});
62+
private aiStatistics = new AIStatistics();
6163
private middlewareInstalled = false;
6264
private attackLogger = new AttackLogger(1000);
6365

@@ -85,6 +87,10 @@ export class Agent {
8587
return this.statistics;
8688
}
8789

90+
getAIStatistics() {
91+
return this.aiStatistics;
92+
}
93+
8894
unableToPreventPrototypePollution(
8995
incompatiblePackages: Record<string, string>
9096
) {
@@ -295,12 +301,14 @@ export class Agent {
295301
if (this.token) {
296302
this.logger.log("Heartbeat...");
297303
const stats = this.statistics.getStats();
304+
const aiStats = this.aiStatistics.getStats();
298305
const routes = this.routes.asArray();
299306
const outgoingDomains = this.hostnames.asArray();
300307
const users = this.users.asArray();
301308
const packages = this.packages.asArray();
302309
const endedAt = Date.now();
303310
this.statistics.reset();
311+
this.aiStatistics.reset();
304312
this.routes.clear();
305313
this.hostnames.clear();
306314
this.users.clear();
@@ -320,6 +328,7 @@ export class Agent {
320328
ipAddresses: stats.ipAddresses,
321329
sqlTokenizationFailures: stats.sqlTokenizationFailures,
322330
},
331+
ai: aiStats,
323332
packages,
324333
hostnames: outgoingDomains,
325334
routes: routes,
@@ -358,8 +367,10 @@ export class Agent {
358367
const now = performance.now();
359368
const diff = now - this.lastHeartbeat;
360369
const shouldSendHeartbeat = diff > this.sendHeartbeatEveryMS;
370+
const hasStats =
371+
!this.statistics.isEmpty() || !this.aiStatistics.isEmpty();
361372
const canSendInitialStats =
362-
!this.serviceConfig.hasReceivedAnyStats() && !this.statistics.isEmpty();
373+
!this.serviceConfig.hasReceivedAnyStats() && hasStats;
363374
const shouldReportInitialStats =
364375
!this.reportedInitialStats && canSendInitialStats;
365376

0 commit comments

Comments
 (0)