Skip to content

Commit 1be9f10

Browse files
committed
Support UploadedFile for OpenAI models
1 parent 6abc260 commit 1be9f10

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datetime import datetime
99
from typing import Any, Literal, Union, cast, overload
1010

11+
from httpx import URL
1112
from openai.types import FileObject
1213
from pydantic import ValidationError
1314
from typing_extensions import assert_never, deprecated
@@ -625,7 +626,7 @@ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.C
625626
else:
626627
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
627628
elif isinstance(part, UserPromptPart):
628-
yield await self._map_user_prompt(part)
629+
yield await self._map_user_prompt(part, self._provider)
629630
elif isinstance(part, ToolReturnPart):
630631
yield chat.ChatCompletionToolMessageParam(
631632
role='tool',
@@ -647,7 +648,7 @@ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.C
647648
assert_never(part)
648649

649650
@staticmethod
650-
async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
651+
async def _map_user_prompt(part: UserPromptPart, provider: Provider[Any]) -> chat.ChatCompletionUserMessageParam:
651652
content: str | list[ChatCompletionContentPartParam]
652653
if isinstance(part.content, str):
653654
content = part.content
@@ -700,15 +701,8 @@ async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessa
700701
elif isinstance(item, VideoUrl): # pragma: no cover
701702
raise NotImplementedError('VideoUrl is not supported for OpenAI')
702703
elif isinstance(item, UploadedFile):
703-
if not isinstance(item.file, FileObject):
704-
raise UserError('UploadedFile.file_object must be an OpenAI FileObject')
705-
file = File(
706-
file=FileFile(
707-
file_id=item.file.id,
708-
),
709-
type='file',
710-
)
711-
content.append(file)
704+
file = _map_uploaded_file(item, provider)
705+
content.append(File(file=FileFile(file_id=file.id, filename=file.filename), type='file'))
712706
else:
713707
assert_never(item)
714708
return chat.ChatCompletionUserMessageParam(role='user', content=content)
@@ -996,7 +990,7 @@ async def _map_messages(
996990
if isinstance(part, SystemPromptPart):
997991
openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content))
998992
elif isinstance(part, UserPromptPart):
999-
openai_messages.append(await self._map_user_prompt(part))
993+
openai_messages.append(await self._map_user_prompt(part, self._provider))
1000994
elif isinstance(part, ToolReturnPart):
1001995
openai_messages.append(
1002996
FunctionCallOutput(
@@ -1078,7 +1072,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseForma
10781072
return response_format_param
10791073

10801074
@staticmethod
1081-
async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam:
1075+
async def _map_user_prompt(part: UserPromptPart, provider: Provider[Any]) -> responses.EasyInputMessageParam:
10821076
content: str | list[responses.ResponseInputContentParam]
10831077
if isinstance(part.content, str):
10841078
content = part.content
@@ -1136,6 +1130,11 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
11361130
)
11371131
elif isinstance(item, VideoUrl): # pragma: no cover
11381132
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
1133+
elif isinstance(item, UploadedFile):
1134+
file = _map_uploaded_file(item, provider)
1135+
content.append(
1136+
responses.ResponseInputFileParam(file_id=file.id, filename=file.filename, type='input_file')
1137+
)
11391138
else:
11401139
assert_never(item)
11411140
return responses.EasyInputMessageParam(role='user', content=content)
@@ -1370,3 +1369,19 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
13701369
u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
13711370
u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
13721371
return u
1372+
1373+
1374+
def _map_openai_uploaded_file(item: UploadedFile) -> FileObject:
1375+
if not isinstance(item.file, FileObject):
1376+
raise UserError('UploadedFile.file must be an openai.types.FileObject')
1377+
return item.file
1378+
1379+
1380+
def _map_uploaded_file(uploaded_file: UploadedFile, provider: Provider[Any]) -> FileObject:
1381+
"""Map an UploadedFile to a File object."""
1382+
url = URL(provider.base_url)
1383+
1384+
if url.host == 'api.openai.com':
1385+
return _map_openai_uploaded_file(uploaded_file)
1386+
else:
1387+
raise UserError(f'UploadedFile is not supported for `{provider.name}` with base_url {provider.base_url}.')

0 commit comments

Comments
 (0)