Skip to content

Commit d05f012

Browse files
committed
fallback for openai api connection error
1 parent 359c6d2 commit d05f012

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
5555

5656
try:
57-
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
57+
from openai import NOT_GIVEN, APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
5858
from openai.types import AllModels, chat, responses
5959
from openai.types.chat import (
6060
ChatCompletionChunk,
@@ -546,6 +546,8 @@ async def _completions_create(
546546
if (status_code := e.status_code) >= 400:
547547
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
548548
raise # pragma: lax no cover
549+
except APIConnectionError as e:
550+
raise ModelHTTPError(status_code=0, model_name=self.model_name, body=str(e)) from e
549551

550552
def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse:
551553
"""Process a non-streamed response, and prepare a message to return."""
@@ -1252,6 +1254,8 @@ async def _responses_create(
12521254
if (status_code := e.status_code) >= 400:
12531255
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
12541256
raise # pragma: lax no cover
1257+
except APIConnectionError as e:
1258+
raise ModelHTTPError(status_code=0, model_name=self.model_name, body=str(e)) from e
12551259

12561260
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | Omit:
12571261
reasoning_effort = model_settings.get('openai_reasoning_effort', None)

tests/models/test_fallback.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,18 @@ async def test_fallback_condition_tuple() -> None:
570570
assert response.output == 'success'
571571

572572

573+
async def test_fallback_connection_error() -> None:
574+
def connection_error_response(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse:
575+
raise ModelHTTPError(status_code=0, model_name='test-connection-model', body='Connection timed out')
576+
577+
connection_error_model = FunctionModel(connection_error_response)
578+
fallback_model = FallbackModel(connection_error_model, success_model)
579+
agent = Agent(model=fallback_model)
580+
581+
response = await agent.run('hello')
582+
assert response.output == 'success'
583+
584+
573585
async def test_fallback_model_settings_merge():
574586
"""Test that FallbackModel properly merges model settings from wrapped model and runtime settings."""
575587

tests/models/test_openai.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
)
5757

5858
with try_import() as imports_successful:
59-
from openai import APIStatusError, AsyncOpenAI
59+
from openai import APIConnectionError, APIStatusError, AsyncOpenAI
6060
from openai.types import chat
6161
from openai.types.chat.chat_completion import ChoiceLogprobs
6262
from openai.types.chat.chat_completion_chunk import (
@@ -1146,6 +1146,22 @@ def test_model_status_error(allow_model_requests: None) -> None:
11461146
assert str(exc_info.value) == snapshot("status_code: 500, model_name: gpt-4o, body: {'error': 'test error'}")
11471147

11481148

1149+
def test_model_connection_error(allow_model_requests: None) -> None:
1150+
mock_client = MockOpenAI.create_mock(
1151+
APIConnectionError(
1152+
message='Connection to http://localhost:11434/v1 timed out',
1153+
request=httpx.Request('POST', 'http://localhost:11434/v1'),
1154+
)
1155+
)
1156+
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
1157+
agent = Agent(m)
1158+
with pytest.raises(ModelHTTPError) as exc_info:
1159+
agent.run_sync('hello')
1160+
assert exc_info.value.status_code == 0
1161+
assert exc_info.value.model_name == 'gpt-4o'
1162+
assert 'Connection to http://localhost:11434/v1 timed out' in str(exc_info.value.body)
1163+
1164+
11491165
@pytest.mark.parametrize('model_name', ['o3-mini', 'gpt-4o-mini', 'gpt-4.5-preview'])
11501166
async def test_max_completion_tokens(allow_model_requests: None, model_name: str, openai_api_key: str):
11511167
m = OpenAIChatModel(model_name, provider=OpenAIProvider(api_key=openai_api_key))

0 commit comments

Comments
 (0)