Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/spotty-kids-follow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@openai/agents-extensions': patch
---

fix: #261 add support for other models that could produce JSON output in text content
221 changes: 200 additions & 21 deletions packages/agents-extensions/src/aiSdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,54 @@ import {
} from '@openai/agents';
import { isZodObject } from '@openai/agents/utils';

function isJsonString(candidate: string): boolean {
try {
JSON.parse(candidate);
return true;
} catch {
return false;
}
}

function applyStructuredOutputInstruction(
prompt: LanguageModelV2Prompt,
outputType: SerializedOutputType,
): LanguageModelV2Prompt {
if (outputType === 'text') {
return prompt;
}

const schemaText = JSON.stringify(outputType.schema, null, 2);
const instruction =
'You must respond with a JSON object that matches this schema:\n' +
`${schemaText}\n` +
'Return only valid JSON without code fences or extra commentary.';

if (prompt.length > 0) {
const [firstMessage, ...rest] = prompt;
if (
firstMessage.role === 'system' &&
typeof firstMessage.content === 'string'
) {
return [
{
...firstMessage,
content: `${firstMessage.content}\n\n${instruction}`,
},
...rest,
];
}
}

return [
{
role: 'system',
content: instruction,
},
...prompt,
];
}

/**
* @internal
* Converts a list of model items to a list of language model V2 messages.
Expand Down Expand Up @@ -430,6 +478,8 @@ export class AiSdkModel implements Model {
];
}

input = applyStructuredOutputInstruction(input, request.outputType);

const tools = request.tools.map((tool) =>
toolToLanguageV2Tool(this.#model, tool),
);
Expand All @@ -448,6 +498,12 @@ export class AiSdkModel implements Model {

const responseFormat: LanguageModelV2CallOptions['responseFormat'] =
getResponseFormat(request.outputType);
const expectsStructuredOutput =
responseFormat && 'type' in responseFormat
? responseFormat.type === 'json'
: false;
const allowStructuredFromToolCalls =
expectsStructuredOutput && tools.length === 0;

const aiSdkRequest: LanguageModelV2CallOptions = {
tools,
Expand Down Expand Up @@ -480,16 +536,38 @@ export class AiSdkModel implements Model {
const toolCalls = resultContent.filter(
(c: any) => c && c.type === 'tool-call',
);
const hasToolCalls = toolCalls.length > 0;
let structuredJsonFromToolCalls: string | undefined;
const forwardedToolCalls: Array<{
toolCall: (typeof toolCalls)[number];
serializedInput: string;
}> = [];
for (const toolCall of toolCalls) {
const serializedInput =
typeof (toolCall as any).input === 'string'
? ((toolCall as any).input as string)
: JSON.stringify((toolCall as any).input ?? {});
let shouldForward = true;
if (
allowStructuredFromToolCalls &&
serializedInput &&
isJsonString(serializedInput)
) {
structuredJsonFromToolCalls = serializedInput;
shouldForward = false;
}

if (shouldForward) {
forwardedToolCalls.push({ toolCall, serializedInput });
}
}

const hasToolCalls = forwardedToolCalls.length > 0;
for (const { toolCall, serializedInput } of forwardedToolCalls) {
output.push({
type: 'function_call',
callId: toolCall.toolCallId,
name: toolCall.toolName,
arguments:
typeof toolCall.input === 'string'
? toolCall.input
: JSON.stringify(toolCall.input ?? {}),
arguments: serializedInput,
status: 'completed',
providerData: hasToolCalls ? result.providerMetadata : undefined,
});
Expand All @@ -499,19 +577,58 @@ export class AiSdkModel implements Model {
// Putting a text message here will let the agent loop to complete,
// so adding this item only when the tool calls are empty.
// Note that the same support is not available for streaming mode.
if (!hasToolCalls) {
const textItem = resultContent.find(
(c: any) => c && c.type === 'text' && typeof c.text === 'string',
);
if (textItem) {
output.push({
type: 'message',
content: [{ type: 'output_text', text: textItem.text }],
role: 'assistant',
status: 'completed',
providerData: (result as any).providerMetadata,
});
}
const textItem = resultContent.find(
(c: any) => c && c.type === 'text' && typeof c.text === 'string',
);
const structuredJsonFromText = expectsStructuredOutput
? textItem?.text && isJsonString(textItem.text)
? textItem.text
: undefined
: undefined;

let messageAdded = false;
if (structuredJsonFromText) {
output.push({
type: 'message',
content: [
{ type: 'output_text', text: structuredJsonFromText } as const,
],
role: 'assistant',
status: 'completed',
providerData: (result as any).providerMetadata,
});
messageAdded = true;
}

if (
!messageAdded &&
!hasToolCalls &&
textItem &&
!expectsStructuredOutput
) {
output.push({
type: 'message',
content: [{ type: 'output_text', text: textItem.text }],
role: 'assistant',
status: 'completed',
providerData: (result as any).providerMetadata,
});
messageAdded = true;
}

if (!messageAdded && !hasToolCalls && structuredJsonFromToolCalls) {
output.push({
type: 'message',
content: [
{
type: 'output_text',
text: structuredJsonFromToolCalls,
} as const,
],
role: 'assistant',
status: 'completed',
providerData: (result as any).providerMetadata,
});
}

if (span && request.tracing === true) {
Expand Down Expand Up @@ -624,6 +741,8 @@ export class AiSdkModel implements Model {
];
}

input = applyStructuredOutputInstruction(input, request.outputType);

const tools = request.tools.map((tool) =>
toolToLanguageV2Tool(this.#model, tool),
);
Expand All @@ -638,6 +757,12 @@ export class AiSdkModel implements Model {

const responseFormat: LanguageModelV2CallOptions['responseFormat'] =
getResponseFormat(request.outputType);
const expectsStructuredOutput =
responseFormat && 'type' in responseFormat
? responseFormat.type === 'json'
: false;
const allowStructuredFromToolCalls =
expectsStructuredOutput && tools.length === 0;

const aiSdkRequest: LanguageModelV2CallOptions = {
tools,
Expand Down Expand Up @@ -669,6 +794,7 @@ export class AiSdkModel implements Model {
let usageCompletionTokens = 0;
const functionCalls: Record<string, protocol.FunctionCallItem> = {};
let textOutput: protocol.OutputText | undefined;
let structuredJsonFromToolCalls: string | undefined;

for await (const part of stream) {
if (!started) {
Expand All @@ -689,12 +815,26 @@ export class AiSdkModel implements Model {
}
case 'tool-call': {
const toolCallId = (part as any).toolCallId;
if (toolCallId) {
let shouldForward = true;
const serializedInput =
typeof (part as any).input === 'string'
? ((part as any).input as string)
: JSON.stringify((part as any).input ?? {});
if (
allowStructuredFromToolCalls &&
serializedInput &&
isJsonString(serializedInput)
) {
structuredJsonFromToolCalls = serializedInput;
shouldForward = false;
}

if (shouldForward && toolCallId) {
functionCalls[toolCallId] = {
type: 'function_call',
callId: toolCallId,
name: (part as any).toolName,
arguments: (part as any).input ?? '',
arguments: serializedInput,
status: 'completed',
};
}
Expand Down Expand Up @@ -727,10 +867,49 @@ export class AiSdkModel implements Model {

const outputs: protocol.OutputModelItem[] = [];
if (textOutput) {
if (expectsStructuredOutput) {
if (textOutput.text && isJsonString(textOutput.text)) {
outputs.push({
type: 'message',
role: 'assistant',
content: [textOutput],
status: 'completed',
});
} else if (structuredJsonFromToolCalls) {
outputs.push({
type: 'message',
role: 'assistant',
content: [
{
type: 'output_text',
text: structuredJsonFromToolCalls,
} as const,
],
status: 'completed',
});
} else {
outputs.push({
type: 'message',
role: 'assistant',
content: [textOutput],
status: 'completed',
});
}
} else {
outputs.push({
type: 'message',
role: 'assistant',
content: [textOutput],
status: 'completed',
});
}
} else if (structuredJsonFromToolCalls) {
outputs.push({
type: 'message',
role: 'assistant',
content: [textOutput],
content: [
{ type: 'output_text', text: structuredJsonFromToolCalls } as const,
],
status: 'completed',
});
}
Expand Down
Loading