Skip to content

Commit e28ef13

Browse files
burtenshawfrascuchonpre-commit-ci[bot]
authored
[REFACTOR] refactor from hub method to simplify method (#5420)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> Closes #<issue_number> **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Paco Aranda <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6815e5c commit e28ef13

File tree

2 files changed

+100
-32
lines changed

2 files changed

+100
-32
lines changed

argilla/src/argilla/datasets/_export/_hub.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
import warnings
1717
from collections import defaultdict
1818
from tempfile import TemporaryDirectory
19-
from typing import TYPE_CHECKING, Any, Optional, Type, Union
19+
from typing import TYPE_CHECKING, Any, Optional, Type, Union, Dict
2020
from uuid import UUID
2121

22+
from datasets import DatasetDict
23+
from datasets.data_files import EmptyDatasetError
24+
2225
from argilla._exceptions._api import UnprocessableEntityError
2326
from argilla._exceptions._records import RecordsIngestionError
2427
from argilla._exceptions._settings import SettingsError
25-
from datasets.data_files import EmptyDatasetError
26-
2728
from argilla.datasets._export._disk import DiskImportExportMixin
2829
from argilla.records._mapping import IngestedRecordMapper
2930
from argilla.responses import Response
@@ -72,6 +73,7 @@ def to_hub(
7273

7374
with TemporaryDirectory() as tmpdirname:
7475
config_dir = os.path.join(tmpdirname)
76+
7577
self.to_disk(path=config_dir, with_records=False)
7678

7779
if generate_card:
@@ -129,9 +131,12 @@ def from_hub(
129131
Returns:
130132
A `Dataset` loaded from the Hugging Face Hub.
131133
"""
132-
from datasets import Dataset, DatasetDict, load_dataset
134+
from datasets import load_dataset
133135
from huggingface_hub import snapshot_download
134136

137+
if name is None:
138+
name = repo_id.replace("/", "_")
139+
135140
if settings is not None:
136141
dataset = cls(name=name, settings=settings)
137142
dataset.create()
@@ -150,31 +155,9 @@ def from_hub(
150155

151156
if with_records:
152157
try:
153-
hf_dataset: Dataset = load_dataset(path=repo_id, **kwargs) # type: ignore
154-
if isinstance(hf_dataset, DatasetDict) and "split" not in kwargs:
155-
if len(hf_dataset.keys()) > 1:
156-
raise ValueError(
157-
"Only one dataset can be loaded at a time, use `split` to select a split, available splits"
158-
f" are: {', '.join(hf_dataset.keys())}."
159-
)
160-
hf_dataset: Dataset = hf_dataset[list(hf_dataset.keys())[0]]
161-
for feature in hf_dataset.features:
162-
if feature not in dataset.settings.fields or feature not in dataset.settings.questions:
163-
warnings.warn(
164-
message=f"Feature {feature} in Hugging Face dataset is not defined in dataset settings."
165-
)
166-
warnings.warn(
167-
message=f"Available fields: {dataset.settings.fields}. Available questions: {dataset.settings.questions}."
168-
)
169-
try:
170-
cls._log_dataset_records(hf_dataset=hf_dataset, dataset=dataset)
171-
except (RecordsIngestionError, UnprocessableEntityError) as e:
172-
if settings is not None:
173-
raise SettingsError(
174-
message=f"Failed to load records from Hugging Face dataset. Defined settings do not match dataset schema {hf_dataset.features}"
175-
) from e
176-
else:
177-
raise e
158+
hf_dataset = load_dataset(path=repo_id, **kwargs) # type: ignore
159+
hf_dataset = cls._get_dataset_split(hf_dataset=hf_dataset, **kwargs)
160+
cls._log_dataset_records(hf_dataset=hf_dataset, dataset=dataset)
178161
except EmptyDatasetError:
179162
warnings.warn(
180163
message="Trying to load a dataset `with_records=True` but dataset does not contain any records.",
@@ -221,9 +204,7 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
221204
records = []
222205
for idx, row in enumerate(hf_dataset):
223206
record = mapper(row)
224-
record.id = row.pop("id")
225207
for question_name, values in response_questions.items():
226-
response_users = {}
227208
response_values = values["responses"][idx]
228209
response_users = values["users"][idx]
229210
response_status = values["status"][idx]
@@ -240,4 +221,31 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
240221
)
241222
record.responses.add(response)
242223
records.append(record)
243-
dataset.records.log(records=records)
224+
225+
try:
226+
dataset.records.log(records=records)
227+
except (RecordsIngestionError, UnprocessableEntityError) as e:
228+
raise SettingsError(
229+
message=f"Failed to load records from Hugging Face dataset. Defined settings do not match dataset schema. Hugging face dataset features: {hf_dataset.features}. Argilla dataset settings : {dataset.settings}"
230+
) from e
231+
232+
@staticmethod
233+
def _get_dataset_split(hf_dataset: "HFDataset", split: Optional[str] = None, **kwargs: Dict) -> "HFDataset":
234+
"""Get a single dataset from a Hugging Face dataset.
235+
236+
Parameters:
237+
hf_dataset (HFDataset): The Hugging Face dataset to get a single dataset from.
238+
239+
Returns:
240+
HFDataset: The single dataset.
241+
"""
242+
243+
if isinstance(hf_dataset, DatasetDict) and split is None:
244+
split = next(iter(hf_dataset.keys()))
245+
if len(hf_dataset.keys()) > 1:
246+
warnings.warn(
247+
message=f"Multiple splits found in Hugging Face dataset. Using the first split: {split}. "
248+
f"Available splits are: {', '.join(hf_dataset.keys())}."
249+
)
250+
hf_dataset = hf_dataset[split]
251+
return hf_dataset

argilla/tests/integration/test_export_dataset.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,70 @@ def test_import_dataset_from_hub_using_settings(
250250

251251
assert new_dataset.settings.fields[0].name == "text"
252252
assert new_dataset.settings.questions[0].name == "label"
253+
254+
@pytest.mark.parametrize("with_records_import", [True, False])
255+
def test_import_dataset_from_hub_using_settings(
256+
self,
257+
token: str,
258+
dataset: rg.Dataset,
259+
client,
260+
mock_data: List[dict[str, Any]],
261+
with_records_export: bool,
262+
with_records_import: bool,
263+
):
264+
repo_id = (
265+
f"argilla-internal-testing/test_import_dataset_from_hub_using_settings_with_records{with_records_export}"
266+
)
267+
mock_dataset_name = f"test_import_dataset_from_hub_using_settings_{uuid.uuid4()}"
268+
dataset.records.log(records=mock_data)
269+
270+
dataset.to_hub(repo_id=repo_id, with_records=with_records_export, token=token)
271+
settings = rg.Settings(
272+
fields=[
273+
rg.TextField(name="text"),
274+
],
275+
questions=[
276+
rg.LabelQuestion(name="label", labels=["positive", "negative"]),
277+
rg.LabelQuestion(name="extra_label", labels=["extra_positive", "extra_negative"]),
278+
],
279+
)
280+
if with_records_import and not with_records_export:
281+
with pytest.warns(
282+
expected_warning=UserWarning,
283+
match="Trying to load a dataset `with_records=True` but dataset does not contain any records.",
284+
):
285+
new_dataset = rg.Dataset.from_hub(
286+
repo_id=repo_id,
287+
client=client,
288+
with_records=with_records_import,
289+
token=token,
290+
settings=settings,
291+
name=mock_dataset_name,
292+
)
293+
else:
294+
new_dataset = rg.Dataset.from_hub(
295+
repo_id=repo_id,
296+
client=client,
297+
with_records=with_records_import,
298+
token=token,
299+
settings=settings,
300+
name=mock_dataset_name,
301+
)
302+
303+
if with_records_import and with_records_export:
304+
for i, record in enumerate(new_dataset.records(with_suggestions=True)):
305+
assert record.fields["text"] == mock_data[i]["text"]
306+
assert record.suggestions["label"].value == mock_data[i]["label"]
307+
else:
308+
assert len(new_dataset.records.to_list()) == 0
309+
310+
assert new_dataset.settings.fields[0].name == "text"
311+
assert new_dataset.settings.questions[0].name == "label"
253312
assert new_dataset.settings.questions[1].name == "extra_label"
254313
assert len(new_dataset.settings.questions[1].labels) == 2
255314
assert new_dataset.settings.questions[1].labels[0] == "extra_positive"
256315
assert new_dataset.settings.questions[1].labels[1] == "extra_negative"
316+
assert new_dataset.name == mock_dataset_name
257317

258318
def test_import_dataset_from_hub_using_wrong_settings(
259319
self,

0 commit comments

Comments
 (0)