Skip to content

Commit 0348f4e

Browse files
burtenshawfrascuchonpre-commit-ci[bot]
authored
[FEAT] Make adding and accessing suggestion and response from a record consistent (#5056)
This PR makes adding and accessing suggestion and response from a record consistent. It does that by: - implementing an `add` method in record suggestions and responses - switch record suggetions and responses to key not index and attributeaccess **Type of change** - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [x] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - updated and replaced `test_update_records` - updated all other assertions in tests - **Checklist** - [x] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Francisco Aranda <[email protected]> Co-authored-by: Ben Burtenshaw <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dd1d81d commit 0348f4e

17 files changed

+298
-224
lines changed

argilla/docs/how_to_guides/record.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ for record in dataset.records(
415415

416416
# Access the responses of the record
417417
for response in record.responses:
418-
print(record.question_name.value)
418+
print(record.["<question_name>"].value)
419419
```
420420

421421
## Update records

argilla/docs/reference/argilla/records/records.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ for record in dataset.records(with_metadata=True):
4747
record.metadata = {"department": "toys"}
4848
```
4949

50-
For changes to take effect, the user must call the `update` method on the `Dataset` object, or pass the updated records to `Dataset.records.log`.
50+
For changes to take effect, the user must call the `update` method on the `Dataset` object, or pass the updated records to `Dataset.records.log`. All core record atttributes can be updated in this way. Check their respective documentation for more information: [Suggestions](suggestions.md), [Responses](responses.md), [Metadata](metadata.md), [Vectors](vectors/md).
51+
5152

5253
---
5354

argilla/docs/reference/argilla/records/responses.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,26 @@ Responses can be accessed from a `Record` via their question name as an attribut
4444
# iterate over the records and responses
4545

4646
for record in dataset.records:
47-
for response in record.responses.label:
47+
for response in record.responses["label"]: # (1)
4848
print(response.value)
4949
print(response.user_id)
5050

5151
# validate that the record has a response
5252

5353
for record in dataset.records:
54-
if record.responses.label:
55-
for response in record.responses.label:
54+
if record.responses["label"]:
55+
for response in record.responses["label"]:
5656
print(response.value)
5757
print(response.user_id)
58+
else:
59+
record.responses.add(
60+
rg.Response("label", "positive", user_id=user.id)
61+
) # (2)
5862

5963
```
64+
1. Access the responses for the question named `label` for each record like a dictionary containing a list of `Response` objects.
65+
2. Add a response to the record if it does not already have one.
66+
6067

6168
---
6269

argilla/docs/reference/argilla/records/suggestions.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,22 @@ Just like responses, suggestions can be accessed from a `Record` via their quest
6464

6565
```python
6666
for record in dataset.records(with_suggestions=True):
67-
print(record.suggestions.label)
67+
print(record.suggestions["label"].value)
6868
```
6969

70+
We can also add suggestions to records as we iterate over them using the `add` method:
71+
72+
```python
73+
for record in dataset.records(with_suggestions=True):
74+
if not record.suggestions["label"]: # (1)
75+
record.suggestions.add(
76+
rg.Suggestion("positive", "label", score=0.9, agent="model_name")
77+
) # (2)
78+
```
79+
80+
1. Validate that the record has a suggestion
81+
2. Add a suggestion to the record if it does not already have one
82+
7083
---
7184

7285
## Class Reference

argilla/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ dynamic = ["version"]
1313
dependencies = [
1414
"httpx>=0.26.0",
1515
"pydantic>=2.6.0, <3.0.0",
16-
"argilla-v1[listeners]"
16+
"argilla-v1[listeners]",
17+
"tqdm>=4.60.0",
18+
"rich>=10.0.0",
1719
]
1820

1921
[project.optional-dependencies]

argilla/src/argilla/records/_dataset_records.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def log(
211211
mapping: Optional[Dict[str, str]] = None,
212212
user_id: Optional[UUID] = None,
213213
batch_size: int = DEFAULT_BATCH_SIZE,
214-
) -> List[Record]:
214+
) -> "DatasetRecords":
215215
"""Add or update records in a dataset on the server using the provided records.
216216
If the record includes a known `id` field, the record will be updated.
217217
If the record does not include a known `id` field, the record will be added as a new record.
@@ -253,7 +253,7 @@ def log(
253253
level="info",
254254
)
255255

256-
return created_or_updated
256+
return self
257257

258258
def delete(
259259
self,

argilla/src/argilla/records/_resource.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,7 @@ def __init__(self, responses: List[Response], record: Record) -> None:
321321
def __iter__(self):
322322
return iter(self.__responses)
323323

324-
def __getitem__(self, index: int):
325-
return self.__responses[index]
326-
327-
def __getattr__(self, name) -> List[Response]:
324+
def __getitem__(self, name: str):
328325
return self.__responses_by_question_name[name]
329326

330327
def __repr__(self) -> str:
@@ -352,6 +349,15 @@ def api_models(self) -> List[UserResponseModel]:
352349
for responses in responses_by_user_id.values()
353350
]
354351

352+
def add(self, response: Response) -> None:
353+
"""Adds a response to the record and updates the record. Records can have multiple responses per question.
354+
Args:
355+
response: The response to add.
356+
"""
357+
response.record = self.record
358+
self.__responses.append(response)
359+
self.__responses_by_question_name[response.question_name].append(response)
360+
355361

356362
class RecordSuggestions(Iterable[Suggestion]):
357363
"""This is a container class for the suggestions of a Record.
@@ -360,17 +366,17 @@ class RecordSuggestions(Iterable[Suggestion]):
360366

361367
def __init__(self, suggestions: List[Suggestion], record: Record) -> None:
362368
self.record = record
363-
364-
self.__suggestions = suggestions or []
365-
for suggestion in self.__suggestions:
369+
self._suggestion_by_question_name: Dict[str, Suggestion] = {}
370+
suggestions = suggestions or []
371+
for suggestion in suggestions:
366372
suggestion.record = self.record
367-
setattr(self, suggestion.question_name, suggestion)
373+
self._suggestion_by_question_name[suggestion.question_name] = suggestion
368374

369375
def __iter__(self):
370-
return iter(self.__suggestions)
376+
return iter(self._suggestion_by_question_name.values())
371377

372-
def __getitem__(self, index: int):
373-
return self.__suggestions[index]
378+
def __getitem__(self, question_name: str):
379+
return self._suggestion_by_question_name[question_name]
374380

375381
def __repr__(self) -> str:
376382
return self.to_dict().__repr__()
@@ -380,14 +386,24 @@ def to_dict(self) -> Dict[str, List[str]]:
380386
Returns:
381387
A dictionary of suggestions.
382388
"""
383-
suggestion_dict: dict = {}
384-
for suggestion in self.__suggestions:
385-
suggestion_dict[suggestion.question_name] = {
389+
suggestion_dict = {}
390+
for question_name, suggestion in self._suggestion_by_question_name.items():
391+
suggestion_dict[question_name] = {
386392
"value": suggestion.value,
387393
"score": suggestion.score,
388394
"agent": suggestion.agent,
389395
}
390396
return suggestion_dict
391397

392398
def api_models(self) -> List[SuggestionModel]:
393-
return [suggestion.api_model() for suggestion in self.__suggestions]
399+
suggestions = self._suggestion_by_question_name.values()
400+
return [suggestion.api_model() for suggestion in suggestions]
401+
402+
def add(self, suggestion: Suggestion) -> None:
403+
"""Adds a suggestion to the record and updates the record. Records can have only one suggestion per question, so
404+
adding a new suggestion will overwrite the previous suggestion.
405+
Args:
406+
suggestion: The suggestion to add.
407+
"""
408+
suggestion.record = self.record
409+
self._suggestion_by_question_name[suggestion.question_name] = suggestion

argilla/tests/integration/test_add_records.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_add_dict_records(client: Argilla):
129129

130130
for record, data in zip(ds.records(batch_size=1, with_suggestions=True), mock_data):
131131
assert record.id == data["id"]
132-
assert record.suggestions.label.value == data["label"]
132+
assert record.suggestions["label"].value == data["label"]
133133

134134

135135
def test_add_records_with_suggestions(client) -> None:
@@ -188,21 +188,21 @@ def test_add_records_with_suggestions(client) -> None:
188188
dataset_records = list(dataset.records(with_suggestions=True))
189189

190190
assert dataset_records[0].id == str(mock_data[0]["id"])
191-
assert dataset_records[0].suggestions.comment.value == "I'm doing great, thank you!"
192-
assert dataset_records[0].suggestions.comment.score is None
193-
assert dataset_records[0].suggestions.topics.value == ["topic1", "topic2"]
194-
assert dataset_records[0].suggestions.topics.score == [0.9, 0.8]
191+
assert dataset_records[0].suggestions["comment"].value == "I'm doing great, thank you!"
192+
assert dataset_records[0].suggestions["comment"].score is None
193+
assert dataset_records[0].suggestions["topics"].value == ["topic1", "topic2"]
194+
assert dataset_records[0].suggestions["topics"].score == [0.9, 0.8]
195195

196196
assert dataset_records[1].fields["text"] == mock_data[1]["text"]
197-
assert dataset_records[1].suggestions.comment.value == "I'm doing great, thank you!"
198-
assert dataset_records[1].suggestions.comment.score is None
199-
assert dataset_records[1].suggestions.topics.value == ["topic3"]
200-
assert dataset_records[1].suggestions.topics.score == [0.9]
197+
assert dataset_records[1].suggestions["comment"].value == "I'm doing great, thank you!"
198+
assert dataset_records[1].suggestions["comment"].score is None
199+
assert dataset_records[1].suggestions["topics"].value == ["topic3"]
200+
assert dataset_records[1].suggestions["topics"].score == [0.9]
201201

202-
assert dataset_records[2].suggestions.comment.value == "I'm doing great, thank you!"
203-
assert dataset_records[2].suggestions.comment.score is None
204-
assert dataset_records[2].suggestions.topics.value == ["topic1", "topic2", "topic3"]
205-
assert dataset_records[2].suggestions.topics.score == [0.9, 0.8, 0.7]
202+
assert dataset_records[2].suggestions["comment"].value == "I'm doing great, thank you!"
203+
assert dataset_records[2].suggestions["comment"].score is None
204+
assert dataset_records[2].suggestions["topics"].value == ["topic1", "topic2", "topic3"]
205+
assert dataset_records[2].suggestions["topics"].score == [0.9, 0.8, 0.7]
206206

207207

208208
def test_add_records_with_responses(client) -> None:
@@ -259,8 +259,8 @@ def test_add_records_with_responses(client) -> None:
259259
for record, mock_record in zip(dataset_records, mock_data):
260260
assert record.id == str(mock_record["id"])
261261
assert record.fields["text"] == mock_record["text"]
262-
assert record.responses.label[0].value == mock_record["my_label"]
263-
assert record.responses.label[0].user_id == user.id
262+
assert record.responses["label"][0].value == mock_record["my_label"]
263+
assert record.responses["label"][0].user_id == user.id
264264

265265

266266
def test_add_records_with_responses_and_suggestions(client) -> None:
@@ -320,9 +320,9 @@ def test_add_records_with_responses_and_suggestions(client) -> None:
320320

321321
assert dataset_records[0].id == str(mock_data[0]["id"])
322322
assert dataset_records[1].fields["text"] == mock_data[1]["text"]
323-
assert dataset_records[2].suggestions.label.value == "positive"
324-
assert dataset_records[2].responses.label[0].value == "negative"
325-
assert dataset_records[2].responses.label[0].user_id == user.id
323+
assert dataset_records[2].suggestions["label"].value == "positive"
324+
assert dataset_records[2].responses["label"][0].value == "negative"
325+
assert dataset_records[2].responses["label"][0].user_id == user.id
326326

327327

328328
def test_add_records_with_fields_mapped(client) -> None:
@@ -387,11 +387,11 @@ def test_add_records_with_fields_mapped(client) -> None:
387387

388388
assert dataset_records[0].id == str(mock_data[0]["id"])
389389
assert dataset_records[1].fields["text"] == mock_data[1]["x"]
390-
assert dataset_records[2].suggestions.label.value == "positive"
391-
assert dataset_records[2].suggestions.label.score == 0.5
392-
assert dataset_records[2].responses.label[0].value == "negative"
393-
assert dataset_records[2].responses.label[0].value == "negative"
394-
assert dataset_records[2].responses.label[0].user_id == user.id
390+
assert dataset_records[2].suggestions["label"].value == "positive"
391+
assert dataset_records[2].suggestions["label"].score == 0.5
392+
assert dataset_records[2].responses["label"][0].value == "negative"
393+
assert dataset_records[2].responses["label"][0].value == "negative"
394+
assert dataset_records[2].responses["label"][0].user_id == user.id
395395

396396

397397
def test_add_records_with_id_mapped(client) -> None:
@@ -448,9 +448,9 @@ def test_add_records_with_id_mapped(client) -> None:
448448

449449
assert dataset_records[0].id == str(mock_data[0]["uuid"])
450450
assert dataset_records[1].fields["text"] == mock_data[1]["x"]
451-
assert dataset_records[2].suggestions.label.value == "positive"
452-
assert dataset_records[2].responses.label[0].value == "negative"
453-
assert dataset_records[2].responses.label[0].user_id == user.id
451+
assert dataset_records[2].suggestions["label"].value == "positive"
452+
assert dataset_records[2].responses["label"][0].value == "negative"
453+
assert dataset_records[2].responses["label"][0].user_id == user.id
454454

455455

456456
def test_add_record_resources(client):
@@ -507,22 +507,22 @@ def test_add_record_resources(client):
507507
assert dataset.name == mock_dataset_name
508508

509509
assert dataset_records[0].id == str(mock_resources[0].id)
510-
assert dataset_records[0].suggestions.label.value == "positive"
511-
assert dataset_records[0].suggestions.label.score == 0.9
512-
assert dataset_records[0].suggestions.topics.value == ["topic1", "topic2"]
513-
assert dataset_records[0].suggestions.topics.score == [0.9, 0.8]
510+
assert dataset_records[0].suggestions["label"].value == "positive"
511+
assert dataset_records[0].suggestions["label"].score == 0.9
512+
assert dataset_records[0].suggestions["topics"].value == ["topic1", "topic2"]
513+
assert dataset_records[0].suggestions["topics"].score == [0.9, 0.8]
514514

515515
assert dataset_records[1].id == str(mock_resources[1].id)
516-
assert dataset_records[1].suggestions.label.value == "positive"
517-
assert dataset_records[1].suggestions.label.score == 0.9
518-
assert dataset_records[1].suggestions.topics.value == ["topic1", "topic2"]
519-
assert dataset_records[1].suggestions.topics.score == [0.9, 0.8]
516+
assert dataset_records[1].suggestions["label"].value == "positive"
517+
assert dataset_records[1].suggestions["label"].score == 0.9
518+
assert dataset_records[1].suggestions["topics"].value == ["topic1", "topic2"]
519+
assert dataset_records[1].suggestions["topics"].score == [0.9, 0.8]
520520

521521
assert dataset_records[2].id == str(mock_resources[2].id)
522-
assert dataset_records[2].suggestions.label.value == "positive"
523-
assert dataset_records[2].suggestions.label.score == 0.9
524-
assert dataset_records[2].suggestions.topics.value == ["topic1", "topic2"]
525-
assert dataset_records[2].suggestions.topics.score == [0.9, 0.8]
522+
assert dataset_records[2].suggestions["label"].value == "positive"
523+
assert dataset_records[2].suggestions["label"].score == 0.9
524+
assert dataset_records[2].suggestions["topics"].value == ["topic1", "topic2"]
525+
assert dataset_records[2].suggestions["topics"].score == [0.9, 0.8]
526526

527527

528528
def test_add_records_with_responses_and_same_schema_name(client: Argilla):
@@ -572,8 +572,8 @@ def test_add_records_with_responses_and_same_schema_name(client: Argilla):
572572
dataset_records = list(dataset.records(with_responses=True))
573573

574574
assert dataset_records[0].fields["text"] == mock_data[1]["text"]
575-
assert dataset_records[1].responses.label[0].value == "negative"
576-
assert dataset_records[1].responses.label[0].user_id == user.id
575+
assert dataset_records[1].responses["label"][0].value == "negative"
576+
assert dataset_records[1].responses["label"][0].user_id == user.id
577577

578578

579579
def test_add_records_objects_with_responses(client: Argilla):
@@ -631,17 +631,17 @@ def test_add_records_objects_with_responses(client: Argilla):
631631

632632
assert dataset.name == mock_dataset_name
633633
assert dataset_records[0].id == records[0].id
634-
assert dataset_records[0].responses.label[0].value == "negative"
635-
assert dataset_records[0].responses.label[0].status == "submitted"
634+
assert dataset_records[0].responses["label"][0].value == "negative"
635+
assert dataset_records[0].responses["label"][0].status == "submitted"
636636

637637
assert dataset_records[1].id == records[1].id
638-
assert dataset_records[1].responses.label[0].value == "positive"
639-
assert dataset_records[1].responses.label[0].status == "discarded"
638+
assert dataset_records[1].responses["label"][0].value == "positive"
639+
assert dataset_records[1].responses["label"][0].status == "discarded"
640640

641641
assert dataset_records[2].id == records[2].id
642-
assert dataset_records[2].responses.comment[0].value == "The comment"
643-
assert dataset_records[2].responses.comment[0].status == "draft"
642+
assert dataset_records[2].responses["comment"][0].value == "The comment"
643+
assert dataset_records[2].responses["comment"][0].status == "draft"
644644

645645
assert dataset_records[3].id == records[3].id
646-
assert dataset_records[3].responses.comment[0].value == "The comment"
647-
assert dataset_records[3].responses.comment[0].status == "draft"
646+
assert dataset_records[3].responses["comment"][0].value == "The comment"
647+
assert dataset_records[3].responses["comment"][0].status == "draft"

argilla/tests/integration/test_export_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_import_dataset_from_disk(dataset: rg.Dataset, client):
119119

120120
for i, record in enumerate(new_dataset.records(with_suggestions=True)):
121121
assert record.fields["text"] == mock_data[i]["text"]
122-
assert record.suggestions.label.value == mock_data[i]["label"]
122+
assert record.suggestions["label"].value == mock_data[i]["label"]
123123

124124
assert new_dataset.settings.fields[0].name == "text"
125125
assert new_dataset.settings.questions[0].name == "label"

argilla/tests/integration/test_export_records.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_export_records_from_json(dataset: rg.Dataset):
273273

274274
for i, record in enumerate(dataset.records(with_suggestions=True)):
275275
assert record.fields["text"] == mock_data[i]["text"]
276-
assert record.suggestions.label.value == mock_data[i]["label"]
276+
assert record.suggestions["label"].value == mock_data[i]["label"]
277277
assert record.id == str(mock_data[i]["id"])
278278

279279

@@ -329,5 +329,5 @@ def test_import_records_from_hf_dataset(dataset: rg.Dataset) -> None:
329329

330330
for i, record in enumerate(dataset.records(with_suggestions=True)):
331331
assert record.fields["text"] == mock_data[i]["text"]
332-
assert record.suggestions.label.value == mock_data[i]["label"]
332+
assert record.suggestions["label"].value == mock_data[i]["label"]
333333
assert record.id == str(mock_data[i]["id"])

0 commit comments

Comments
 (0)