Skip to content

Commit 0b21024

Browse files
burtenshawpre-commit-ci[bot]davidberenstein1957
authored
[FEATURE] pass dataset Settings to from_hub method (#5418)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR adds a settings parameter to the `from_hub` method so that it is compatible with datasets that do not have a `.argilla` directory. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David Berenstein <[email protected]>
1 parent 1eb44cb commit 0b21024

File tree

2 files changed

+115
-14
lines changed

2 files changed

+115
-14
lines changed

argilla/src/argilla/datasets/_export/_hub.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from typing import TYPE_CHECKING, Any, Optional, Type, Union
2020
from uuid import UUID
2121

22+
from argilla._exceptions._api import UnprocessableEntityError
23+
from argilla._exceptions._records import RecordsIngestionError
24+
from argilla._exceptions._settings import SettingsError
2225
from datasets.data_files import EmptyDatasetError
2326

2427
from argilla.datasets._export._disk import DiskImportExportMixin
@@ -28,7 +31,7 @@
2831
if TYPE_CHECKING:
2932
from datasets import Dataset as HFDataset
3033

31-
from argilla import Argilla, Dataset, Workspace
34+
from argilla import Argilla, Dataset, Workspace, Settings
3235

3336

3437
class HubImportExportMixin(DiskImportExportMixin):
@@ -110,6 +113,7 @@ def from_hub(
110113
workspace: Optional[Union["Workspace", str]] = None,
111114
client: Optional["Argilla"] = None,
112115
with_records: bool = True,
116+
settings: Optional["Settings"] = None,
113117
**kwargs: Any,
114118
):
115119
"""Loads a `Dataset` from the Hugging Face Hub.
@@ -128,17 +132,21 @@ def from_hub(
128132
from datasets import Dataset, DatasetDict, load_dataset
129133
from huggingface_hub import snapshot_download
130134

131-
# download both files in parallel
132-
folder_path = snapshot_download(
133-
repo_id=repo_id,
134-
repo_type="dataset",
135-
allow_patterns=cls._DEFAULT_CONFIGURATION_FILES,
136-
token=kwargs.get("token"),
137-
)
135+
if settings is not None:
136+
dataset = cls(name=name, settings=settings)
137+
dataset.create()
138+
else:
139+
# download configuration files from the hub
140+
folder_path = snapshot_download(
141+
repo_id=repo_id,
142+
repo_type="dataset",
143+
allow_patterns=cls._DEFAULT_CONFIGURATION_FILES,
144+
token=kwargs.get("token"),
145+
)
138146

139-
dataset = cls.from_disk(
140-
path=folder_path, workspace=workspace, name=name, client=client, with_records=with_records
141-
)
147+
dataset = cls.from_disk(
148+
path=folder_path, workspace=workspace, name=name, client=client, with_records=with_records
149+
)
142150

143151
if with_records:
144152
try:
@@ -150,8 +158,23 @@ def from_hub(
150158
f" are: {', '.join(hf_dataset.keys())}."
151159
)
152160
hf_dataset: Dataset = hf_dataset[list(hf_dataset.keys())[0]]
153-
154-
cls._log_dataset_records(hf_dataset=hf_dataset, dataset=dataset)
161+
for feature in hf_dataset.features:
162+
if feature not in dataset.settings.fields or feature not in dataset.settings.questions:
163+
warnings.warn(
164+
message=f"Feature {feature} in Hugging Face dataset is not defined in dataset settings."
165+
)
166+
warnings.warn(
167+
message=f"Available fields: {dataset.settings.fields}. Available questions: {dataset.settings.questions}."
168+
)
169+
try:
170+
cls._log_dataset_records(hf_dataset=hf_dataset, dataset=dataset)
171+
except (RecordsIngestionError, UnprocessableEntityError) as e:
172+
if settings is not None:
173+
raise SettingsError(
174+
message=f"Failed to load records from Hugging Face dataset. Defined settings do not match dataset schema {hf_dataset.features}"
175+
) from e
176+
else:
177+
raise e
155178
except EmptyDatasetError:
156179
warnings.warn(
157180
message="Trying to load a dataset `with_records=True` but dataset does not contain any records.",

argilla/tests/integration/test_export_dataset.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import argilla as rg
2424
import pytest
25-
from argilla._exceptions import ConflictError
25+
from argilla._exceptions import ConflictError, SettingsError
2626
from huggingface_hub.utils._errors import BadRequestError, FileMetadataError, HfHubHTTPError
2727

2828
_RETRIES = 5
@@ -202,3 +202,81 @@ def test_import_dataset_from_hub(
202202

203203
assert new_dataset.settings.fields[0].name == "text"
204204
assert new_dataset.settings.questions[0].name == "label"
205+
206+
@pytest.mark.parametrize("with_records_import", [True, False])
207+
def test_import_dataset_from_hub_using_settings(
208+
self,
209+
token: str,
210+
dataset: rg.Dataset,
211+
client,
212+
mock_data: List[dict[str, Any]],
213+
with_records_export: bool,
214+
with_records_import: bool,
215+
):
216+
repo_id = (
217+
f"argilla-internal-testing/test_import_dataset_from_hub_using_settings_with_records{with_records_export}"
218+
)
219+
dataset.records.log(records=mock_data)
220+
221+
dataset.to_hub(repo_id=repo_id, with_records=with_records_export, token=token)
222+
settings = rg.Settings(
223+
fields=[
224+
rg.TextField(name="text"),
225+
],
226+
questions=[
227+
rg.LabelQuestion(name="label", labels=["positive", "negative"]),
228+
rg.LabelQuestion(name="extra_label", labels=["extra_positive", "extra_negative"]),
229+
],
230+
)
231+
if with_records_import and not with_records_export:
232+
with pytest.warns(
233+
expected_warning=UserWarning,
234+
match="Trying to load a dataset `with_records=True` but dataset does not contain any records.",
235+
):
236+
new_dataset = rg.Dataset.from_hub(
237+
repo_id=repo_id, client=client, with_records=with_records_import, token=token, settings=settings
238+
)
239+
else:
240+
new_dataset = rg.Dataset.from_hub(
241+
repo_id=repo_id, client=client, with_records=with_records_import, token=token, settings=settings
242+
)
243+
244+
if with_records_import and with_records_export:
245+
for i, record in enumerate(new_dataset.records(with_suggestions=True)):
246+
assert record.fields["text"] == mock_data[i]["text"]
247+
assert record.suggestions["label"].value == mock_data[i]["label"]
248+
else:
249+
assert len(new_dataset.records.to_list()) == 0
250+
251+
assert new_dataset.settings.fields[0].name == "text"
252+
assert new_dataset.settings.questions[0].name == "label"
253+
assert new_dataset.settings.questions[1].name == "extra_label"
254+
assert len(new_dataset.settings.questions[1].labels) == 2
255+
assert new_dataset.settings.questions[1].labels[0] == "extra_positive"
256+
assert new_dataset.settings.questions[1].labels[1] == "extra_negative"
257+
258+
def test_import_dataset_from_hub_using_wrong_settings(
259+
self,
260+
token: str,
261+
dataset: rg.Dataset,
262+
client,
263+
mock_data: List[dict[str, Any]],
264+
with_records_export: bool,
265+
):
266+
repo_id = f"argilla-internal-testing/test_import_dataset_from_hub_using_wrong_settings_with_records_{with_records_export}"
267+
dataset.records.log(records=mock_data)
268+
269+
dataset.to_hub(repo_id=repo_id, with_records=with_records_export, token=token)
270+
settings = rg.Settings(
271+
fields=[
272+
rg.TextField(name="text"),
273+
],
274+
questions=[
275+
rg.RatingQuestion(name="label", values=[1, 2, 3, 4, 5]),
276+
],
277+
)
278+
if with_records_export:
279+
with pytest.raises(SettingsError):
280+
rg.Dataset.from_hub(repo_id=repo_id, client=client, token=token, settings=settings)
281+
else:
282+
rg.Dataset.from_hub(repo_id=repo_id, client=client, token=token, settings=settings)

0 commit comments

Comments
 (0)