Skip to content

Commit f5ff647

Browse files
gabrielmbmbburtenshawfrascuchon
authored
[FEATURE] Add retries to the internal httpx.Client used by the SDK (#5386)
# Description This PR adds a new argument `retries` that can be used to specify the number of times that an HTTP request performed by the internal `httpx.Client` used by the SDK should be retried before raising an exception. This is useful as sometimes while using the `dataset.records.log` you can receive a `ConnectionError` or 5xx from the server and if you retry a few seconds later everything is fine. **Type of change** - Improvement **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Ben Burtenshaw <[email protected]> Co-authored-by: burtenshaw <[email protected]> Co-authored-by: Paco Aranda <[email protected]>
1 parent a9dd0fb commit f5ff647

File tree

4 files changed

+44
-17
lines changed

4 files changed

+44
-17
lines changed

argilla/src/argilla/_api/_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ class APIClient:
105105
def __init__(
106106
self,
107107
api_url: Optional[str] = DEFAULT_HTTP_CONFIG.api_url,
108-
api_key: str = DEFAULT_HTTP_CONFIG.api_key,
108+
api_key: Optional[str] = DEFAULT_HTTP_CONFIG.api_key,
109109
timeout: int = DEFAULT_HTTP_CONFIG.timeout,
110+
retries: int = DEFAULT_HTTP_CONFIG.retries,
110111
**http_client_args,
111112
):
112113
if not api_url:
@@ -120,6 +121,7 @@ def __init__(
120121

121122
http_client_args = http_client_args or {}
122123
http_client_args["timeout"] = timeout
124+
http_client_args["retries"] = retries
123125

124126
self.http_client = create_http_client(
125127
api_url=self.api_url, # type: ignore

argilla/src/argilla/_api/_http/_client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,8 @@ class HTTPClientConfig:
2323

2424
api_url: str
2525
api_key: str
26-
timeout: int = None
27-
28-
def __post_init__(self):
29-
self.api_url = self.api_url
30-
self.api_key = self.api_key
31-
self.timeout = self.timeout or 60
26+
timeout: int = 60
27+
retries: int = 5
3228

3329

3430
def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Client:
@@ -37,5 +33,11 @@ def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Clien
3733

3834
headers = client_args.pop("headers", {})
3935
headers["X-Argilla-Api-Key"] = api_key
40-
41-
return httpx.Client(base_url=api_url, headers=headers, **client_args)
36+
retries = client_args.pop("retries", 0)
37+
38+
return httpx.Client(
39+
base_url=api_url,
40+
headers=headers,
41+
transport=httpx.HTTPTransport(retries=retries),
42+
**client_args,
43+
)

argilla/src/argilla/client.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,8 @@ class Argilla(_api.APIClient):
4141
datasets: A collection of datasets.
4242
users: A collection of users.
4343
me: The current user.
44-
4544
"""
4645

47-
workspaces: "Workspaces"
48-
datasets: "Datasets"
49-
users: "Users"
50-
me: "User"
51-
5246
# Default instance of Argilla
5347
_default_client: Optional["Argilla"] = None
5448

@@ -57,9 +51,24 @@ def __init__(
5751
api_url: Optional[str] = DEFAULT_HTTP_CONFIG.api_url,
5852
api_key: Optional[str] = DEFAULT_HTTP_CONFIG.api_key,
5953
timeout: int = DEFAULT_HTTP_CONFIG.timeout,
54+
retries: int = DEFAULT_HTTP_CONFIG.retries,
6055
**http_client_args,
6156
) -> None:
62-
super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, **http_client_args)
57+
"""Inits the `Argilla` client.
58+
59+
Args:
60+
api_url: the URL of the Argilla API. If not provided, then the value will try
61+
to be set from `ARGILLA_API_URL` environment variable. Defaults to
62+
`"http://localhost:6900"`.
63+
api_key: the key to be used to authenticate in the Argilla API. If not provided,
64+
then the value will try to be set from `ARGILLA_API_KEY` environment variable.
65+
Defaults to `None`.
66+
timeout: the maximum time in seconds to wait for a request to the Argilla API
67+
to be completed before raising an exception. Defaults to `60`.
68+
retries: the number of times to retry the HTTP connection to the Argilla API
69+
before raising an exception. Defaults to `5`.
70+
"""
71+
super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, retries=retries, **http_client_args)
6372

6473
self._set_default(self)
6574

argilla/tests/unit/api/http/test_http_client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from httpx import Timeout
15+
from unittest.mock import MagicMock, patch
1616

17+
import pytest
1718
from argilla import Argilla
19+
from httpx import Timeout
1820

1921

2022
class TestHTTPClient:
@@ -62,3 +64,15 @@ def test_create_client_with_extra_cookies(self):
6264
assert http_client.base_url == "http://localhost:6900"
6365
assert http_client.headers["X-Argilla-Api-Key"] == "argilla.apikey"
6466
assert http_client.cookies["session"] == "session_id"
67+
68+
@pytest.mark.parametrize("retries", [0, 1, 5, 10])
69+
def test_create_client_with_various_retries(self, retries):
70+
with patch("argilla._api._client.create_http_client") as mock_create_http_client:
71+
mock_http_client = MagicMock()
72+
mock_create_http_client.return_value = mock_http_client
73+
74+
Argilla(api_url="http://test.com", api_key="test_key", retries=retries)
75+
76+
mock_create_http_client.assert_called_once_with(
77+
api_url="http://test.com", api_key="test_key", timeout=60, retries=retries
78+
)

0 commit comments

Comments
 (0)