Skip to content

Commit 07a7877

Browse files
committed
typing
1 parent dce5fc2 commit 07a7877

File tree

1 file changed

+57
-40
lines changed

1 file changed

+57
-40
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import asynccontextmanager
88
from dataclasses import dataclass, field, replace
99
from datetime import datetime
10-
from typing import Any, Literal, cast, overload
10+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
1111

1212
from pydantic import ValidationError
1313
from pydantic_core import to_json
@@ -53,7 +53,7 @@
5353
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
5454

5555
try:
56-
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
56+
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
5757
from openai.types import AllModels, chat, responses
5858
from openai.types.chat import (
5959
ChatCompletionChunk,
@@ -88,6 +88,23 @@
8888
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
8989
) from _import_error
9090

91+
if TYPE_CHECKING:
92+
from openai import Omit, omit
93+
94+
OMIT = omit
95+
else:
96+
# Backward compatibility with openai<2
97+
try:
98+
from openai import Omit, omit
99+
100+
OMIT = omit
101+
except ImportError:
102+
from openai import NOT_GIVEN, NotGiven
103+
104+
OMIT = NOT_GIVEN
105+
Omit = NotGiven
106+
107+
91108
__all__ = (
92109
'OpenAIModel',
93110
'OpenAIChatModel',
@@ -499,28 +516,28 @@ async def _completions_create(
499516
return await self.client.chat.completions.create(
500517
model=self._model_name,
501518
messages=openai_messages,
502-
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
503-
tools=tools or NOT_GIVEN,
504-
tool_choice=tool_choice or NOT_GIVEN,
519+
parallel_tool_calls=model_settings.get('parallel_tool_calls', OMIT),
520+
tools=tools or OMIT,
521+
tool_choice=tool_choice or OMIT,
505522
stream=stream,
506-
stream_options={'include_usage': True} if stream else NOT_GIVEN,
507-
stop=model_settings.get('stop_sequences', NOT_GIVEN),
508-
max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
523+
stream_options={'include_usage': True} if stream else OMIT,
524+
stop=model_settings.get('stop_sequences', OMIT),
525+
max_completion_tokens=model_settings.get('max_tokens', OMIT),
509526
timeout=model_settings.get('timeout', NOT_GIVEN),
510-
response_format=response_format or NOT_GIVEN,
511-
seed=model_settings.get('seed', NOT_GIVEN),
512-
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
513-
user=model_settings.get('openai_user', NOT_GIVEN),
514-
web_search_options=web_search_options or NOT_GIVEN,
515-
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
516-
prediction=model_settings.get('openai_prediction', NOT_GIVEN),
517-
temperature=model_settings.get('temperature', NOT_GIVEN),
518-
top_p=model_settings.get('top_p', NOT_GIVEN),
519-
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
520-
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
521-
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
522-
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
523-
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
527+
response_format=response_format or OMIT,
528+
seed=model_settings.get('seed', OMIT),
529+
reasoning_effort=model_settings.get('openai_reasoning_effort', OMIT),
530+
user=model_settings.get('openai_user', OMIT),
531+
web_search_options=web_search_options or OMIT,
532+
service_tier=model_settings.get('openai_service_tier', OMIT),
533+
prediction=model_settings.get('openai_prediction', OMIT),
534+
temperature=model_settings.get('temperature', OMIT),
535+
top_p=model_settings.get('top_p', OMIT),
536+
presence_penalty=model_settings.get('presence_penalty', OMIT),
537+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
538+
logit_bias=model_settings.get('logit_bias', OMIT),
539+
logprobs=model_settings.get('openai_logprobs', OMIT),
540+
top_logprobs=model_settings.get('openai_top_logprobs', OMIT),
524541
extra_headers=extra_headers,
525542
extra_body=model_settings.get('extra_body'),
526543
)
@@ -1184,7 +1201,7 @@ async def _responses_create(
11841201
# Apparently they're only checking input messages for "JSON", not instructions.
11851202
assert isinstance(instructions, str)
11861203
openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions))
1187-
instructions = NOT_GIVEN
1204+
instructions = OMIT
11881205

11891206
if verbosity := model_settings.get('openai_text_verbosity'):
11901207
text = text or {}
@@ -1200,7 +1217,7 @@ async def _responses_create(
12001217
if model_settings.get('openai_include_code_execution_outputs'):
12011218
include.append('code_interpreter_call.outputs')
12021219
if model_settings.get('openai_include_web_search_sources'):
1203-
include.append('web_search_call.action.sources') # pyright: ignore[reportArgumentType]
1220+
include.append('web_search_call.action.sources')
12041221

12051222
try:
12061223
extra_headers = model_settings.get('extra_headers', {})
@@ -1209,21 +1226,21 @@ async def _responses_create(
12091226
input=openai_messages,
12101227
model=self._model_name,
12111228
instructions=instructions,
1212-
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
1213-
tools=tools or NOT_GIVEN,
1214-
tool_choice=tool_choice or NOT_GIVEN,
1215-
max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
1229+
parallel_tool_calls=model_settings.get('parallel_tool_calls', OMIT),
1230+
tools=tools or OMIT,
1231+
tool_choice=tool_choice or OMIT,
1232+
max_output_tokens=model_settings.get('max_tokens', OMIT),
12161233
stream=stream,
1217-
temperature=model_settings.get('temperature', NOT_GIVEN),
1218-
top_p=model_settings.get('top_p', NOT_GIVEN),
1219-
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
1234+
temperature=model_settings.get('temperature', OMIT),
1235+
top_p=model_settings.get('top_p', OMIT),
1236+
truncation=model_settings.get('openai_truncation', OMIT),
12201237
timeout=model_settings.get('timeout', NOT_GIVEN),
1221-
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
1222-
previous_response_id=previous_response_id or NOT_GIVEN,
1238+
service_tier=model_settings.get('openai_service_tier', OMIT),
1239+
previous_response_id=previous_response_id or OMIT,
12231240
reasoning=reasoning,
1224-
user=model_settings.get('openai_user', NOT_GIVEN),
1225-
text=text or NOT_GIVEN,
1226-
include=include or NOT_GIVEN,
1241+
user=model_settings.get('openai_user', OMIT),
1242+
text=text or OMIT,
1243+
include=include or OMIT,
12271244
extra_headers=extra_headers,
12281245
extra_body=model_settings.get('extra_body'),
12291246
)
@@ -1232,7 +1249,7 @@ async def _responses_create(
12321249
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
12331250
raise # pragma: lax no cover
12341251

1235-
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
1252+
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | Omit:
12361253
reasoning_effort = model_settings.get('openai_reasoning_effort', None)
12371254
reasoning_summary = model_settings.get('openai_reasoning_summary', None)
12381255
reasoning_generate_summary = model_settings.get('openai_reasoning_generate_summary', None)
@@ -1248,7 +1265,7 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason
12481265
reasoning_summary = reasoning_generate_summary
12491266

12501267
if reasoning_effort is None and reasoning_summary is None:
1251-
return NOT_GIVEN
1268+
return OMIT
12521269
return Reasoning(effort=reasoning_effort, summary=reasoning_summary)
12531270

12541271
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
@@ -1358,7 +1375,7 @@ async def _map_messages( # noqa: C901
13581375
messages: list[ModelMessage],
13591376
model_settings: OpenAIResponsesModelSettings,
13601377
model_request_parameters: ModelRequestParameters,
1361-
) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
1378+
) -> tuple[str | Omit, list[responses.ResponseInputItemParam]]:
13621379
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
13631380
profile = OpenAIModelProfile.from_profile(self.profile)
13641381
send_item_ids = model_settings.get(
@@ -1582,7 +1599,7 @@ async def _map_messages( # noqa: C901
15821599
assert_never(item)
15831600
else:
15841601
assert_never(message)
1585-
instructions = self._get_instructions(messages, model_request_parameters) or NOT_GIVEN
1602+
instructions = self._get_instructions(messages, model_request_parameters) or OMIT
15861603
return instructions, openai_messages
15871604

15881605
def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam:

0 commit comments

Comments
 (0)