Skip to content

Commit ffcd204

Browse files
authored
fix: #239 enable to pass toolChoice through ai-sdk (#467)
1 parent 01fad84 commit ffcd204

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

.changeset/curly-pumas-visit.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@openai/agents-extensions": patch
3+
---
4+
5+
fix: #239 enable to pass toolChoice through ai-sdk

packages/agents-extensions/src/aiSdk.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type {
77
LanguageModelV2Prompt,
88
LanguageModelV2ProviderDefinedTool,
99
LanguageModelV2ToolCallPart,
10+
LanguageModelV2ToolChoice,
1011
LanguageModelV2ToolResultPart,
1112
} from '@ai-sdk/provider';
1213
import {
@@ -25,6 +26,7 @@ import {
2526
UserError,
2627
withGenerationSpan,
2728
getLogger,
29+
ModelSettingsToolChoice,
2830
} from '@openai/agents';
2931
import { isZodObject } from '@openai/agents/utils';
3032

@@ -449,6 +451,9 @@ export class AiSdkModel implements Model {
449451

450452
const aiSdkRequest: LanguageModelV2CallOptions = {
451453
tools,
454+
toolChoice: toolChoiceToLanguageV2Format(
455+
request.modelSettings.toolChoice,
456+
),
452457
prompt: input,
453458
temperature: request.modelSettings.temperature,
454459
topP: request.modelSettings.topP,
@@ -829,3 +834,21 @@ export function parseArguments(args: string | undefined | null): any {
829834
return {};
830835
}
831836
}
837+
838+
export function toolChoiceToLanguageV2Format(
839+
toolChoice: ModelSettingsToolChoice | undefined,
840+
): LanguageModelV2ToolChoice | undefined {
841+
if (!toolChoice) {
842+
return undefined;
843+
}
844+
switch (toolChoice) {
845+
case 'auto':
846+
return { type: 'auto' };
847+
case 'required':
848+
return { type: 'required' };
849+
case 'none':
850+
return { type: 'none' };
851+
default:
852+
return { type: 'tool', toolName: toolChoice };
853+
}
854+
}

packages/agents-extensions/test/aiSdk.test.ts

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
getResponseFormat,
55
itemsToLanguageV2Messages,
66
parseArguments,
7+
toolChoiceToLanguageV2Format,
78
toolToLanguageV2Tool,
89
} from '../src/aiSdk';
910
import { protocol, withTrace, UserError } from '@openai/agents';
@@ -161,6 +162,19 @@ describe('itemsToLanguageV2Messages', () => {
161162
expect(() => itemsToLanguageV2Messages(stubModel({}), items)).toThrow();
162163
});
163164

165+
test('throws on computer tool calls and results', () => {
166+
expect(() =>
167+
itemsToLanguageV2Messages(stubModel({}), [
168+
{ type: 'computer_call' } as any,
169+
]),
170+
).toThrow(UserError);
171+
expect(() =>
172+
itemsToLanguageV2Messages(stubModel({}), [
173+
{ type: 'computer_call_result' } as any,
174+
]),
175+
).toThrow(UserError);
176+
});
177+
164178
test('converts user images, function results and reasoning items', () => {
165179
const items: protocol.ModelItem[] = [
166180
{
@@ -265,6 +279,61 @@ describe('itemsToLanguageV2Messages', () => {
265279
UserError,
266280
);
267281
});
282+
283+
test('supports input_file string and rejects non-string file id', () => {
284+
const ok: protocol.ModelItem[] = [
285+
{
286+
role: 'user',
287+
content: [
288+
{
289+
type: 'input_file',
290+
file: 'file_123',
291+
},
292+
],
293+
} as any,
294+
];
295+
296+
const msgs = itemsToLanguageV2Messages(stubModel({}), ok);
297+
expect(msgs).toEqual([
298+
{
299+
role: 'user',
300+
content: [
301+
{
302+
type: 'file',
303+
file: 'file_123',
304+
mediaType: 'application/octet-stream',
305+
data: 'file_123',
306+
providerOptions: {},
307+
},
308+
],
309+
providerOptions: {},
310+
},
311+
]);
312+
313+
const bad: protocol.ModelItem[] = [
314+
{
315+
role: 'user',
316+
content: [
317+
{
318+
type: 'input_file',
319+
file: { not: 'a-string' },
320+
},
321+
],
322+
} as any,
323+
];
324+
expect(() => itemsToLanguageV2Messages(stubModel({}), bad)).toThrow(
325+
/File ID is not supported/,
326+
);
327+
});
328+
329+
test('passes through unknown items via providerData', () => {
330+
const custom = { role: 'system', content: 'x', providerOptions: { a: 1 } };
331+
const items: protocol.ModelItem[] = [
332+
{ type: 'unknown', providerData: custom } as any,
333+
];
334+
const msgs = itemsToLanguageV2Messages(stubModel({}), items);
335+
expect(msgs).toEqual([custom]);
336+
});
268337
});
269338

270339
describe('toolToLanguageV2Tool', () => {
@@ -358,6 +427,77 @@ describe('AiSdkModel.getResponse', () => {
358427
]);
359428
});
360429

430+
test('forwards toolChoice to AI SDK (generate)', async () => {
431+
const seen: any[] = [];
432+
const model = new AiSdkModel(
433+
stubModel({
434+
async doGenerate(options) {
435+
seen.push(options.toolChoice);
436+
return {
437+
content: [{ type: 'text', text: 'ok' }],
438+
usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 },
439+
providerMetadata: {},
440+
response: { id: 'id' },
441+
finishReason: 'stop',
442+
warnings: [],
443+
} as any;
444+
},
445+
}),
446+
);
447+
448+
// auto
449+
await withTrace('t', () =>
450+
model.getResponse({
451+
input: 'hi',
452+
tools: [],
453+
handoffs: [],
454+
modelSettings: { toolChoice: 'auto' },
455+
outputType: 'text',
456+
tracing: false,
457+
} as any),
458+
);
459+
// required
460+
await withTrace('t', () =>
461+
model.getResponse({
462+
input: 'hi',
463+
tools: [],
464+
handoffs: [],
465+
modelSettings: { toolChoice: 'required' },
466+
outputType: 'text',
467+
tracing: false,
468+
} as any),
469+
);
470+
// none
471+
await withTrace('t', () =>
472+
model.getResponse({
473+
input: 'hi',
474+
tools: [],
475+
handoffs: [],
476+
modelSettings: { toolChoice: 'none' },
477+
outputType: 'text',
478+
tracing: false,
479+
} as any),
480+
);
481+
// specific tool
482+
await withTrace('t', () =>
483+
model.getResponse({
484+
input: 'hi',
485+
tools: [],
486+
handoffs: [],
487+
modelSettings: { toolChoice: 'myTool' as any },
488+
outputType: 'text',
489+
tracing: false,
490+
} as any),
491+
);
492+
493+
expect(seen).toEqual([
494+
{ type: 'auto' },
495+
{ type: 'required' },
496+
{ type: 'none' },
497+
{ type: 'tool', toolName: 'myTool' },
498+
]);
499+
});
500+
361501
test('aborts when signal already aborted', async () => {
362502
const abort = new AbortController();
363503
abort.abort();
@@ -736,6 +876,48 @@ describe('AiSdkModel.getStreamedResponse', () => {
736876

737877
expect(final).toEqual({ inputTokens: 0, outputTokens: 0, totalTokens: 0 });
738878
});
879+
880+
test('prepends system instructions to prompt for doStream', async () => {
881+
let received: any;
882+
const model = new AiSdkModel(
883+
stubModel({
884+
async doStream(options) {
885+
received = options.prompt;
886+
return { stream: partsStream([]) } as any;
887+
},
888+
}),
889+
);
890+
891+
for await (const _ of model.getStreamedResponse({
892+
systemInstructions: 'inst',
893+
input: 'hi',
894+
tools: [],
895+
handoffs: [],
896+
modelSettings: {},
897+
outputType: 'text',
898+
tracing: false,
899+
} as any)) {
900+
// drain
901+
}
902+
903+
expect(received[0]).toEqual({ role: 'system', content: 'inst' });
904+
});
905+
});
906+
907+
describe('toolChoiceToLanguageV2Format', () => {
908+
test('maps default choices and specific tool', () => {
909+
expect(toolChoiceToLanguageV2Format(undefined)).toBeUndefined();
910+
expect(toolChoiceToLanguageV2Format(null as any)).toBeUndefined();
911+
expect(toolChoiceToLanguageV2Format('auto')).toEqual({ type: 'auto' });
912+
expect(toolChoiceToLanguageV2Format('required')).toEqual({
913+
type: 'required',
914+
});
915+
expect(toolChoiceToLanguageV2Format('none')).toEqual({ type: 'none' });
916+
expect(toolChoiceToLanguageV2Format('runTool' as any)).toEqual({
917+
type: 'tool',
918+
toolName: 'runTool',
919+
});
920+
});
739921
});
740922

741923
describe('AiSdkModel', () => {

0 commit comments

Comments
 (0)