Skip to content

Commit a78a7ee

Browse files
committed
feat: Update to AI v5 in vercel-ai-sdk integration
1 parent ed6f275 commit a78a7ee

File tree

1 file changed

+167
-99
lines changed
  • integrations/vercel-ai-sdk/typescript/src

1 file changed

+167
-99
lines changed

integrations/vercel-ai-sdk/typescript/src/index.ts

Lines changed: 167 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ import {
1818
} from "@ag-ui/client";
1919
import { Observable } from "rxjs";
2020
import {
21-
CoreMessage,
22-
LanguageModelV1,
23-
processDataStream,
21+
ModelMessage,
22+
LanguageModel,
2423
streamText,
2524
tool as createVercelAISDKTool,
25+
Tool,
2626
ToolChoice,
2727
ToolSet,
28+
stepCountIs,
2829
} from "ai";
2930
import { randomUUID } from "@ag-ui/client";
3031
import { z } from "zod";
@@ -39,13 +40,13 @@ type ProcessedEvent =
3940
| ToolCallStartEvent;
4041

4142
interface VercelAISDKAgentConfig extends AgentConfig {
42-
model: LanguageModelV1;
43+
model: LanguageModel;
4344
maxSteps?: number;
4445
toolChoice?: ToolChoice<Record<string, unknown>>;
4546
}
4647

