Skip to content

Commit 28aa90a

Browse files
committed
Support UploadedFile for OpenAI models
1 parent 9e427d7 commit 28aa90a

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datetime import datetime
99
from typing import Any, Literal, Union, cast, overload
1010

11+
from httpx import URL
1112
from openai.types import FileObject
1213
from pydantic import ValidationError
1314
from typing_extensions import assert_never, deprecated
@@ -625,7 +626,7 @@ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.C
625626
else:
626627
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
627628
elif isinstance(part, UserPromptPart):
628-
yield await self._map_user_prompt(part)
629+
yield await self._map_user_prompt(part, self._provider)
629630
elif isinstance(part, ToolReturnPart):
630631
yield chat.ChatCompletionToolMessageParam(
631632
role='tool',
@@ -647,7 +648,7 @@ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.C
647648
assert_never(part)
648649

649650
@staticmethod
650-
async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
651+
async def _map_user_prompt(part: UserPromptPart, provider: Provider[Any]) -> chat.ChatCompletionUserMessageParam:
651652
content: str | list[ChatCompletionContentPartParam]
652653
if isinstance(part.content, str):
653654
content = part.content
@@ -700,15 +701,7 @@ async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessa
700701
elif isinstance(item, VideoUrl): # pragma: no cover
701702
raise NotImplementedError('VideoUrl is not supported for OpenAI')
702703
elif isinstance(item, UploadedFile):
703-
if not isinstance(item.file, FileObject):
704-
raise UserError('UploadedFile.file_object must be an OpenAI FileObject')
705-
file = File(
706-
file=FileFile(
707-
file_id=item.file.id,
708-
),
709-
type='file',
710-
)
711-
content.append(file)
704+
content.append(_map_uploaded_file(item, provider))
712705
else:
713706
assert_never(item)
714707
return chat.ChatCompletionUserMessageParam(role='user', content=content)
@@ -996,7 +989,7 @@ async def _map_messages(
996989
if isinstance(part, SystemPromptPart):
997990
openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content))
998991
elif isinstance(part, UserPromptPart):
999-
openai_messages.append(await self._map_user_prompt(part))
992+
openai_messages.append(await self._map_user_prompt(part, self._provider))
1000993
elif isinstance(part, ToolReturnPart):
1001994
openai_messages.append(
1002995
FunctionCallOutput(
@@ -1078,7 +1071,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseForma
10781071
return response_format_param
10791072

10801073
@staticmethod
1081-
async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam:
1074+
async def _map_user_prompt(part: UserPromptPart, provider: Provider[Any]) -> responses.EasyInputMessageParam:
10821075
content: str | list[responses.ResponseInputContentParam]
10831076
if isinstance(part.content, str):
10841077
content = part.content
@@ -1136,6 +1129,8 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
11361129
)
11371130
elif isinstance(item, VideoUrl): # pragma: no cover
11381131
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
1132+
elif isinstance(item, UploadedFile):
1133+
content.append(_map_uploaded_file(item, provider))
11391134
else:
11401135
assert_never(item)
11411136
return responses.EasyInputMessageParam(role='user', content=content)
@@ -1370,3 +1365,24 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
13701365
u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
13711366
u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
13721367
return u
1368+
1369+
1370+
def _map_openai_uploaded_file(item: UploadedFile):
1371+
if not isinstance(item.file, FileObject):
1372+
raise UserError('UploadedFile.file must be an openai.types.FileObject')
1373+
return File(
1374+
file=FileFile(
1375+
file_id=item.file.id,
1376+
),
1377+
type='file',
1378+
)
1379+
1380+
1381+
def _map_uploaded_file(uploaded_file: UploadedFile, provider: Provider[Any]):
1382+
"""Map an UploadedFile to a File object."""
1383+
url = URL(provider.base_url)
1384+
1385+
if url.host == 'api.openai.com':
1386+
return _map_openai_uploaded_file(uploaded_file)
1387+
else:
1388+
raise UserError(f'UploadedFile is not supported for `{provider.name}` with base_url {provider.base_url}.')

0 commit comments

Comments
 (0)