Skip to content

Commit 7dbe0df

Browse files
burtenshawpre-commit-ci[bot]
authored andcommitted
[ENHANCEMENT] Make exporting records with image consistent (#5454)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR implements a casting step from data uri to pil in the export method for hf_datasets. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b95b89d commit 7dbe0df

File tree

5 files changed

+112
-8
lines changed

5 files changed

+112
-8
lines changed

argilla/src/argilla/records/_dataset_records.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def to_json(self, path: Union[Path, str]) -> Path:
121121
return JsonIO.to_json(records=list(self), path=path)
122122

123123
def to_datasets(self) -> "HFDataset":
124-
return HFDatasetsIO.to_datasets(records=list(self))
124+
return HFDatasetsIO.to_datasets(records=list(self), dataset=self.__dataset)
125125

126126

127127
class DatasetRecords(Iterable[Record], LoggingMixin):

argilla/src/argilla/records/_io/_datasets.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from datasets import IterableDataset, Image
2020

2121
from argilla.records._io._generic import GenericIO
22-
from argilla._helpers._media import pil_to_data_uri
22+
from argilla._helpers._media import pil_to_data_uri, uncast_image
2323

2424
if TYPE_CHECKING:
2525
from argilla.records import Record
26+
from argilla.datasets import Dataset
2627

2728

2829
class HFDatasetsIO:
@@ -39,16 +40,19 @@ def _is_hf_dataset(dataset: Any) -> bool:
3940
return isinstance(dataset, HFDataset)
4041

4142
@staticmethod
42-
def to_datasets(records: List["Record"]) -> HFDataset:
43+
def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset:
4344
"""
4445
Export the records to a Hugging Face dataset.
4546
4647
Returns:
4748
The dataset containing the records.
4849
"""
4950
record_dicts = GenericIO.to_dict(records, flatten=True)
50-
dataset = HFDataset.from_dict(record_dicts)
51-
return dataset
51+
hf_dataset = HFDataset.from_dict(record_dicts)
52+
image_fields = HFDatasetsIO._get_image_fields(schema=dataset.schema)
53+
if image_fields:
54+
hf_dataset = HFDatasetsIO._cast_uris_as_images(hf_dataset=hf_dataset, columns=image_fields)
55+
return hf_dataset
5256

5357
@staticmethod
5458
def _record_dicts_from_datasets(dataset: HFDataset) -> List[Dict[str, Union[str, float, int, list]]]:
@@ -72,6 +76,18 @@ def _record_dicts_from_datasets(dataset: HFDataset) -> List[Dict[str, Union[str,
7276
record_dicts.append(example)
7377
return record_dicts
7478

79+
@staticmethod
80+
def _get_image_fields(schema: Dict) -> List[str]:
81+
"""Get the names of the Argilla fields that contain image data.
82+
83+
Parameters:
84+
dataset (Dataset): The dataset to check.
85+
86+
Returns:
87+
List[str]: The names of the Argilla fields that contain image data.
88+
"""
89+
return [field_name for field_name, field in schema.items() if field.type == "image"]
90+
7591
@staticmethod
7692
def _get_image_features(dataset: "HFDataset") -> List[str]:
7793
"""Check if the Hugging Face dataset contains image features.
@@ -114,3 +130,32 @@ def batch_fn(batch):
114130
hf_dataset = hf_dataset.rename_column(original_column_name=unique_identifier, new_column_name=column)
115131

116132
return hf_dataset
133+
134+
@staticmethod
135+
def _cast_uris_as_images(hf_dataset: "HFDataset", columns: List[str]) -> "HFDataset":
136+
"""Cast the image features in the Hugging Face dataset as PIL images.
137+
138+
Parameters:
139+
hf_dataset (HFDataset): The Hugging Face dataset to cast.
140+
columns (List[str]): The names of the columns containing the image features.
141+
142+
Returns:
143+
HFDataset: The Hugging Face dataset with image features cast as PIL images.
144+
"""
145+
unique_identifier = uuid4().hex
146+
147+
def batch_fn(batch):
148+
images = [uncast_image(sample) for sample in batch]
149+
return {unique_identifier: images}
150+
151+
for column in columns:
152+
hf_dataset = hf_dataset.map(
153+
function=batch_fn,
154+
with_indices=False,
155+
batched=True,
156+
input_columns=[column],
157+
remove_columns=[column],
158+
)
159+
hf_dataset = hf_dataset.rename_column(original_column_name=unique_identifier, new_column_name=column)
160+
161+
return hf_dataset

argilla/tests/integration/test_export_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def test_import_dataset_from_hub_using_wrong_settings(
333333
with_records_export: bool,
334334
):
335335
repo_id = f"argilla-internal-testing/test_import_dataset_from_hub_using_wrong_settings_with_records_{with_records_export}"
336+
mock_unique_name = f"test_import_dataset_from_hub_using_wrong_settings_{uuid.uuid4()}"
336337
dataset.records.log(records=mock_data)
337338

338339
dataset.to_hub(repo_id=repo_id, with_records=with_records_export, token=token)
@@ -346,6 +347,8 @@ def test_import_dataset_from_hub_using_wrong_settings(
346347
)
347348
if with_records_export:
348349
with pytest.raises(SettingsError):
349-
rg.Dataset.from_hub(repo_id=repo_id, client=client, token=token, settings=settings)
350+
rg.Dataset.from_hub(
351+
repo_id=repo_id, client=client, token=token, settings=settings, name=mock_unique_name
352+
)
350353
else:
351-
rg.Dataset.from_hub(repo_id=repo_id, client=client, token=token, settings=settings)
354+
rg.Dataset.from_hub(repo_id=repo_id, client=client, token=token, settings=settings, name=mock_unique_name)

argilla/tests/integration/test_export_records.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tempfile import TemporaryDirectory
2121

2222
import pytest
23+
from PIL import Image
2324
from datasets import Dataset as HFDataset
2425

2526
import argilla as rg
@@ -331,3 +332,47 @@ def test_import_records_from_hf_dataset(dataset: rg.Dataset) -> None:
331332
assert record.fields["text"] == mock_data[i]["text"]
332333
assert record.suggestions["label"].value == mock_data[i]["label"]
333334
assert record.id == str(mock_data[i]["id"])
335+
336+
337+
def test_export_records_with_images_to_hf_datasets(client):
338+
mock_dataset_name = "".join(random.choices(ascii_lowercase, k=16))
339+
settings = rg.Settings(
340+
fields=[
341+
rg.ImageField(name="image"),
342+
],
343+
questions=[
344+
rg.TextQuestion(name="label", use_markdown=False),
345+
],
346+
)
347+
dataset = rg.Dataset(
348+
name=mock_dataset_name,
349+
settings=settings,
350+
client=client,
351+
)
352+
dataset.create()
353+
mock_data = [
354+
{
355+
"image": Image.new("RGB", (100, 100)),
356+
"label": "positive",
357+
"id": uuid.uuid4(),
358+
},
359+
{
360+
"image": Image.new("RGB", (100, 100)),
361+
"label": "negative",
362+
"id": uuid.uuid4(),
363+
},
364+
{
365+
"image": Image.new("RGB", (100, 100)),
366+
"label": "positive",
367+
"id": uuid.uuid4(),
368+
},
369+
]
370+
dataset.records.log(records=mock_data)
371+
hf_dataset = dataset.records.to_datasets()
372+
373+
assert isinstance(hf_dataset, HFDataset)
374+
assert hf_dataset.num_rows == len(mock_data)
375+
assert "image" in hf_dataset.column_names
376+
assert "label.suggestion" in hf_dataset.column_names
377+
for i, image in enumerate(hf_dataset["image"]):
378+
assert isinstance(image, Image.Image)

argilla/tests/unit/test_io/test_hf_datasets.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222

2323
class TestHFDatasetsIO:
2424
def test_to_datasets_with_partial_values_in_records(self):
25+
mock_dataset = rg.Dataset(
26+
name="test",
27+
settings=rg.Settings(
28+
fields=[
29+
rg.TextField(name="field"),
30+
],
31+
questions=[
32+
rg.TextQuestion(name="question"),
33+
],
34+
),
35+
)
2536
records = [
2637
rg.Record(fields={"field": "The field"}, metadata={"a": "a"}),
2738
rg.Record(fields={"field": "Other field", "other": "Field"}, metadata={"b": "b"}),
@@ -44,7 +55,7 @@ def test_to_datasets_with_partial_values_in_records(self):
4455
),
4556
]
4657

47-
ds = HFDatasetsIO.to_datasets(records)
58+
ds = HFDatasetsIO.to_datasets(records, dataset=mock_dataset)
4859
assert ds.features == {
4960
"status": Value(dtype="string", id=None),
5061
"_server_id": Value(dtype="null", id=None),

0 commit comments

Comments
 (0)