Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/smart-mirrors-join.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@openai/agents-extensions': patch
---

fix(aisdk): make providerData less opinionated and pass to content
30 changes: 30 additions & 0 deletions docs/src/content/docs/extensions/ai-sdk.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,33 @@ of supported models that can be brought into the Agents SDK through this adapter
## Example

<Code lang="typescript" code={aiSdkSetupExample} title="AI SDK Setup" />

## Passing provider metadata

If you need to send provider-specific options with a message, pass them through
`providerMetadata`. The values are forwarded directly to the underlying AI SDK
model. For example, the following `providerData` in the Agents SDK

```ts
providerData: {
anthropic: {
cacheControl: {
type: 'ephemeral';
}
}
}
```

would become

```ts
providerMetadata: {
anthropic: {
cacheControl: {
type: 'ephemeral';
}
}
}
```

when using the AI SDK integration.
26 changes: 26 additions & 0 deletions docs/src/content/docs/ja/extensions/ai-sdk.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,29 @@ import aiSdkSetupExample from '../../../../../../examples/docs/extensions/ai-sdk
## 例

<Code lang="typescript" code={aiSdkSetupExample} title="AI SDK Setup" />

## プロバイダーメタデータの渡し方

メッセージにプロバイダー固有のオプションを設定したい場合は、`providerMetadata` にその値を直接指定します。例えば Agents SDK で

```ts
providerData: {
anthropic: {
cacheControl: {
type: 'ephemeral';
}
}
}
```

と指定していた場合、AI SDK 連携では次のようになります。

