@@ -18,13 +18,14 @@ import {
1818} from "@ag-ui/client" ;
1919import { Observable } from "rxjs" ;
2020import {
21- CoreMessage ,
22- LanguageModelV1 ,
23- processDataStream ,
21+ ModelMessage ,
22+ LanguageModel ,
2423 streamText ,
2524 tool as createVercelAISDKTool ,
25+ Tool ,
2626 ToolChoice ,
2727 ToolSet ,
28+ stepCountIs ,
2829} from "ai" ;
2930import { randomUUID } from "@ag-ui/client" ;
3031import { z } from "zod" ;
@@ -39,13 +40,13 @@ type ProcessedEvent =
3940 | ToolCallStartEvent ;
4041
4142interface VercelAISDKAgentConfig extends AgentConfig {
42- model : LanguageModelV1 ;
43+ model : LanguageModel ;
4344 maxSteps ?: number ;
4445 toolChoice ?: ToolChoice < Record < string , unknown > > ;
4546}
4647
4748export class VercelAISDKAgent extends AbstractAgent {
48- model : LanguageModelV1 ;
49+ model : LanguageModel ;
4950 maxSteps : number ;
5051 toolChoice : ToolChoice < Record < string , unknown > > ;
5152 constructor ( { model, maxSteps, toolChoice, ...rest } : VercelAISDKAgentConfig ) {
@@ -65,12 +66,15 @@ export class VercelAISDKAgent extends AbstractAgent {
6566 runId : input . runId ,
6667 } as RunStartedEvent ) ;
6768
69+ const toolSet = convertToolToVercelAISDKTools ( input . tools ) ;
70+ const stopCondition = this . maxSteps > 0 ? stepCountIs ( this . maxSteps ) : undefined ;
71+
6872 const response = streamText ( {
6973 model : this . model ,
70- messages : convertMessagesToVercelAISDKMessages ( input . messages ) ,
71- tools : convertToolToVerlAISDKTools ( input . tools ) ,
72- maxSteps : this . maxSteps ,
74+ messages : convertMessagesToModelMessages ( input . messages ) ,
7375 toolChoice : this . toolChoice ,
76+ ...( Object . keys ( toolSet ) . length > 0 ? { tools : toolSet } : { } ) ,
77+ ...( stopCondition ? { stopWhen : stopCondition } : { } ) ,
7478 } ) ;
7579
7680 let messageId = randomUUID ( ) ;
@@ -82,93 +86,130 @@ export class VercelAISDKAgent extends AbstractAgent {
8286 } ;
8387 finalMessages . push ( assistantMessage ) ;
8488
85- processDataStream ( {
86- stream : response . toDataStreamResponse ( ) . body ! ,
87- onTextPart : ( text ) => {
88- assistantMessage . content += text ;
89- const event : TextMessageChunkEvent = {
90- type : EventType . TEXT_MESSAGE_CHUNK ,
91- role : "assistant" ,
92- messageId,
93- delta : text ,
94- } ;
95- subscriber . next ( event ) ;
96- } ,
97- onFinishMessagePart : ( ) => {
98- // Emit message snapshot
99- const event : MessagesSnapshotEvent = {
100- type : EventType . MESSAGES_SNAPSHOT ,
101- messages : finalMessages ,
102- } ;
103- subscriber . next ( event ) ;
89+ let hasCompleted = false ;
90+ const seenToolCallIds = new Set < string > ( ) ;
91+
92+ const finalizeRun = ( ) => {
93+ if ( hasCompleted ) {
94+ return ;
95+ }
96+ hasCompleted = true ;
97+
98+ const snapshotEvent : MessagesSnapshotEvent = {
99+ type : EventType . MESSAGES_SNAPSHOT ,
100+ messages : finalMessages ,
101+ } ;
102+ subscriber . next ( snapshotEvent ) ;
104103
105- // Emit run finished event
106- subscriber . next ( {
107- type : EventType . RUN_FINISHED ,
108- threadId : input . threadId ,
109- runId : input . runId ,
110- } as RunFinishedEvent ) ;
104+ subscriber . next ( {
105+ type : EventType . RUN_FINISHED ,
106+ threadId : input . threadId ,
107+ runId : input . runId ,
108+ } as RunFinishedEvent ) ;
111109
112- // Complete the observable
113- subscriber . complete ( ) ;
114- } ,
115- onToolCallPart ( streamPart ) {
116- let toolCall : ToolCall = {
117- id : streamPart . toolCallId ,
118- type : "function" ,
119- function : {
120- name : streamPart . toolName ,
121- arguments : JSON . stringify ( streamPart . args ) ,
122- } ,
123- } ;
124- assistantMessage . toolCalls ! . push ( toolCall ) ;
110+ subscriber . complete ( ) ;
111+ } ;
125112
126- const startEvent : ToolCallStartEvent = {
127- type : EventType . TOOL_CALL_START ,
128- parentMessageId : messageId ,
129- toolCallId : streamPart . toolCallId ,
130- toolCallName : streamPart . toolName ,
131- } ;
132- subscriber . next ( startEvent ) ;
113+ const processStream = async ( ) => {
114+ try {
115+ for await ( const part of response . fullStream ) {
116+ switch ( part . type ) {
117+ case "text-delta" : {
118+ if ( ! part . text ) {
119+ break ;
120+ }
121+ assistantMessage . content += part . text ;
122+ const event : TextMessageChunkEvent = {
123+ type : EventType . TEXT_MESSAGE_CHUNK ,
124+ role : "assistant" ,
125+ messageId,
126+ delta : part . text ,
127+ } ;
128+ subscriber . next ( event ) ;
129+ break ;
130+ }
131+ case "tool-call" : {
132+ if ( seenToolCallIds . has ( part . toolCallId ) ) {
133+ break ;
134+ }
135+ seenToolCallIds . add ( part . toolCallId ) ;
136+ const argumentsJson = safeStringify ( part . input ) ;
137+ let toolCall : ToolCall = {
138+ id : part . toolCallId ,
139+ type : "function" ,
140+ function : {
141+ name : part . toolName ,
142+ arguments : argumentsJson ,
143+ } ,
144+ } ;
145+ assistantMessage . toolCalls ! . push ( toolCall ) ;
133146
134- const argsEvent : ToolCallArgsEvent = {
135- type : EventType . TOOL_CALL_ARGS ,
136- toolCallId : streamPart . toolCallId ,
137- delta : JSON . stringify ( streamPart . args ) ,
138- } ;
139- subscriber . next ( argsEvent ) ;
147+ const startEvent : ToolCallStartEvent = {
148+ type : EventType . TOOL_CALL_START ,
149+ parentMessageId : messageId ,
150+ toolCallId : part . toolCallId ,
151+ toolCallName : part . toolName ,
152+ } ;
153+ subscriber . next ( startEvent ) ;
140154
141- const endEvent : ToolCallEndEvent = {
142- type : EventType . TOOL_CALL_END ,
143- toolCallId : streamPart . toolCallId ,
144- } ;
145- subscriber . next ( endEvent ) ;
146- } ,
147- onToolResultPart ( streamPart ) {
148- const toolMessage : ToolMessage = {
149- role : "tool" ,
150- id : randomUUID ( ) ,
151- toolCallId : streamPart . toolCallId ,
152- content : JSON . stringify ( streamPart . result ) ,
153- } ;
154- finalMessages . push ( toolMessage ) ;
155- } ,
156- onErrorPart ( streamPart ) {
157- subscriber . error ( streamPart ) ;
158- } ,
159- } ) . catch ( ( error ) => {
160- console . error ( "catch error" , error ) ;
161- // Handle error
162- subscriber . error ( error ) ;
163- } ) ;
155+ const argsEvent : ToolCallArgsEvent = {
156+ type : EventType . TOOL_CALL_ARGS ,
157+ toolCallId : part . toolCallId ,
158+ delta : argumentsJson ,
159+ } ;
160+ subscriber . next ( argsEvent ) ;
161+
162+ const endEvent : ToolCallEndEvent = {
163+ type : EventType . TOOL_CALL_END ,
164+ toolCallId : part . toolCallId ,
165+ } ;
166+ subscriber . next ( endEvent ) ;
167+ break ;
168+ }
169+ case "tool-result" : {
170+ if ( part . preliminary ) {
171+ break ;
172+ }
173+ const toolMessage : ToolMessage = {
174+ role : "tool" ,
175+ id : randomUUID ( ) ,
176+ toolCallId : part . toolCallId ,
177+ content : safeStringify ( part . output ) ,
178+ } ;
179+ finalMessages . push ( toolMessage ) ;
180+ break ;
181+ }
182+ case "tool-error" : {
183+ subscriber . error ( part . error ?? new Error ( `Tool ${ part . toolName } failed` ) ) ;
184+ return ;
185+ }
186+ case "error" : {
187+ subscriber . error ( part . error ?? new Error ( "Stream error" ) ) ;
188+ return ;
189+ }
190+ case "finish" : {
191+ finalizeRun ( ) ;
192+ return ;
193+ }
194+ default :
195+ break ;
196+ }
197+ }
198+ finalizeRun ( ) ;
199+ } catch ( error ) {
200+ subscriber . error ( error ) ;
201+ }
202+ } ;
203+
204+ processStream ( ) ;
164205
165206 return ( ) => { } ;
166207 } ) ;
167208 }
168209}
169210
170- export function convertMessagesToVercelAISDKMessages ( messages : Message [ ] ) : CoreMessage [ ] {
171- const result : CoreMessage [ ] = [ ] ;
211+ export function convertMessagesToModelMessages ( messages : Message [ ] ) : ModelMessage [ ] {
212+ const result : ModelMessage [ ] = [ ] ;
172213
173214 for ( const message of messages ) {
174215 if ( message . role === "assistant" ) {
@@ -178,7 +219,7 @@ export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreM
178219 type : "tool-call" ,
179220 toolCallId : toolCall . id ,
180221 toolName : toolCall . function . name ,
181- args : JSON . parse ( toolCall . function . arguments ) ,
222+ input : JSON . parse ( toolCall . function . arguments ) ,
182223 } ) ;
183224 }
184225 result . push ( {
@@ -209,7 +250,7 @@ export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreM
209250 type : "tool-result" ,
210251 toolCallId : message . toolCallId ,
211252 toolName : toolName ,
212- result : message . content ,
253+ output : parseToolMessageContent ( message . content ) ,
213254 } ,
214255 ] ,
215256 } ) ;
@@ -219,9 +260,9 @@ export function convertMessagesToVercelAISDKMessages(messages: Message[]): CoreM
219260 return result ;
220261}
221262
222- export function convertJsonSchemaToZodSchema ( jsonSchema : any , required : boolean ) : z . ZodSchema {
263+ export function convertJsonSchemaToZodSchema ( jsonSchema : any , required : boolean ) : z . ZodTypeAny {
223264 if ( jsonSchema . type === "object" ) {
224- const spec : { [ key : string ] : z . ZodSchema } = { } ;
265+ const spec : Record < string , z . ZodTypeAny > = { } ;
225266
226267 if ( ! jsonSchema . properties || ! Object . keys ( jsonSchema . properties ) . length ) {
227268 return ! required ? z . object ( spec ) . optional ( ) : z . object ( spec ) ;
@@ -252,15 +293,42 @@ export function convertJsonSchemaToZodSchema(jsonSchema: any, required: boolean)
252293 throw new Error ( "Invalid JSON schema" ) ;
253294}
254295
255- export function convertToolToVerlAISDKTools ( tools : RunAgentInput [ "tools" ] ) : ToolSet {
256- return tools . reduce (
257- ( acc : ToolSet , tool : RunAgentInput [ "tools" ] [ number ] ) => ( {
258- ...acc ,
259- [ tool . name ] : createVercelAISDKTool ( {
260- description : tool . description ,
261- parameters : convertJsonSchemaToZodSchema ( tool . parameters , true ) ,
262- } ) ,
263- } ) ,
264- { } ,
265- ) ;
296+ export function convertToolToVercelAISDKTools ( tools : RunAgentInput [ "tools" ] ) : ToolSet {
297+ const toolSet : Record < string , unknown > = { } ;
298+
299+ for ( const tool of tools ) {
300+ const inputSchema = convertJsonSchemaToZodSchema ( tool . parameters , true ) as z . ZodTypeAny ;
301+ const toolDefinition = {
302+ description : tool . description ,
303+ inputSchema,
304+ outputSchema : z . any ( ) ,
305+ } as unknown ;
306+ toolSet [ tool . name ] = createVercelAISDKTool ( toolDefinition as any ) ;
307+ }
308+
309+ return toolSet as ToolSet ;
310+ }
311+
312+ function safeStringify ( value : unknown ) : string {
313+ if ( typeof value === "string" ) {
314+ return value ;
315+ }
316+ try {
317+ return JSON . stringify ( value ?? { } ) ;
318+ } catch {
319+ return JSON . stringify ( { value : String ( value ) } ) ;
320+ }
321+ }
322+
323+ function parseToolMessageContent ( content : string ) {
324+ if ( ! content ) {
325+ return { type : "text" as const , value : "" } ;
326+ }
327+
328+ try {
329+ const parsed = JSON . parse ( content ) ;
330+ return { type : "json" as const , value : parsed } ;
331+ } catch {
332+ return { type : "text" as const , value : content } ;
333+ }
266334}
0 commit comments