1010import tracemalloc
1111from typing import Any , Union , cast
1212from unittest import mock
13+ from typing_extensions import Literal
1314
1415import httpx
1516import pytest
@@ -764,7 +765,14 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non
764765 @pytest .mark .parametrize ("failures_before_success" , [0 , 2 , 4 ])
765766 @mock .patch ("openai._base_client.BaseClient._calculate_retry_timeout" , _low_retry_timeout )
766767 @pytest .mark .respx (base_url = base_url )
767- def test_retries_taken (self , client : OpenAI , failures_before_success : int , respx_mock : MockRouter ) -> None :
768+ @pytest .mark .parametrize ("failure_mode" , ["status" , "exception" ])
769+ def test_retries_taken (
770+ self ,
771+ client : OpenAI ,
772+ failures_before_success : int ,
773+ failure_mode : Literal ["status" , "exception" ],
774+ respx_mock : MockRouter ,
775+ ) -> None :
768776 client = client .with_options (max_retries = 4 )
769777
770778 nb_retries = 0
@@ -773,6 +781,8 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
773781 nonlocal nb_retries
774782 if nb_retries < failures_before_success :
775783 nb_retries += 1
784+ if failure_mode == "exception" :
785+ raise RuntimeError ("oops" )
776786 return httpx .Response (500 )
777787 return httpx .Response (200 )
778788
@@ -1623,8 +1633,13 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter)
16231633 @mock .patch ("openai._base_client.BaseClient._calculate_retry_timeout" , _low_retry_timeout )
16241634 @pytest .mark .respx (base_url = base_url )
16251635 @pytest .mark .asyncio
1636+ @pytest .mark .parametrize ("failure_mode" , ["status" , "exception" ])
16261637 async def test_retries_taken (
1627- self , async_client : AsyncOpenAI , failures_before_success : int , respx_mock : MockRouter
1638+ self ,
1639+ async_client : AsyncOpenAI ,
1640+ failures_before_success : int ,
1641+ failure_mode : Literal ["status" , "exception" ],
1642+ respx_mock : MockRouter ,
16281643 ) -> None :
16291644 client = async_client .with_options (max_retries = 4 )
16301645
@@ -1634,6 +1649,8 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
16341649 nonlocal nb_retries
16351650 if nb_retries < failures_before_success :
16361651 nb_retries += 1
1652+ if failure_mode == "exception" :
1653+ raise RuntimeError ("oops" )
16371654 return httpx .Response (500 )
16381655 return httpx .Response (200 )
16391656
0 commit comments