Skip to content

Commit 7bd8ce3

Browse files
authored
[BUGFIX] argilla: raise error adding record responses when a response with same question_name and user_id is found (#5209)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR prevents adding multiple question responses per user which result in a server error. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 5cdd7f5 commit 7bd8ce3

File tree

8 files changed

+48
-17
lines changed

8 files changed

+48
-17
lines changed

argilla/src/argilla/_exceptions/_api.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Optional
1415

1516
from httpx import HTTPStatusError
1617

17-
from argilla._exceptions._base import ArgillaErrorBase
18+
from argilla._exceptions._base import ArgillaError
1819

1920

20-
class ArgillaAPIError(ArgillaErrorBase):
21-
pass
21+
class ArgillaAPIError(ArgillaError):
22+
def __init__(self, message: Optional[str] = None, status_code: int = 500):
23+
"""Base class for all Argilla API exceptions
24+
Args:
25+
message (str): The message to display when the exception is raised
26+
status_code (int): The status code of the response that caused the exception
27+
"""
28+
super().__init__(message)
29+
self.status_code = status_code
2230

2331

2432
class BadRequestError(ArgillaAPIError):

argilla/src/argilla/_exceptions/_base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Optional
1415

1516

16-
class ArgillaErrorBase(Exception):
17+
class ArgillaError(Exception):
1718
message_stub = "Argilla SDK error"
18-
message: str = message_stub
1919

20-
def __init__(self, message: str = message, status_code: int = 500):
20+
def __init__(self, message: Optional[str] = None):
2121
"""Base class for all Argilla exceptions
2222
Args:
2323
message (str): The message to display when the exception is raised
24-
status_code (int): The status code of the response that caused the exception
2524
"""
26-
super().__init__(message)
27-
self.status_code = status_code
25+
super().__init__(message or self.message_stub)
2826

2927
def __str__(self):
3028
return f"{self.message_stub}: {self.__class__.__name__}: {super().__str__()}"

argilla/src/argilla/_exceptions/_metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from argilla._exceptions._base import ArgillaErrorBase
15+
from argilla._exceptions._base import ArgillaError
1616

1717

18-
class MetadataError(ArgillaErrorBase):
18+
class MetadataError(ArgillaError):
1919
message: str = "Error defining dataset metadata settings"

argilla/src/argilla/_exceptions/_records.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from argilla._exceptions._base import ArgillaErrorBase
15+
from argilla._exceptions._base import ArgillaError
1616

1717

18-
class RecordsIngestionError(ArgillaErrorBase):
18+
class RecordsIngestionError(ArgillaError):
1919
pass

argilla/src/argilla/_exceptions/_serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from argilla._exceptions._base import ArgillaErrorBase
15+
from argilla._exceptions._base import ArgillaError
1616

1717

18-
class ArgillaSerializeError(ArgillaErrorBase):
18+
class ArgillaSerializeError(ArgillaError):
1919
pass

argilla/src/argilla/_exceptions/_settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from argilla._exceptions._base import ArgillaErrorBase
15+
from argilla._exceptions._base import ArgillaError
1616

1717

18-
class SettingsError(ArgillaErrorBase):
18+
class SettingsError(ArgillaError):
1919
pass

argilla/src/argilla/records/_resource.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Iterable
1717
from uuid import UUID, uuid4
1818

19+
from argilla._exceptions import ArgillaError
1920
from argilla._models import (
2021
MetadataModel,
2122
RecordModel,
@@ -324,6 +325,9 @@ def __iter__(self):
324325
def __getitem__(self, name: str):
325326
return self.__responses_by_question_name[name]
326327

328+
def __len__(self):
329+
return len(self.__responses)
330+
327331
def __repr__(self) -> str:
328332
return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__()
329333

@@ -354,10 +358,21 @@ def add(self, response: Response) -> None:
354358
Args:
355359
response: The response to add.
356360
"""
361+
self._check_response_already_exists(response)
362+
357363
response.record = self.record
358364
self.__responses.append(response)
359365
self.__responses_by_question_name[response.question_name].append(response)
360366

367+
def _check_response_already_exists(self, response: Response) -> None:
368+
"""Checks if a response for the same question name and user id already exists"""
369+
for response in self.__responses_by_question_name[response.question_name]:
370+
if response.user_id == response.user_id:
371+
raise ArgillaError(
372+
f"Response for question with name {response.question_name!r} and user id {response.user_id!r} "
373+
f"already found. The responses for the same question name do not support more than one user"
374+
)
375+
361376

362377
class RecordSuggestions(Iterable[Suggestion]):
363378
"""This is a container class for the suggestions of a Record.

argilla/tests/unit/test_resources/test_records.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import uuid
1616

17+
import pytest
18+
1719
from argilla import Record, Suggestion, Response
20+
from argilla._exceptions import ArgillaError
1821
from argilla._models import MetadataModel
1922

2023

@@ -62,3 +65,10 @@ def test_update_record_vectors(self):
6265

6366
record.vectors["new-vector"] = [1.0, 2.0, 3.0]
6467
assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]}
68+
69+
def test_add_record_response_for_the_same_question_and_user_id(self):
70+
response = Response(question_name="question", value="value", user_id=uuid.uuid4())
71+
record = Record(fields={"name": "John"}, responses=[response])
72+
73+
with pytest.raises(ArgillaError):
74+
record.responses.add(response)

0 commit comments

Comments
 (0)