1
+ import {
2
+ AgentConfig ,
3
+ AbstractAgent ,
4
+ EventType ,
5
+ BaseEvent ,
6
+ Message ,
7
+ AssistantMessage ,
8
+ RunAgentInput ,
9
+ MessagesSnapshotEvent ,
10
+ RunFinishedEvent ,
11
+ RunStartedEvent ,
12
+ TextMessageChunkEvent ,
13
+ ToolCallArgsEvent ,
14
+ ToolCallEndEvent ,
15
+ ToolCallStartEvent ,
16
+ ToolCall ,
17
+ ToolMessage ,
18
+ } from '@ag-ui/client' ;
19
+ import { Observable } from "rxjs" ;
20
+ import {
21
+ CoreMessage ,
22
+ LanguageModelV1 ,
23
+ processDataStream ,
24
+ streamText ,
25
+ tool as createVercelAISDKTool ,
26
+ ToolChoice ,
27
+ ToolSet
28
+ } from "ai" ;
29
+ import { randomUUID } from 'crypto' ;
30
+ import { z } from "zod" ;
31
+
32
+ type ProcessedEvent =
33
+ | MessagesSnapshotEvent
34
+ | RunFinishedEvent
35
+ | RunStartedEvent
36
+ | TextMessageChunkEvent
37
+ | ToolCallArgsEvent
38
+ | ToolCallEndEvent
39
+ | ToolCallStartEvent
40
+
41
+ interface VercelAISDKAgentConfig extends AgentConfig {
42
+ model : LanguageModelV1
43
+ maxSteps ?: number
44
+ toolChoice ?: ToolChoice < Record < string , unknown > >
45
+ }
46
+
47
+ export class VercelAISDKAgent extends AbstractAgent {
48
+ model : LanguageModelV1 ;
49
+ maxSteps : number ;
50
+ toolChoice : ToolChoice < Record < string , unknown > > ;
51
+ constructor ( { model, maxSteps, toolChoice, ...rest } : VercelAISDKAgentConfig ) {
52
+ super ( { ...rest } ) ;
53
+ this . model = model ;
54
+ this . maxSteps = maxSteps ?? 1
55
+ this . toolChoice = toolChoice ?? 'auto'
56
+ }
57
+
58
+ protected run ( input : RunAgentInput ) : Observable < BaseEvent > {
59
+ const finalMessages : Message [ ] = input . messages ;
60
+
61
+ return new Observable < ProcessedEvent > ( ( subscriber ) => {
62
+ subscriber . next ( {
63
+ type : EventType . RUN_STARTED ,
64
+ threadId : input . threadId ,
65
+ runId : input . runId ,
66
+ } as RunStartedEvent ) ;
67
+
68
+ const response = streamText ( {
69
+ model : this . model ,
70
+ messages : convertMessagesToVercelAISDKMessages ( input . messages ) ,
71
+ tools : convertToolToVerlAISDKTools ( input . tools ) ,
72
+ maxSteps : this . maxSteps ,
73
+ toolChoice : this . toolChoice ,
74
+ } ) ;
75
+
76
+ let messageId = randomUUID ( ) ;
77
+ let assistantMessage : AssistantMessage = {
78
+ id : messageId ,
79
+ role : 'assistant' ,
80
+ content : '' ,
81
+ toolCalls : [ ] ,
82
+ } ;
83
+ finalMessages . push ( assistantMessage ) ;
84
+
85
+ processDataStream ( {
86
+ stream : response . toDataStreamResponse ( ) . body ! ,
87
+ onTextPart : text => {
88
+ assistantMessage . content += text ;
89
+ const event : TextMessageChunkEvent = {
90
+ type : EventType . TEXT_MESSAGE_CHUNK ,
91
+ role : 'assistant' ,
92
+ messageId,
93
+ delta : text ,
94
+ } ;
95
+ subscriber . next ( event ) ;
96
+ } ,
97
+ onFinishMessagePart : ( ) => {
98
+ // Emit message snapshot
99
+ const event : MessagesSnapshotEvent = {
100
+ type : EventType . MESSAGES_SNAPSHOT ,
101
+ messages : finalMessages ,
102
+ } ;
103
+ subscriber . next ( event ) ;
104
+
105
+ // Emit run finished event
106
+ subscriber . next ( {
107
+ type : EventType . RUN_FINISHED ,
108
+ threadId : input . threadId ,
109
+ runId : input . runId ,
110
+ } as RunFinishedEvent ) ;
111
+
112
+ // Complete the observable
113
+ subscriber . complete ( ) ;
114
+ } ,
115
+ onToolCallPart ( streamPart ) {
116
+ let toolCall : ToolCall = {
117
+ id : streamPart . toolCallId ,
118
+ type : 'function' ,
119
+ function : {
120
+ name : streamPart . toolName ,
121
+ arguments : JSON . stringify ( streamPart . args ) ,
122
+ } ,
123
+ } ;
124
+ assistantMessage . toolCalls ! . push ( toolCall ) ;
125
+
126
+ const startEvent : ToolCallStartEvent = {
127
+ type : EventType . TOOL_CALL_START ,
128
+ parentMessageId : messageId ,
129
+ toolCallId : streamPart . toolCallId ,
130
+ toolCallName : streamPart . toolName ,
131
+ } ;
132
+ subscriber . next ( startEvent ) ;
133
+
134
+ const argsEvent : ToolCallArgsEvent = {
135
+ type : EventType . TOOL_CALL_ARGS ,
136
+ toolCallId : streamPart . toolCallId ,
137
+ delta : JSON . stringify ( streamPart . args ) ,
138
+ } ;
139
+ subscriber . next ( argsEvent ) ;
140
+
141
+ const endEvent : ToolCallEndEvent = {
142
+ type : EventType . TOOL_CALL_END ,
143
+ toolCallId : streamPart . toolCallId ,
144
+ } ;
145
+ subscriber . next ( endEvent ) ;
146
+ } ,
147
+ onToolResultPart ( streamPart ) {
148
+ const toolMessage : ToolMessage = {
149
+ role : 'tool' ,
150
+ id : randomUUID ( ) ,
151
+ toolCallId : streamPart . toolCallId ,
152
+ content : JSON . stringify ( streamPart . result ) ,
153
+ } ;
154
+ finalMessages . push ( toolMessage ) ;
155
+ } ,
156
+ onErrorPart ( streamPart ) {
157
+ subscriber . error ( streamPart )
158
+ } ,
159
+ } ) . catch ( error => {
160
+ console . error ( 'catch error' , error ) ;
161
+ // Handle error
162
+ subscriber . error ( error ) ;
163
+ } ) ;
164
+
165
+ return ( ) => { }
166
+ } ) ;
167
+ }
168
+ }
169
+
170
+ export function convertMessagesToVercelAISDKMessages ( messages : Message [ ] ) : CoreMessage [ ] {
171
+ const result : CoreMessage [ ] = [ ] ;
172
+
173
+ for ( const message of messages ) {
174
+ if ( message . role === 'assistant' ) {
175
+ const parts : any [ ] = message . content ? [ { type : 'text' , text : message . content } ] : [ ] ;
176
+ for ( const toolCall of message . toolCalls ?? [ ] ) {
177
+ parts . push ( {
178
+ type : 'tool-call' ,
179
+ toolCallId : toolCall . id ,
180
+ toolName : toolCall . function . name ,
181
+ args : JSON . parse ( toolCall . function . arguments ) ,
182
+ } ) ;
183
+ }
184
+ result . push ( {
185
+ role : 'assistant' ,
186
+ content : parts ,
187
+ } ) ;
188
+ } else if ( message . role === 'user' ) {
189
+ result . push ( {
190
+ role : 'user' ,
191
+ content : message . content || '' ,
192
+ } ) ;
193
+ } else if ( message . role === 'tool' ) {
194
+ let toolName = 'unknown' ;
195
+ for ( const msg of messages ) {
196
+ if ( msg . role === 'assistant' ) {
197
+ for ( const toolCall of msg . toolCalls ?? [ ] ) {
198
+ if ( toolCall . id === message . toolCallId ) {
199
+ toolName = toolCall . function . name ;
200
+ break ;
201
+ }
202
+ }
203
+ }
204
+ }
205
+ result . push ( {
206
+ role : 'tool' ,
207
+ content : [
208
+ {
209
+ type : 'tool-result' ,
210
+ toolCallId : message . toolCallId ,
211
+ toolName : toolName ,
212
+ result : message . content ,
213
+ } ,
214
+ ] ,
215
+ } ) ;
216
+ }
217
+ }
218
+
219
+ return result ;
220
+ }
221
+
222
+ export function convertJsonSchemaToZodSchema ( jsonSchema : any , required : boolean ) : z . ZodSchema {
223
+ if ( jsonSchema . type === "object" ) {
224
+ const spec : { [ key : string ] : z . ZodSchema } = { } ;
225
+
226
+ if ( ! jsonSchema . properties || ! Object . keys ( jsonSchema . properties ) . length ) {
227
+ return ! required ? z . object ( spec ) . optional ( ) : z . object ( spec ) ;
228
+ }
229
+
230
+ for ( const [ key , value ] of Object . entries ( jsonSchema . properties ) ) {
231
+ spec [ key ] = convertJsonSchemaToZodSchema (
232
+ value ,
233
+ jsonSchema . required ? jsonSchema . required . includes ( key ) : false ,
234
+ ) ;
235
+ }
236
+ let schema = z . object ( spec ) . describe ( jsonSchema . description ) ;
237
+ return required ? schema : schema . optional ( ) ;
238
+ } else if ( jsonSchema . type === "string" ) {
239
+ let schema = z . string ( ) . describe ( jsonSchema . description ) ;
240
+ return required ? schema : schema . optional ( ) ;
241
+ } else if ( jsonSchema . type === "number" ) {
242
+ let schema = z . number ( ) . describe ( jsonSchema . description ) ;
243
+ return required ? schema : schema . optional ( ) ;
244
+ } else if ( jsonSchema . type === "boolean" ) {
245
+ let schema = z . boolean ( ) . describe ( jsonSchema . description ) ;
246
+ return required ? schema : schema . optional ( ) ;
247
+ } else if ( jsonSchema . type === "array" ) {
248
+ let itemSchema = convertJsonSchemaToZodSchema ( jsonSchema . items , true ) ;
249
+ let schema = z . array ( itemSchema ) . describe ( jsonSchema . description ) ;
250
+ return required ? schema : schema . optional ( ) ;
251
+ }
252
+ throw new Error ( "Invalid JSON schema" ) ;
253
+ }
254
+
255
+ export function convertToolToVerlAISDKTools ( tools : RunAgentInput [ 'tools' ] ) : ToolSet {
256
+ return tools . reduce ( ( acc : ToolSet , tool : RunAgentInput [ 'tools' ] [ number ] ) => ( {
257
+ ...acc ,
258
+ [ tool . name ] : createVercelAISDKTool ( {
259
+ description : tool . description ,
260
+ parameters : convertJsonSchemaToZodSchema ( tool . parameters , true ) ,
261
+ } ) ,
262
+ } ) , { } )
263
+ }
0 commit comments