Skip to content

Commit c902c1e

Browse files
[BUGFIX] validate client credentials to make error messages clearer (#5384)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR improves error messaging when users pass incorrect credentials to the `Argilla` client. It does these things: - adds a custom credentials exception - logs the user details when client is successfully inited - raise the custom exception when above fails Closes #<issue_number> **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> No new tests. **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1d75656 commit c902c1e

File tree

6 files changed

+141
-7
lines changed

6 files changed

+141
-7
lines changed

argilla-server/pdm.lock

Lines changed: 86 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

argilla-server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ dependencies = [
5656
# for CLI
5757
"typer >= 0.6.0, < 0.10.0", # spaCy only supports typer<0.10.0
5858
"packaging>=23.2",
59+
"psycopg2-binary>=2.9.9",
5960
]
6061

6162
[project.optional-dependencies]

argilla/src/argilla/_api/_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from typing import Optional
1818

1919
import httpx
20+
from argilla._exceptions._api import UnauthorizedError
21+
from argilla._exceptions._client import ArgillaCredentialsError
2022

2123
from argilla._api import HTTPClientConfig, create_http_client
2224
from argilla._api._datasets import DatasetsAPI
@@ -127,6 +129,11 @@ def __init__(
127129

128130
self.api = ArgillaAPI(self.http_client)
129131

132+
try:
133+
self._validate_connection()
134+
except UnauthorizedError as e:
135+
raise ArgillaCredentialsError() from e
136+
130137
##############################
131138
# Utility methods
132139
##############################
@@ -135,3 +142,8 @@ def log(self, message: str, level: int = logging.INFO) -> None:
135142
class_name = self.__class__.__name__
136143
message = f"{class_name}: {message}"
137144
logging.log(level=level, msg=message)
145+
146+
def _validate_connection(self) -> None:
147+
user = self.api.users.get_me()
148+
message = f"Logged in as {user.username} with the role {user.role}"
149+
self.log(message=message, level=logging.INFO)

argilla/src/argilla/_exceptions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from argilla._exceptions._api import * # noqa: F403
16+
from argilla._exceptions._client import * # noqa: F403
1617
from argilla._exceptions._metadata import * # noqa: F403
1718
from argilla._exceptions._serialization import * # noqa: F403
1819
from argilla._exceptions._settings import * # noqa: F403
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from argilla._exceptions._base import ArgillaError
16+
17+
18+
class ArgillaCredentialsError(ArgillaError):
19+
def __init__(self, message: str = "Credentials (api_key and/or api_url) are invalid") -> None:
20+
super().__init__(message)

argilla/tests/unit/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,24 @@
1919
# argilla.DEFAULT_HTTP_CLIENT = mock_client
2020

2121
# return mock_client
22+
23+
import pytest
24+
from unittest.mock import patch
25+
from httpx import Timeout
26+
from argilla import Argilla
27+
28+
29+
@pytest.fixture(autouse=True)
30+
def mock_validate_connection():
31+
with patch("argilla._api._client.APIClient._validate_connection") as mocked_validator:
32+
yield mocked_validator
33+
34+
35+
# Example usage in a test module
36+
def test_create_default_client(mock_validate_connection):
37+
http_client = Argilla().http_client
38+
39+
assert http_client is not None
40+
assert http_client.base_url == "http://localhost:6900"
41+
assert http_client.timeout == Timeout(60)
42+
assert http_client.headers["X-Argilla-Api-Key"] == "argilla.apikey"

0 commit comments

Comments
 (0)