Skip to content

Commit de00032

Browse files
authored
Fix inability to call response.raise_for_status in AsyncTenacityTransport (#2668)
1 parent e603f68 commit de00032

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

pydantic_ai_slim/pydantic_ai/retries.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from datetime import datetime, timezone
3030
from email.utils import parsedate_to_datetime
31-
from typing import Callable, cast
31+
from typing import Any, Callable, cast
3232

3333
from httpx import HTTPStatusError
3434
from tenacity import RetryCallState, wait_exponential
@@ -78,7 +78,7 @@ def __init__(
7878
self,
7979
controller: Retrying,
8080
wrapped: BaseTransport | None = None,
81-
validate_response: Callable[[Response], None] | None = None,
81+
validate_response: Callable[[Response], Any] | None = None,
8282
):
8383
self.controller = controller
8484
self.wrapped = wrapped or HTTPTransport()
@@ -100,6 +100,10 @@ def handle_request(self, request: Request) -> Response:
100100
for attempt in self.controller:
101101
with attempt:
102102
response = self.wrapped.handle_request(request)
103+
104+
# this is normally set by httpx _after_ calling this function, but we want the request in the validator:
105+
response.request = request
106+
103107
if self.validate_response:
104108
self.validate_response(response)
105109
return response
@@ -149,7 +153,7 @@ def __init__(
149153
self,
150154
controller: AsyncRetrying,
151155
wrapped: AsyncBaseTransport | None = None,
152-
validate_response: Callable[[Response], None] | None = None,
156+
validate_response: Callable[[Response], Any] | None = None,
153157
):
154158
self.controller = controller
155159
self.wrapped = wrapped or AsyncHTTPTransport()
@@ -171,6 +175,10 @@ async def handle_async_request(self, request: Request) -> Response:
171175
async for attempt in self.controller:
172176
with attempt:
173177
response = await self.wrapped.handle_async_request(request)
178+
179+
# this is normally set by httpx _after_ calling this function, but we want the request in the validator:
180+
response.request = request
181+
174182
if self.validate_response:
175183
self.validate_response(response)
176184
return response

tests/test_tenacity.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,54 @@ def validate_response(response: httpx.Response):
134134
assert result is mock_response_success
135135
assert mock_transport.handle_request.call_count == 2
136136

137+
def test_raise_for_status_in_validate_response(self):
138+
"""Test that response.raise_for_status() works in validate_response callback."""
139+
mock_transport = Mock(spec=httpx.BaseTransport)
140+
mock_response_fail = Mock(spec=httpx.Response)
141+
mock_response_fail.status_code = 429
142+
mock_response_fail.is_success = False
143+
mock_response_fail.is_error = True
144+
mock_response_fail.request = None # Initially None, will be set by transport
145+
146+
# Mock raise_for_status to check if request is set
147+
def mock_raise_for_status():
148+
if mock_response_fail.request is None:
149+
raise RuntimeError( # pragma: no cover
150+
'Cannot call `raise_for_status` as the request instance has not been set on this response.'
151+
)
152+
raise httpx.HTTPStatusError(
153+
'Too Many Requests', request=mock_response_fail.request, response=mock_response_fail
154+
)
155+
156+
mock_response_fail.raise_for_status = mock_raise_for_status
157+
158+
mock_response_success = Mock(spec=httpx.Response)
159+
mock_response_success.status_code = 200
160+
mock_response_success.is_success = True
161+
mock_response_success.is_error = False
162+
mock_response_success.raise_for_status = Mock() # Should not raise
163+
164+
mock_transport.handle_request.side_effect = [mock_response_fail, mock_response_success]
165+
166+
controller = Retrying(
167+
retry=retry_if_exception_type(httpx.HTTPStatusError),
168+
stop=stop_after_attempt(3),
169+
wait=wait_fixed(0.001),
170+
reraise=True,
171+
)
172+
transport = TenacityTransport(
173+
controller, mock_transport, validate_response=lambda response: response.raise_for_status()
174+
)
175+
176+
request = httpx.Request('GET', 'https://example.com')
177+
result = transport.handle_request(request)
178+
179+
assert result is mock_response_success
180+
assert mock_transport.handle_request.call_count == 2
181+
# Verify that the request was set on the failed response before raise_for_status was called
182+
assert mock_response_fail.request is request
183+
mock_response_success.raise_for_status.assert_called_once()
184+
137185

138186
class TestAsyncTenacityTransport:
139187
"""Tests for the asynchronous AsyncTenacityTransport."""
@@ -243,6 +291,54 @@ def validate_response(response: httpx.Response):
243291
assert result is mock_response_success
244292
assert mock_transport.handle_async_request.call_count == 2
245293

294+
async def test_raise_for_status_in_validate_response(self):
295+
"""Test that response.raise_for_status() works in validate_response callback."""
296+
mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport)
297+
mock_response_fail = Mock(spec=httpx.Response)
298+
mock_response_fail.status_code = 429
299+
mock_response_fail.is_success = False
300+
mock_response_fail.is_error = True
301+
mock_response_fail.request = None # Initially None, will be set by transport
302+
303+
# Mock raise_for_status to check if request is set
304+
def mock_raise_for_status():
305+
if mock_response_fail.request is None:
306+
raise RuntimeError( # pragma: no cover
307+
'Cannot call `raise_for_status` as the request instance has not been set on this response.'
308+
)
309+
raise httpx.HTTPStatusError(
310+
'Too Many Requests', request=mock_response_fail.request, response=mock_response_fail
311+
)
312+
313+
mock_response_fail.raise_for_status = mock_raise_for_status
314+
315+
mock_response_success = Mock(spec=httpx.Response)
316+
mock_response_success.status_code = 200
317+
mock_response_success.is_success = True
318+
mock_response_success.is_error = False
319+
mock_response_success.raise_for_status = Mock() # Should not raise
320+
321+
mock_transport.handle_async_request.side_effect = [mock_response_fail, mock_response_success]
322+
323+
controller = AsyncRetrying(
324+
retry=retry_if_exception_type(httpx.HTTPStatusError),
325+
stop=stop_after_attempt(3),
326+
wait=wait_fixed(0.001),
327+
reraise=True,
328+
)
329+
transport = AsyncTenacityTransport(
330+
controller, mock_transport, validate_response=lambda response: response.raise_for_status()
331+
)
332+
333+
request = httpx.Request('GET', 'https://example.com')
334+
result = await transport.handle_async_request(request)
335+
336+
assert result is mock_response_success
337+
assert mock_transport.handle_async_request.call_count == 2
338+
# Verify that the request was set on the failed response before raise_for_status was called
339+
assert mock_response_fail.request is request
340+
mock_response_success.raise_for_status.assert_called_once()
341+
246342

247343
class TestWaitRetryAfter:
248344
"""Tests for the wait_retry_after wait strategy."""

0 commit comments

Comments
 (0)