```ts
providerMetadata: {
anthropic: {
cacheControl: {
type: 'ephemeral';
}
}
}
```
115 changes: 74 additions & 41 deletions packages/agents-extensions/src/aiSdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ export function itemsToLanguageV1Messages(
role: 'system',
content: content,
providerMetadata: {
[model.provider]: {
...(providerData ?? {}),
},
...(providerData ?? {}),
},
});
continue;
Expand All @@ -66,12 +64,25 @@ export function itemsToLanguageV1Messages(
typeof content === 'string'
? [{ type: 'text', text: content }]
: content.map((c) => {
const { providerData: contentProviderData } = c;
if (c.type === 'input_text') {
return { type: 'text', text: c.text };
return {
type: 'text',
text: c.text,
providerMetadata: {
...(contentProviderData ?? {}),
},
};
}
if (c.type === 'input_image') {
const url = new URL(c.image);
return { type: 'image', image: url };
return {
type: 'image',
image: url,
providerMetadata: {
...(contentProviderData ?? {}),
},
};
}
if (c.type === 'input_file') {
if (typeof c.file !== 'string') {
Expand All @@ -82,14 +93,15 @@ export function itemsToLanguageV1Messages(
file: c.file,
mimeType: 'application/octet-stream',
data: c.file,
providerMetadata: {
...(contentProviderData ?? {}),
},
};
}
throw new UserError(`Unknown content type: ${c.type}`);
}),
providerMetadata: {
[model.provider]: {
...(providerData ?? {}),
},
...(providerData ?? {}),
},
});
continue;
Expand All @@ -106,19 +118,30 @@ export function itemsToLanguageV1Messages(
content: content
.filter((c) => c.type === 'input_text' || c.type === 'output_text')
.map((c) => {
const { providerData: contentProviderData } = c;
if (c.type === 'output_text') {
return { type: 'text', text: c.text };
return {
type: 'text',
text: c.text,
providerMetadata: {
...(contentProviderData ?? {}),
},
};
}
if (c.type === 'input_text') {
return { type: 'text', text: c.text };
return {
type: 'text',
text: c.text,
providerMetadata: {
...(contentProviderData ?? {}),
},
};
}
const exhaustiveCheck = c satisfies never;
throw new UserError(`Unknown content type: ${exhaustiveCheck}`);
}),
providerMetadata: {
[model.provider]: {
...(providerData ?? {}),
},
...(providerData ?? {}),
},
});
continue;
Expand All @@ -132,9 +155,7 @@ export function itemsToLanguageV1Messages(
role: 'assistant',
content: [],
providerMetadata: {
[model.provider]: {
...(item.providerData ?? {}),
},
...(item.providerData ?? {}),
},
};
}
Expand All @@ -148,6 +169,9 @@ export function itemsToLanguageV1Messages(
toolCallId: item.callId,
toolName: item.name,
args: JSON.parse(item.arguments),
providerMetadata: {
...(item.providerData ?? {}),
},
};
currentAssistantMessage.content.push(content);
}
Expand All @@ -162,14 +186,15 @@ export function itemsToLanguageV1Messages(
toolCallId: item.callId,
toolName: item.name,
result: item.output,
providerMetadata: {
...(item.providerData ?? {}),
},
};
messages.push({
role: 'tool',
content: [toolResult],
providerMetadata: {
[model.provider]: {
...(item.providerData ?? {}),
},
...(item.providerData ?? {}),
},
});
continue;
Expand All @@ -194,11 +219,15 @@ export function itemsToLanguageV1Messages(
) {
messages.push({
role: 'assistant',
content: [{ type: 'reasoning', text: item.content[0].text }],
providerMetadata: {
[model.provider]: {
...(item.providerData ?? {}),
content: [
{
type: 'reasoning',
text: item.content[0].text,
providerMetadata: { ...(item.providerData ?? {}) },
},
],
providerMetadata: {
...(item.providerData ?? {}),
},
});
continue;
Expand Down Expand Up @@ -344,11 +373,6 @@ export class AiSdkModel implements Model {
}

async getResponse(request: ModelRequest) {
if (this.#logger.dontLogModelData) {
this.#logger.debug('Request received');
} else {
this.#logger.debug('Request:', request);
}
return withGenerationSpan(async (span) => {
try {
span.spanData.model = this.#model.provider + ':' + this.#model.modelId;
Expand Down Expand Up @@ -396,7 +420,7 @@ export class AiSdkModel implements Model {
const responseFormat: LanguageModelV1CallOptions['responseFormat'] =
getResponseFormat(request.outputType);

const result = await this.#model.doGenerate({
const aiSdkRequest: LanguageModelV1CallOptions = {
inputFormat: 'messages',
mode: {
type: 'regular',
Expand All @@ -412,7 +436,15 @@ export class AiSdkModel implements Model {
abortSignal: request.signal,

...(request.modelSettings.providerData ?? {}),
});
};

if (this.#logger.dontLogModelData) {
this.#logger.debug('Request sent');
} else {
this.#logger.debug('Request:', aiSdkRequest);
}

const result = await this.#model.doGenerate(aiSdkRequest);

const output: ModelResponse['output'] = [];

Expand All @@ -423,9 +455,7 @@ export class AiSdkModel implements Model {
name: toolCall.toolName,
arguments: toolCall.args,
status: 'completed',
providerData: !result.text
? result.providerMetadata?.[this.#model.provider]
: undefined,
providerData: !result.text ? result.providerMetadata : undefined,
});
});

Expand All @@ -439,7 +469,7 @@ export class AiSdkModel implements Model {
content: [{ type: 'output_text', text: result.text }],
role: 'assistant',
status: 'completed',
providerData: result.providerMetadata?.[this.#model.provider],
providerData: result.providerMetadata,
});
}

Expand Down Expand Up @@ -509,11 +539,6 @@ export class AiSdkModel implements Model {
async *getStreamedResponse(
request: ModelRequest,
): AsyncIterable<ResponseStreamEvent> {
if (this.#logger.dontLogModelData) {
this.#logger.debug('Request received (streamed)');
} else {
this.#logger.debug('Request (streamed):', request);
}
const span = request.tracing ? createGenerationSpan() : undefined;
try {
if (span) {
Expand Down Expand Up @@ -564,7 +589,7 @@ export class AiSdkModel implements Model {
const responseFormat: LanguageModelV1CallOptions['responseFormat'] =
getResponseFormat(request.outputType);

const { stream } = await this.#model.doStream({
const aiSdkRequest: LanguageModelV1CallOptions = {
inputFormat: 'messages',
mode: {
type: 'regular',
Expand All @@ -579,7 +604,15 @@ export class AiSdkModel implements Model {
responseFormat,
abortSignal: request.signal,
...(request.modelSettings.providerData ?? {}),
});
};

if (this.#logger.dontLogModelData) {
this.#logger.debug('Request received (streamed)');
} else {
this.#logger.debug('Request (streamed):', aiSdkRequest);
}

const { stream } = await this.#model.doStream(aiSdkRequest);

let started = false;
let responseId: string | undefined;
Expand Down
Loading