4748
export class VercelAISDKAgent extends AbstractAgent {
48-
model: LanguageModelV1;
49+
model: LanguageModel;
4950
maxSteps: number;
5051
toolChoice: ToolChoice<Record<string, unknown>>;
5152
constructor({ model, maxSteps, toolChoice, ...rest }: VercelAISDKAgentConfig) {
@@ -65,12 +66,15 @@ export class VercelAISDKAgent extends AbstractAgent {
6566
runId: input.runId,
6667
} as RunStartedEvent);
6768

69+
const toolSet = convertToolToVercelAISDKTools(input.tools);
70+
const stopCondition = this.maxSteps > 0 ? stepCountIs(this.maxSteps) : undefined;
71+
6872
const response = streamText({
6973
model: this.model,
70-
messages: convertMessagesToVercelAISDKMessages(input.messages),
71-
tools: convertToolToVerlAISDKTools(input.tools),
72-
maxSteps: this.maxSteps,
74+
messages: convertMessagesToModelMessages(input.messages),
7375
toolChoice: this.toolChoice,
76+
...(Object.keys(toolSet).length > 0 ? { tools: toolSet } : {}),
77+
...(stopCondition ? { stopWhen: stopCondition } : {}),
7478
});
7579

7680
let messageId = randomUUID();
@@ -82,93 +86,130 @@ export class VercelAISDKAgent extends AbstractAgent {
8286
};
8387
finalMessages.push(assistantMessage);
8488

85-
processDataStream({
86-
stream: response.toDataStreamResponse().body!,
87-
onTextPart: (text) => {
88-
assistantMessage.content += text;
89-
const event: TextMessageChunkEvent = {
90-
type: EventType.TEXT_MESSAGE_CHUNK,
91-
role: "assistant",
92-
messageId,
93-
delta: text,
94-
};
95-
subscriber.next(event);
96-
},
97-
onFinishMessagePart: () => {
98-
// Emit message snapshot
99-
const event: MessagesSnapshotEvent = {
100-
type: EventType.MESSAGES_SNAPSHOT,
101-
messages: finalMessages,
102-
};
103-
subscriber.next(event);
89+
let hasCompleted = false;
90+
const seenToolCallIds = new Set<string>();
91+
92+
const finalizeRun = () => {
93+
if (hasCompleted) {
94+
return;
95+
}
96+
hasCompleted = true;
97+
98+
const snapshotEvent: MessagesSnapshotEvent = {
99+
type: EventType.MESSAGES_SNAPSHOT,
100+
messages: finalMessages,
101+
};
102+
subscriber.next(snapshotEvent);
104103

105-
// Emit run finished event
106-
subscriber.next({
107-
type: EventType.RUN_FINISHED,
108-
threadId: input.threadId,
109-
runId: input.runId,
110-
} as RunFinishedEvent);
104+
subscriber.next({
105+
type: EventType.RUN_FINISHED,
106+
threadId: input.threadId,
107+
runId: input.runId,
108+
} as RunFinishedEvent);
111109

112-
// Complete the observable
113-
subscriber.complete();
114-
},
115-
onToolCallPart(streamPart) {
116-
let toolCall: ToolCall = {
117-
id: streamPart.toolCallId,
118-
type: "function",
119-
function: {
120-
name: streamPart.toolName,
121-
arguments: JSON.stringify(streamPart.args),
122-
},
123-
};
124-
assistantMessage.toolCalls!.push(toolCall);
110+
subscriber.complete();
111+
};
125112

126-
const startEvent: ToolCallStartEvent = {
127-
type: EventType.TOOL_CALL_START,
128-
parentMessageId: messageId,
129-
toolCallId: streamPart.toolCallId,
130-
toolCallName: streamPart.toolName,
131-
};
132-
subscriber.next(startEvent);
113+
const processStream = async () => {
114+
try {
115+
for await (const part of response.fullStream) {
116+
switch (part.type) {
117+
case "text-delta": {
118+
if (!part.text) {
119+
break;
120+
}
121+
assistantMessage.content += part.text;
122+
const event: TextMessageChunkEvent = {
123+
type: EventType.TEXT_MESSAGE_CHUNK,
124+
role: "assistant",
125+
messageId,
126+
delta: part.text,
127+
};
128+
subscriber.next(event);
129+
break;
130+
}
131+
case "tool-call": {
132+
if (seenToolCallIds.has(part.toolCallId)) {
133+
break;
134+
}
135+
seenToolCallIds.add(part.toolCallId);
136+
const argumentsJson = safeStringify(part.input);
137+
let toolCall: ToolCall = {
138+
id: part.toolCallId,
139+
type: "function",
140+
function: {
141+
name: part.toolName,
142+
arguments: argumentsJson,
143+
},
144+
};
145+
assistantMessage.toolCalls!.push(toolCall);
133146

134-
const argsEvent: ToolCallArgsEvent = {
135-
type: EventType.TOOL_CALL_ARGS,
136-
toolCallId: streamPart.toolCallId,
137-
delta: JSON.stringify(streamPart.args),
138-
};
139-
subscriber.next(argsEvent);
147+
const startEvent: ToolCallStartEvent = {
148+
type: EventType.TOOL_CALL_START,
149+
parentMessageId: messageId,
150+
toolCallId: part.toolCallId,
151+
toolCallName: part.toolName,
152+
};
153+
subscriber.next(startEvent);
140154

141-
const endEvent: ToolCallEndEvent = {
142-
type: EventType.TOOL_CALL_END,
143-
toolCallId: streamPart.toolCallId,
144-
};
145-
subscriber.next(endEvent);
146-
},
147-
onToolResultPart(streamPart) {
148-
const toolMessage: ToolMessage = {
149-
role: "tool",
150-
id: randomUUID(),
151-
toolCallId: streamPart.toolCallId,
152-
content: JSON.stringify(streamPart.result),
153-
};
154-
finalMessages.push(toolMessage);
155-
},
156-
onErrorPart(streamPart) {
157-
subscriber.error(streamPart);
158-
},
159-
}).catch((error) => {
160-
console.error("catch error", error);
161-
// Handle error
162-
subscriber.error(error);
163-
});
155+
const argsEvent: ToolCallArgsEvent = {
156+
type: EventType.TOOL_CALL_ARGS,
157+
toolCallId: part.toolCallId,
158+
delta: argumentsJson,
159+
};
160+
subscriber.next(argsEvent);
161+
162+
const endEvent: ToolCallEndEvent = {
163+
type: EventType.TOOL_CALL_END,
164+
toolCallId: part.toolCallId,
165+
};
166+
subscriber.next(endEvent);
167+
break;
168+
}
169+
case "tool-result": {
170+
if (part.preliminary) {
171+
break;
172+
}
173+
const toolMessage: ToolMessage = {
174+
role: "tool",
175+
id: randomUUID(),
176+
toolCallId: part.toolCallId,
177+
content: safeStringify(part.output),
178+
};
179+
finalMessages.push(toolMessage);
180+
break;
181+
}
182+
case "tool-error": {
183+
subscriber.error(part.error ?? new Error(`Tool ${part.toolName} failed`));
184+
return;
185+
}
186+
case "error": {
187+
subscriber.error(part.error ?? new Error("Stream error"));
188+
return;
189+
}
190+
case "finish": {
191+
finalizeRun();
192+
return;
193+
}
194+
default:
195+
break;
196+
}
197+
}
198+
finalizeRun();
199+
} catch (error) {
200+
subscriber.error(error);
201+
}
202+
};
203+
204+
processStream();
164205

165206
return () => {};
166207
});
167208
}
168209
}
169210

