|
4 | 4 | getResponseFormat, |
5 | 5 | itemsToLanguageV2Messages, |
6 | 6 | parseArguments, |
| 7 | + toolChoiceToLanguageV2Format, |
7 | 8 | toolToLanguageV2Tool, |
8 | 9 | } from '../src/aiSdk'; |
9 | 10 | import { protocol, withTrace, UserError } from '@openai/agents'; |
@@ -161,6 +162,19 @@ describe('itemsToLanguageV2Messages', () => { |
161 | 162 | expect(() => itemsToLanguageV2Messages(stubModel({}), items)).toThrow(); |
162 | 163 | }); |
163 | 164 |
|
| 165 | + test('throws on computer tool calls and results', () => { |
| 166 | + expect(() => |
| 167 | + itemsToLanguageV2Messages(stubModel({}), [ |
| 168 | + { type: 'computer_call' } as any, |
| 169 | + ]), |
| 170 | + ).toThrow(UserError); |
| 171 | + expect(() => |
| 172 | + itemsToLanguageV2Messages(stubModel({}), [ |
| 173 | + { type: 'computer_call_result' } as any, |
| 174 | + ]), |
| 175 | + ).toThrow(UserError); |
| 176 | + }); |
| 177 | + |
164 | 178 | test('converts user images, function results and reasoning items', () => { |
165 | 179 | const items: protocol.ModelItem[] = [ |
166 | 180 | { |
@@ -265,6 +279,61 @@ describe('itemsToLanguageV2Messages', () => { |
265 | 279 | UserError, |
266 | 280 | ); |
267 | 281 | }); |
| 282 | + |
| 283 | + test('supports input_file string and rejects non-string file id', () => { |
| 284 | + const ok: protocol.ModelItem[] = [ |
| 285 | + { |
| 286 | + role: 'user', |
| 287 | + content: [ |
| 288 | + { |
| 289 | + type: 'input_file', |
| 290 | + file: 'file_123', |
| 291 | + }, |
| 292 | + ], |
| 293 | + } as any, |
| 294 | + ]; |
| 295 | + |
| 296 | + const msgs = itemsToLanguageV2Messages(stubModel({}), ok); |
| 297 | + expect(msgs).toEqual([ |
| 298 | + { |
| 299 | + role: 'user', |
| 300 | + content: [ |
| 301 | + { |
| 302 | + type: 'file', |
| 303 | + file: 'file_123', |
| 304 | + mediaType: 'application/octet-stream', |
| 305 | + data: 'file_123', |
| 306 | + providerOptions: {}, |
| 307 | + }, |
| 308 | + ], |
| 309 | + providerOptions: {}, |
| 310 | + }, |
| 311 | + ]); |
| 312 | + |
| 313 | + const bad: protocol.ModelItem[] = [ |
| 314 | + { |
| 315 | + role: 'user', |
| 316 | + content: [ |
| 317 | + { |
| 318 | + type: 'input_file', |
| 319 | + file: { not: 'a-string' }, |
| 320 | + }, |
| 321 | + ], |
| 322 | + } as any, |
| 323 | + ]; |
| 324 | + expect(() => itemsToLanguageV2Messages(stubModel({}), bad)).toThrow( |
| 325 | + /File ID is not supported/, |
| 326 | + ); |
| 327 | + }); |
| 328 | + |
| 329 | + test('passes through unknown items via providerData', () => { |
| 330 | + const custom = { role: 'system', content: 'x', providerOptions: { a: 1 } }; |
| 331 | + const items: protocol.ModelItem[] = [ |
| 332 | + { type: 'unknown', providerData: custom } as any, |
| 333 | + ]; |
| 334 | + const msgs = itemsToLanguageV2Messages(stubModel({}), items); |
| 335 | + expect(msgs).toEqual([custom]); |
| 336 | + }); |
268 | 337 | }); |
269 | 338 |
|
270 | 339 | describe('toolToLanguageV2Tool', () => { |
@@ -358,6 +427,77 @@ describe('AiSdkModel.getResponse', () => { |
358 | 427 | ]); |
359 | 428 | }); |
360 | 429 |
|
| 430 | + test('forwards toolChoice to AI SDK (generate)', async () => { |
| 431 | + const seen: any[] = []; |
| 432 | + const model = new AiSdkModel( |
| 433 | + stubModel({ |
| 434 | + async doGenerate(options) { |
| 435 | + seen.push(options.toolChoice); |
| 436 | + return { |
| 437 | + content: [{ type: 'text', text: 'ok' }], |
| 438 | + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, |
| 439 | + providerMetadata: {}, |
| 440 | + response: { id: 'id' }, |
| 441 | + finishReason: 'stop', |
| 442 | + warnings: [], |
| 443 | + } as any; |
| 444 | + }, |
| 445 | + }), |
| 446 | + ); |
| 447 | + |
| 448 | + // auto |
| 449 | + await withTrace('t', () => |
| 450 | + model.getResponse({ |
| 451 | + input: 'hi', |
| 452 | + tools: [], |
| 453 | + handoffs: [], |
| 454 | + modelSettings: { toolChoice: 'auto' }, |
| 455 | + outputType: 'text', |
| 456 | + tracing: false, |
| 457 | + } as any), |
| 458 | + ); |
| 459 | + // required |
| 460 | + await withTrace('t', () => |
| 461 | + model.getResponse({ |
| 462 | + input: 'hi', |
| 463 | + tools: [], |
| 464 | + handoffs: [], |
| 465 | + modelSettings: { toolChoice: 'required' }, |
| 466 | + outputType: 'text', |
| 467 | + tracing: false, |
| 468 | + } as any), |
| 469 | + ); |
| 470 | + // none |
| 471 | + await withTrace('t', () => |
| 472 | + model.getResponse({ |
| 473 | + input: 'hi', |
| 474 | + tools: [], |
| 475 | + handoffs: [], |
| 476 | + modelSettings: { toolChoice: 'none' }, |
| 477 | + outputType: 'text', |
| 478 | + tracing: false, |
| 479 | + } as any), |
| 480 | + ); |
| 481 | + // specific tool |
| 482 | + await withTrace('t', () => |
| 483 | + model.getResponse({ |
| 484 | + input: 'hi', |
| 485 | + tools: [], |
| 486 | + handoffs: [], |
| 487 | + modelSettings: { toolChoice: 'myTool' as any }, |
| 488 | + outputType: 'text', |
| 489 | + tracing: false, |
| 490 | + } as any), |
| 491 | + ); |
| 492 | + |
| 493 | + expect(seen).toEqual([ |
| 494 | + { type: 'auto' }, |
| 495 | + { type: 'required' }, |
| 496 | + { type: 'none' }, |
| 497 | + { type: 'tool', toolName: 'myTool' }, |
| 498 | + ]); |
| 499 | + }); |
| 500 | + |
361 | 501 | test('aborts when signal already aborted', async () => { |
362 | 502 | const abort = new AbortController(); |
363 | 503 | abort.abort(); |
@@ -736,6 +876,48 @@ describe('AiSdkModel.getStreamedResponse', () => { |
736 | 876 |
|
737 | 877 | expect(final).toEqual({ inputTokens: 0, outputTokens: 0, totalTokens: 0 }); |
738 | 878 | }); |
| 879 | + |
| 880 | + test('prepends system instructions to prompt for doStream', async () => { |
| 881 | + let received: any; |
| 882 | + const model = new AiSdkModel( |
| 883 | + stubModel({ |
| 884 | + async doStream(options) { |
| 885 | + received = options.prompt; |
| 886 | + return { stream: partsStream([]) } as any; |
| 887 | + }, |
| 888 | + }), |
| 889 | + ); |
| 890 | + |
| 891 | + for await (const _ of model.getStreamedResponse({ |
| 892 | + systemInstructions: 'inst', |
| 893 | + input: 'hi', |
| 894 | + tools: [], |
| 895 | + handoffs: [], |
| 896 | + modelSettings: {}, |
| 897 | + outputType: 'text', |
| 898 | + tracing: false, |
| 899 | + } as any)) { |
| 900 | + // drain |
| 901 | + } |
| 902 | + |
| 903 | + expect(received[0]).toEqual({ role: 'system', content: 'inst' }); |
| 904 | + }); |
| 905 | +}); |
| 906 | + |
| 907 | +describe('toolChoiceToLanguageV2Format', () => { |
| 908 | + test('maps default choices and specific tool', () => { |
| 909 | + expect(toolChoiceToLanguageV2Format(undefined)).toBeUndefined(); |
| 910 | + expect(toolChoiceToLanguageV2Format(null as any)).toBeUndefined(); |
| 911 | + expect(toolChoiceToLanguageV2Format('auto')).toEqual({ type: 'auto' }); |
| 912 | + expect(toolChoiceToLanguageV2Format('required')).toEqual({ |
| 913 | + type: 'required', |
| 914 | + }); |
| 915 | + expect(toolChoiceToLanguageV2Format('none')).toEqual({ type: 'none' }); |
| 916 | + expect(toolChoiceToLanguageV2Format('runTool' as any)).toEqual({ |
| 917 | + type: 'tool', |
| 918 | + toolName: 'runTool', |
| 919 | + }); |
| 920 | + }); |
739 | 921 | }); |
740 | 922 |
|
741 | 923 | describe('AiSdkModel', () => { |
|
0 commit comments