Skip to content

Commit 47b717c

Browse files
committed
Update readme and sources, fix tests
1 parent bc5c163 commit 47b717c

File tree

5 files changed

+87
-70
lines changed

5 files changed

+87
-70
lines changed

.github/workflows/unit-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ jobs:
8080
- run: npm run install-lib-only
8181
- run: npm run build
8282
- run: npm run test:ci
83+
env:
84+
GOOGLE_GENERATIVE_AI_API_KEY: ${{ secrets.GOOGLE_GENERATIVE_AI_API_KEY }}
8385
- name: "Upload coverage"
8486
uses: codecov/codecov-action@0565863a31f2c772f9f0395002a31e3f06189574 # v5
8587
with:

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ See list above for supported database drivers.
9898

9999
*[`@koa/router`](https://www.npmjs.com/package/@koa/router) 13.x, 12.x, 11.x and 10.x
100100

101+
### AI SDKs
102+
*[`ai`](https://www.npmjs.com/package/ai) 4.x
101103

102104
## Installation
103105

library/agent/Source.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export const SOURCES = [
99
"subdomains",
1010
"markUnsafe",
1111
"url",
12+
"aiToolParams",
1213
] as const;
1314

1415
export type Source = (typeof SOURCES)[number];

library/agent/protect.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import { Fastify } from "../sources/Fastify";
4848
import { Koa } from "../sources/Koa";
4949
import { ClickHouse } from "../sinks/ClickHouse";
5050
import { Prisma } from "../sinks/Prisma";
51+
import { AiSDK } from "../sources/AiSDK";
5152

5253
function getLogger(): Logger {
5354
if (isDebugging()) {
@@ -141,6 +142,7 @@ export function getWrappers() {
141142
new ClickHouse(),
142143
new Prisma(),
143144
// new Function(), Disabled because functionName.constructor === Function is false after patching global
145+
new AiSDK(),
144146
];
145147
}
146148

library/sources/AiSDK.test.ts

Lines changed: 80 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,88 +2,98 @@ import * as t from "tap";
22
import { startTestAgent } from "../helpers/startTestAgent";
33
import { AiSDK } from "./AiSDK";
44
import { getContext, runWithContext, type Context } from "../agent/Context";
5+
import { getMajorNodeVersion } from "../helpers/getNodeVersion";
56

6-
t.test("It works with agentic", async (t) => {
7-
startTestAgent({
8-
wrappers: [new AiSDK()],
9-
rewrite: {},
10-
});
11-
12-
const getTestContext = (message: string): Context => {
13-
return {
14-
remoteAddress: "::1",
15-
method: "POST",
16-
url: "http://localhost:4000",
17-
query: {
18-
message,
19-
},
20-
body: undefined,
21-
headers: {},
22-
cookies: {},
23-
routeParams: {},
24-
source: "express",
25-
route: "/posts/:id",
26-
};
27-
};
28-
29-
const { google } =
30-
require("@ai-sdk/google") as typeof import("@ai-sdk/google");
31-
const { generateText, tool } = require("ai") as typeof import("ai");
32-
const { z } = require("zod") as typeof import("zod");
33-
34-
const callWithPrompt = async (prompt: string) => {
35-
return await generateText({
36-
model: google("models/gemini-2.0-flash"),
37-
tools: {
38-
weather: tool({
39-
description: "Get the weather in a location",
40-
parameters: z.object({
41-
location: z
42-
.string()
43-
.describe("The location to get the weather for"),
44-
}),
45-
execute: async ({ location }) => {
46-
const temperature = location === "Norway" ? 5 : 24;
47-
return {
48-
temperature,
49-
context: getContext(),
50-
};
51-
},
52-
}),
53-
},
54-
prompt: prompt,
7+
t.test(
8+
"It works with AI tool execution",
9+
{
10+
skip:
11+
!process.env.GOOGLE_GENERATIVE_AI_API_KEY || getMajorNodeVersion() < 20
12+
? "Google API key not set or Node version < 20"
13+
: undefined,
14+
},
15+
async (t) => {
16+
startTestAgent({
17+
wrappers: [new AiSDK()],
18+
rewrite: {},
5519
});
56-
};
57-
58-
await runWithContext(
59-
getTestContext("What is the weather in San Francisco?"),
60-
async () => {
61-
const result = await callWithPrompt(
62-
"What is the weather in San Francisco?"
63-
);
6420

65-
t.same(result.toolResults.length, 1);
66-
t.same(result.toolResults[0].toolName, "weather");
67-
t.same(result.toolResults[0].result.temperature, 24);
68-
t.match(result.toolResults[0].result.context, {
21+
const getTestContext = (message: string): Context => {
22+
return {
6923
remoteAddress: "::1",
7024
method: "POST",
7125
url: "http://localhost:4000",
7226
query: {
73-
message: "What is the weather in San Francisco?",
27+
message,
7428
},
7529
body: undefined,
7630
headers: {},
7731
cookies: {},
7832
routeParams: {},
7933
source: "express",
8034
route: "/posts/:id",
81-
aiToolParams: [
82-
{
83-
location: "San Francisco",
84-
},
85-
],
35+
};
36+
};
37+
38+
const { google } =
39+
require("@ai-sdk/google") as typeof import("@ai-sdk/google");
40+
const { generateText, tool } = require("ai") as typeof import("ai");
41+
const { z } = require("zod") as typeof import("zod");
42+
43+
const callWithPrompt = async (prompt: string) => {
44+
return await generateText({
45+
model: google("models/gemini-2.0-flash-lite"),
46+
tools: {
47+
weather: tool({
48+
description: "Get the weather in a location",
49+
parameters: z.object({
50+
location: z
51+
.string()
52+
.describe("The location to get the weather for"),
53+
}),
54+
execute: async ({ location }) => {
55+
const temperature = location === "Norway" ? 5 : 24;
56+
return {
57+
temperature,
58+
context: getContext(),
59+
};
60+
},
61+
}),
62+
},
63+
prompt: prompt,
8664
});
87-
}
88-
);
89-
});
65+
};
66+
67+
await runWithContext(
68+
getTestContext("What is the weather in San Francisco?"),
69+
async () => {
70+
const result = await callWithPrompt(
71+
"What is the weather in San Francisco?"
72+
);
73+
74+
t.same(result.toolResults.length, 1);
75+
t.same(result.toolResults[0].toolName, "weather");
76+
t.same(result.toolResults[0].result.temperature, 24);
77+
t.match(result.toolResults[0].result.context, {
78+
remoteAddress: "::1",
79+
method: "POST",
80+
url: "http://localhost:4000",
81+
query: {
82+
message: "What is the weather in San Francisco?",
83+
},
84+
body: undefined,
85+
headers: {},
86+
cookies: {},
87+
routeParams: {},
88+
source: "express",
89+
route: "/posts/:id",
90+
aiToolParams: [
91+
{
92+
location: "San Francisco",
93+
},
94+
],
95+
});
96+
}
97+
);
98+
}
99+
);

0 commit comments

Comments
 (0)