170-
export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreMessage[] {
171-
const result: CoreMessage[] = [];
211+
export function convertMessagesToModelMessages(messages: Message[]): ModelMessage[] {
212+
const result: ModelMessage[] = [];
172213

173214
for (const message of messages) {
174215
if (message.role === "assistant") {
@@ -178,7 +219,7 @@ export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreM
178219
type: "tool-call",
179220
toolCallId: toolCall.id,
180221
toolName: toolCall.function.name,
181-
args: JSON.parse(toolCall.function.arguments),
222+
input: JSON.parse(toolCall.function.arguments),
182223
});
183224
}
184225
result.push({
@@ -209,7 +250,7 @@ export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreM
209250
type: "tool-result",
210251
toolCallId: message.toolCallId,
211252
toolName: toolName,
212-
result: message.content,
253+
output: parseToolMessageContent(message.content),
213254
},
214255
],
215256
});
@@ -219,9 +260,9 @@ export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreM
219260
return result;
220261
}
221262

222-
export function convertJsonSchemaToZodSchema(jsonSchema: any, required: boolean): z.ZodSchema {
263+
export function convertJsonSchemaToZodSchema(jsonSchema: any, required: boolean): z.ZodTypeAny {
223264
if (jsonSchema.type === "object") {
224-
const spec: { [key: string]: z.ZodSchema } = {};
265+
const spec: Record<string, z.ZodTypeAny> = {};
225266

226267
if (!jsonSchema.properties || !Object.keys(jsonSchema.properties).length) {
227268
return !required ? z.object(spec).optional() : z.object(spec);
@@ -252,15 +293,42 @@ export function convertJsonSchemaToZodSchema(jsonSchema: any, required: boolean)
252293
throw new Error("Invalid JSON schema");
253294
}
254295

255-
export function convertToolToVerlAISDKTools(tools: RunAgentInput["tools"]): ToolSet {
256-
return tools.reduce(
257-
(acc: ToolSet, tool: RunAgentInput["tools"][number]) => ({
258-
...acc,
259-
[tool.name]: createVercelAISDKTool({
260-
description: tool.description,
261-
parameters: convertJsonSchemaToZodSchema(tool.parameters, true),
262-
}),
263-
}),
264-
{},
265-
);
296+
export function convertToolToVercelAISDKTools(tools: RunAgentInput["tools"]): ToolSet {
297+
const toolSet: Record<string, unknown> = {};
298+
299+
for (const tool of tools) {
300+
const inputSchema = convertJsonSchemaToZodSchema(tool.parameters, true) as z.ZodTypeAny;
301+
const toolDefinition = {
302+
description: tool.description,
303+
inputSchema,
304+
outputSchema: z.any(),
305+
} as unknown;
306+
toolSet[tool.name] = createVercelAISDKTool(toolDefinition as any);
307+
}
308+
309+
return toolSet as ToolSet;
310+
}
311+
312+
function safeStringify(value: unknown): string {
313+
if (typeof value === "string") {
314+
return value;
315+
}
316+
try {
317+
return JSON.stringify(value ?? {});
318+
} catch {
319+
return JSON.stringify({ value: String(value) });
320+
}
321+
}
322+
323+
function parseToolMessageContent(content: string) {
324+
if (!content) {
325+
return { type: "text" as const, value: "" };
326+
}
327+
328+
try {
329+
const parsed = JSON.parse(content);
330+
return { type: "json" as const, value: parsed };
331+
} catch {
332+
return { type: "text" as const, value: content };
333+
}
266334
}

0 commit comments

Comments
 (0)