Skip to content

Commit c68d0ba

Browse files
committed
Add support for streams
1 parent be6a396 commit c68d0ba

File tree

2 files changed

+126
-3
lines changed

2 files changed

+126
-3
lines changed

library/sinks/AiSDK.test.ts

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { runWithContext, type Context } from "../agent/Context";
55
import { getMajorNodeVersion } from "../helpers/getNodeVersion";
66
import { getInstance } from "../agent/AgentSingleton";
77
import { z } from "zod";
8+
import { setTimeout } from "timers/promises";
89

910
t.test(
1011
"It works",
@@ -37,7 +38,7 @@ t.test(
3738

3839
const { google } =
3940
require("@ai-sdk/google") as typeof import("@ai-sdk/google");
40-
const { generateText, generateObject } =
41+
const { generateText, generateObject, streamText, streamObject } =
4142
require("ai") as typeof import("ai");
4243

4344
await runWithContext(getTestContext(), async () => {
@@ -73,6 +74,8 @@ t.test(
7374
"Total tokens should match input + output"
7475
);
7576

77+
await setTimeout(400);
78+
7679
const resultObj = await generateObject({
7780
model: google("models/gemini-2.0-flash-lite"),
7881
prompt: "Return numbers one to five",
@@ -91,6 +94,74 @@ t.test(
9194
},
9295
},
9396
]);
97+
98+
await setTimeout(400);
99+
100+
const stream = streamText({
101+
model: google("models/gemini-2.0-flash"),
102+
prompt: "What is Zen by Aikido Security? Return one sentence.",
103+
});
104+
105+
let streamedText = "";
106+
for await (const chunk of stream.textStream) {
107+
streamedText += chunk;
108+
}
109+
110+
t.ok(streamedText.length > 0, "Streamed text should not be empty");
111+
112+
t.match(agent.getAIStatistics().getStats(), [
113+
{
114+
provider: "gemini",
115+
model: "gemini-2.0-flash-lite",
116+
calls: 2,
117+
tokens: {
118+
input: 23,
119+
},
120+
},
121+
{
122+
provider: "gemini",
123+
model: "gemini-2.0-flash",
124+
calls: 1,
125+
tokens: {
126+
input: 12,
127+
},
128+
},
129+
]);
130+
131+
await setTimeout(400);
132+
133+
const objectStream = streamObject({
134+
model: google("models/gemini-2.0-flash"),
135+
prompt: "Return numbers one to five",
136+
output: "array",
137+
schema: z.array(z.number()),
138+
});
139+
140+
const streamedObject = [];
141+
for await (const chunk of objectStream.elementStream) {
142+
streamedObject.push(chunk);
143+
}
144+
145+
t.same(streamedObject, [[1], [2], [3], [4], [5]]);
146+
147+
t.match(agent.getAIStatistics().getStats(), [
148+
{
149+
provider: "gemini",
150+
model: "gemini-2.0-flash-lite",
151+
calls: 2,
152+
tokens: {
153+
input: 23,
154+
},
155+
},
156+
{
157+
provider: "gemini",
158+
model: "gemini-2.0-flash",
159+
calls: 2,
160+
tokens: {
161+
input: 23,
162+
},
163+
},
164+
]);
94165
});
95166
}
96167
);

library/sinks/AiSDK.ts

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ export class AiSDK implements Wrapper {
111111
return modelName;
112112
}
113113

114-
wrap(hooks: Hooks) {
115-
const interceptors: InterceptorObject = {
114+
private getInterceptors(): InterceptorObject {
115+
return {
116116
kind: "ai_op",
117117
modifyReturnValue: (args, returnValue, agent) => {
118118
if (returnValue instanceof Promise) {
@@ -128,14 +128,54 @@ export class AiSDK implements Wrapper {
128128
return returnValue;
129129
},
130130
};
131+
}
132+
133+
private getStreamInterceptors(): InterceptorObject {
134+
return {
135+
kind: "ai_op",
136+
modifyReturnValue: (args, returnValue, agent) => {
137+
if (
138+
!returnValue ||
139+
typeof returnValue !== "object" ||
140+
!("response" in returnValue) ||
141+
!(returnValue.response instanceof Promise) ||
142+
!("usage" in returnValue) ||
143+
!(returnValue.usage instanceof Promise)
144+
) {
145+
return returnValue;
146+
}
147+
148+
Promise.all([returnValue.response, returnValue.usage]).then(
149+
([response, usage]) => {
150+
try {
151+
this.inspectAiCall(agent, args, {
152+
response,
153+
usage,
154+
});
155+
} catch {
156+
// If we don't catch these errors, it will result in an unhandled promise rejection!
157+
}
158+
}
159+
);
160+
161+
return returnValue;
162+
},
163+
};
164+
}
131165

166+
wrap(hooks: Hooks) {
132167
hooks
133168
.addPackage("ai")
134169
.withVersion("^4.0.0")
135170
.onRequire((exports, pkgInfo) => {
136171
// Can't wrap it directly because it's a readonly proxy
137172
const generateTextFunc = exports.generateText;
138173
const generateObjectFunc = exports.generateObject;
174+
const streamTextFunc = exports.streamText;
175+
const streamObjectFunc = exports.streamObject;
176+
177+
const interceptors = this.getInterceptors();
178+
const streamInterceptors = this.getStreamInterceptors();
139179

140180
return {
141181
...exports,
@@ -151,6 +191,18 @@ export class AiSDK implements Wrapper {
151191
pkgInfo,
152192
interceptors
153193
),
194+
streamText: wrapExport(
195+
streamTextFunc,
196+
undefined,
197+
pkgInfo,
198+
streamInterceptors
199+
),
200+
streamObject: wrapExport(
201+
streamObjectFunc,
202+
undefined,
203+
pkgInfo,
204+
streamInterceptors
205+
),
154206
};
155207
});
156208
}

0 commit comments

Comments
 (0)