Skip to content

Commit bf08e83

Browse files
authored
[BUGFIX] argilla: prevent errors when vector, suggestion, or responses are None (#5491)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> The API can return records with `None` values for vectors, suggestions, or responses, making the SDK raise an error. This PR prevents those errors. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent bafd92f commit bf08e83

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

argilla/src/argilla/_models/_record/_record.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,29 @@ def validate_external_id(cls, external_id: Any) -> Union[str, int, uuid.UUID]:
7777
if external_id is None:
7878
external_id = uuid.uuid4()
7979
return external_id
80+
81+
@field_validator("vectors", mode="before")
82+
@classmethod
83+
def empty_vectors_if_none(cls, vectors: Optional[List[VectorModel]]) -> Optional[List[VectorModel]]:
84+
"""Ensure vectors is None if not provided."""
85+
if vectors is None:
86+
return []
87+
return vectors
88+
89+
@field_validator("responses", mode="before")
90+
@classmethod
91+
def empty_responses_if_none(cls, responses: Optional[List[UserResponseModel]]) -> Optional[List[UserResponseModel]]:
92+
"""Ensure responses is None if not provided."""
93+
if responses is None:
94+
return []
95+
return responses
96+
97+
@field_validator("suggestions", mode="before")
98+
@classmethod
99+
def empty_suggestions_if_none(
100+
cls, suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]]
101+
) -> Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]]:
102+
"""Ensure suggestions is None if not provided."""
103+
if suggestions is None:
104+
return []
105+
return suggestions

argilla/src/argilla/records/_resource.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,9 @@ def __iter__(self):
408408
def __getitem__(self, question_name: str):
409409
return self._suggestion_by_question_name[question_name]
410410

411+
def __len__(self):
412+
return len(self._suggestion_by_question_name)
413+
411414
def __repr__(self) -> str:
412415
return self.to_dict().__repr__()
413416

argilla/tests/unit/test_resources/test_records.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,21 @@
1515
import uuid
1616

1717
import pytest
18-
from argilla import Record, Response, Suggestion
18+
19+
from argilla import Record, Response, Suggestion, Dataset, Settings, TextQuestion, TextField
1920
from argilla._exceptions import ArgillaError
20-
from argilla._models import MetadataModel
21+
from argilla._models import MetadataModel, RecordModel
22+
23+
24+
@pytest.fixture()
25+
def dataset():
26+
return Dataset(
27+
name="test_dataset",
28+
settings=Settings(
29+
fields=[TextField(name="name", required=True), TextField(name="age", required=True)],
30+
questions=[TextQuestion(name="question", required=True)],
31+
),
32+
)
2133

2234

2335
class TestRecords:
@@ -88,3 +100,18 @@ def test_add_record_response_for_the_same_question_and_user_id(self):
88100

89101
with pytest.raises(ArgillaError):
90102
record.responses.add(response)
103+
104+
def test_record_from_model_with_none_vectors(self, dataset: Dataset):
105+
record = Record.from_model(RecordModel(fields={"name": "John"}, vectors=None), dataset=dataset)
106+
107+
assert len(record.vectors) == 0
108+
109+
def test_record_from_model_with_none_suggestions(self, dataset: Dataset):
110+
record = Record.from_model(RecordModel(fields={"name": "John"}, suggestions=None), dataset=dataset)
111+
112+
assert len(record.suggestions) == 0
113+
114+
def test_record_from_model_with_none_responses(self, dataset: Dataset):
115+
record = Record.from_model(RecordModel(fields={"name": "John"}, responses=None), dataset=dataset)
116+
117+
assert len(record.responses) == 0

0 commit comments

Comments
 (0)