Skip to content

Commit 01f8e2e

Browse files
authored
5272 bug pythondeployment copying records with suggestions produces an unprocessableentityerror (#5282)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR fixes errors when creating records with suggestions from other datasets. Also, remove the internal `question_id` and `id` from the suggestion __init__ method. Closes #5272 **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 0e88ede commit 01f8e2e

File tree

3 files changed

+45
-28
lines changed

3 files changed

+45
-28
lines changed

argilla/src/argilla/records/_mapping/_mapper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,10 @@ def _map_suggestions(self, data: Dict[str, Any], mapping) -> List[Suggestion]:
242242
parameters = {param.parameter_type: data.get(param.source) for param in route.parameters}
243243
if parameters.get(ParameterType.VALUE) is None:
244244
continue
245-
schema_item = self._dataset.schema.get(name)
245+
question = self._dataset.questions[name]
246246
suggestion = Suggestion(
247247
**parameters,
248-
question_name=route.name,
249-
question_id=schema_item.id,
248+
question_name=question.name,
250249
)
251250
suggestions.append(suggestion)
252251

argilla/src/argilla/suggestions.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Optional, Literal, Union, List, TYPE_CHECKING, Dict
15-
from uuid import UUID
1615

1716
from argilla._models import SuggestionModel
1817
from argilla._resource import Resource
@@ -29,13 +28,11 @@ class Suggestion(Resource):
2928
Suggestions are rendered in the user interfaces as 'hints' or 'suggestions' for the user to review and accept or reject.
3029
3130
Attributes:
32-
value (str): The value of the suggestion.add()
3331
question_name (str): The name of the question that the suggestion is for.
34-
type (str): The type of suggestion, either 'model' or 'human'.
32+
value (str): The value of the suggestion
3533
score (float): The score of the suggestion. For example, the probability of the model prediction.
3634
agent (str): The agent that created the suggestion. For example, the model name.
37-
question_id (UUID): The ID of the question that the suggestion is for.
38-
35+
type (str): The type of suggestion, either 'model' or 'human'.
3936
"""
4037

4138
_model: SuggestionModel
@@ -47,8 +44,6 @@ def __init__(
4744
score: Union[float, List[float], None] = None,
4845
agent: Optional[str] = None,
4946
type: Optional[Literal["model", "human"]] = None,
50-
id: Optional[UUID] = None,
51-
question_id: Optional[UUID] = None,
5247
_record: Optional["Record"] = None,
5348
) -> None:
5449
super().__init__()
@@ -60,9 +55,7 @@ def __init__(
6055

6156
self.record = _record
6257
self._model = SuggestionModel(
63-
id=id,
6458
question_name=question_name,
65-
question_id=question_id,
6659
value=value,
6760
type=type,
6861
score=score,
@@ -87,15 +80,6 @@ def question_name(self) -> Optional[str]:
8780
def question_name(self, value: str) -> None:
8881
self._model.question_name = value
8982

90-
@property
91-
def question_id(self) -> Optional[UUID]:
92-
"""The ID of the question that the suggestion is for."""
93-
return self._model.question_id
94-
95-
@question_id.setter
96-
def question_id(self, value: UUID) -> None:
97-
self._model.question_id = value
98-
9983
@property
10084
def type(self) -> Optional[Literal["model", "human"]]:
10185
"""The type of suggestion, either 'model' or 'human'."""
@@ -125,7 +109,10 @@ def from_model(cls, model: SuggestionModel, dataset: "Dataset") -> "Suggestion":
125109
model.question_name = question.name
126110
model.value = cls.__from_model_value(model.value, question)
127111

128-
return cls(**model.model_dump())
112+
instance = cls(question.name, model.value)
113+
instance._model = model
114+
115+
return instance
129116

130117
def api_model(self) -> SuggestionModel:
131118
if self.record is None or self.record.dataset is None:
@@ -134,8 +121,8 @@ def api_model(self) -> SuggestionModel:
134121
question = self.record.dataset.settings.questions[self.question_name]
135122
return SuggestionModel(
136123
value=self.__to_model_value(self.value, question),
137-
question_name=self.question_name,
138-
question_id=self.question_id or question.id,
124+
question_name=question.name,
125+
question_id=question.id,
139126
type=self._model.type,
140127
score=self._model.score,
141128
agent=self._model.agent,

argilla/tests/integration/test_create_datasets.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414

1515
import pytest
1616

17-
from argilla import Argilla, Dataset, Settings, TextField, RatingQuestion, LabelQuestion, Workspace
17+
from argilla import (
18+
Argilla,
19+
Dataset,
20+
Settings,
21+
TextField,
22+
RatingQuestion,
23+
LabelQuestion,
24+
Workspace,
25+
VectorField,
26+
TermsMetadataProperty,
27+
)
1828
from argilla.settings._task_distribution import TaskDistribution
1929

2030

@@ -127,20 +137,41 @@ def test_create_a_dataset_copy(self, client: Argilla, dataset_name: str):
127137
settings=Settings(
128138
fields=[TextField(name="text")],
129139
questions=[RatingQuestion(name="question", values=[1, 2, 3, 4, 5])],
140+
vectors=[VectorField(name="vector", dimensions=2)],
141+
metadata=[TermsMetadataProperty(name="terms")],
130142
),
131143
).create()
132144

133-
dataset.records.log([{"text": "This is a text"}])
145+
dataset.records.log(
146+
[
147+
{
148+
"text": "This is a text",
149+
"terms": ["a", "b"],
150+
"vector": [1, 2],
151+
"question": 3,
152+
}
153+
]
154+
)
134155

135156
new_dataset = Dataset(
136157
name=f"{dataset_name}_copy",
137158
settings=dataset.settings,
138159
).create()
139160

140-
records = list(dataset.records)
161+
records = list(dataset.records(with_vectors=True))
141162
new_dataset.records.log(records)
142163

143-
assert len(list(dataset.records)) == len(list(new_dataset.records))
164+
expected_records = list(dataset.records(with_vectors=True))
165+
records = list(new_dataset.records(with_vectors=True))
166+
assert len(expected_records) == len(records)
167+
assert len(records) == 1
168+
169+
record, expected_record = records[0], expected_records[0]
170+
171+
assert expected_record.metadata.to_dict() == record.metadata.to_dict()
172+
assert expected_record.vectors.to_dict() == record.vectors.to_dict()
173+
assert expected_record.suggestions.to_dict() == record.suggestions.to_dict()
174+
144175
assert dataset.distribution == new_dataset.distribution
145176

146177
def test_create_dataset_with_custom_task_distribution(self, client: Argilla, dataset_name: str):

0 commit comments

Comments
 (0)