Skip to content

Commit 67aee5a

Browse files
Feature/persistent-record-mappings-in-settings (#5466)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR brings changes to persist the record mapping dict used in the log function within dataset settings. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - New feature (non-breaking change which adds functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 84c8aa7 commit 67aee5a

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

argilla/src/argilla/datasets/_resource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
self._model = DatasetModel(name=name)
7474
self._settings = settings._copy() if settings else Settings(_dataset=self)
7575
self._settings.dataset = self
76-
self.__records = DatasetRecords(client=self._client, dataset=self)
76+
self.__records = DatasetRecords(client=self._client, dataset=self, mapping=self._settings.mapping)
7777

7878
#####################
7979
# Properties #

argilla/src/argilla/records/_dataset_records.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,17 @@ class DatasetRecords(Iterable[Record], LoggingMixin):
139139
DEFAULT_BATCH_SIZE = 256
140140
DEFAULT_DELETE_BATCH_SIZE = 64
141141

142-
def __init__(self, client: "Argilla", dataset: "Dataset"):
142+
def __init__(
143+
self, client: "Argilla", dataset: "Dataset", mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None
144+
):
143145
"""Initializes a DatasetRecords object with a client and a dataset.
144146
Args:
145147
client: An Argilla client object.
146148
dataset: A Dataset object.
147149
"""
148150
self.__client = client
149151
self.__dataset = dataset
152+
self._mapping = mapping or {}
150153
self._api = self.__client.api.records
151154

152155
def __iter__(self):
@@ -380,6 +383,7 @@ def _ingest_records(
380383
) -> List[RecordModel]:
381384
"""Ingests records from a list of dictionaries, a Hugging Face Dataset, or a list of Record objects."""
382385

386+
mapping = mapping or self._mapping
383387
if len(records) == 0:
384388
raise ValueError("No records provided to ingest.")
385389

argilla/src/argilla/settings/_resource.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
guidelines: Optional[str] = None,
5151
allow_extra_metadata: bool = False,
5252
distribution: Optional[TaskDistribution] = None,
53+
mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
5354
_dataset: Optional["Dataset"] = None,
5455
) -> None:
5556
"""
@@ -64,11 +65,13 @@ def __init__(
6465
Dataset. Defaults to False.
6566
distribution (TaskDistribution): The annotation task distribution configuration.
6667
Default to DEFAULT_TASK_DISTRIBUTION
68+
mapping (Dict[str, Union[str, Sequence[str]]]): A dictionary that maps incoming data names to Argilla dataset attributes in DatasetRecords.
6769
"""
6870
super().__init__(client=_dataset._client if _dataset else None)
6971

7072
self._dataset = _dataset
7173
self._distribution = distribution
74+
self._mapping = mapping
7275
self.__guidelines = self.__process_guidelines(guidelines)
7376
self.__allow_extra_metadata = allow_extra_metadata
7477

@@ -137,6 +140,14 @@ def distribution(self) -> TaskDistribution:
137140
def distribution(self, value: TaskDistribution) -> None:
138141
self._distribution = value
139142

143+
@property
144+
def mapping(self) -> Dict[str, Union[str, Sequence[str]]]:
145+
return self._mapping
146+
147+
@mapping.setter
148+
def mapping(self, value: Dict[str, Union[str, Sequence[str]]]):
149+
self._mapping = value
150+
140151
@property
141152
def dataset(self) -> "Dataset":
142153
return self._dataset
@@ -220,6 +231,7 @@ def serialize(self):
220231
"metadata": self.metadata.serialize(),
221232
"allow_extra_metadata": self.allow_extra_metadata,
222233
"distribution": self.distribution.to_dict(),
234+
"mapping": self.mapping,
223235
}
224236
except Exception as e:
225237
raise ArgillaSerializeError(f"Failed to serialize the settings. {e.__class__.__name__}") from e
@@ -271,6 +283,7 @@ def _from_dict(cls, settings_dict: dict) -> "Settings":
271283
guidelines = settings_dict.get("guidelines")
272284
distribution = settings_dict.get("distribution")
273285
allow_extra_metadata = settings_dict.get("allow_extra_metadata")
286+
mapping = settings_dict.get("mapping")
274287

275288
questions = [question_from_dict(question) for question in settings_dict.get("questions", [])]
276289
fields = [_field_from_dict(field) for field in fields]
@@ -280,6 +293,9 @@ def _from_dict(cls, settings_dict: dict) -> "Settings":
280293
if distribution:
281294
distribution = TaskDistribution.from_dict(distribution)
282295

296+
if mapping:
297+
mapping = cls._validate_mapping(mapping)
298+
283299
return cls(
284300
questions=questions,
285301
fields=fields,
@@ -288,6 +304,7 @@ def _from_dict(cls, settings_dict: dict) -> "Settings":
288304
guidelines=guidelines,
289305
allow_extra_metadata=allow_extra_metadata,
290306
distribution=distribution,
307+
mapping=mapping,
291308
)
292309

293310
def _copy(self) -> "Settings":
@@ -362,6 +379,18 @@ def _validate_duplicate_names(self) -> None:
362379
)
363380
dataset_properties_by_name[property.name] = property
364381

382+
@classmethod
383+
def _validate_mapping(cls, mapping: Dict[str, Union[str, Sequence[str]]]) -> None:
384+
validate_mapping = {}
385+
for key, value in mapping.items():
386+
if isinstance(value, str):
387+
validate_mapping[key] = value
388+
elif isinstance(value, list) or isinstance(value, tuple):
389+
validate_mapping[key] = tuple(value)
390+
else:
391+
raise SettingsError(f"Invalid mapping value for key {key!r}: {value}")
392+
return validate_mapping
393+
365394
def __process_guidelines(self, guidelines):
366395
if guidelines is None:
367396
return guidelines
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2024-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tempfile import TemporaryDirectory
16+
from uuid import uuid4
17+
18+
import pytest
19+
20+
import argilla as rg
21+
22+
23+
@pytest.fixture
24+
def dataset():
25+
mock_mapping = {
26+
"true_label": "label.response",
27+
"my_label": "label.suggestion.value",
28+
"score": "label.suggestion.score",
29+
"model": "label.suggestion.agent",
30+
"my_prompt": ("prompt_field", "prompt_question"),
31+
}
32+
settings = rg.Settings(
33+
fields=[rg.TextField(name="prompt_field")],
34+
questions=[
35+
rg.LabelQuestion(name="label", labels=["negative", "positive"]),
36+
rg.TextQuestion(name="prompt_question"),
37+
],
38+
metadata=[rg.FloatMetadataProperty(name="score")],
39+
vectors=[rg.VectorField(name="vector", dimensions=3)],
40+
mapping=mock_mapping,
41+
)
42+
workspace = rg.Workspace(name="workspace", id=uuid4())
43+
dataset = rg.Dataset(
44+
name="test_dataset",
45+
settings=settings,
46+
workspace=workspace,
47+
)
48+
return dataset
49+
50+
51+
def test_settings_with_record_mapping(dataset):
52+
mock_user_id = uuid4()
53+
record_api_models = dataset.records._ingest_records(
54+
records=[
55+
{
56+
"my_prompt": "What is the capital of France?",
57+
"my_label": "positive",
58+
"true_label": "positive",
59+
"score": 0.9,
60+
"model": "model_name",
61+
}
62+
],
63+
user_id=mock_user_id,
64+
)
65+
record = record_api_models[0]
66+
assert record.fields["prompt_field"] == "What is the capital of France?"
67+
assert record.suggestions[0].value == "positive"
68+
assert record.suggestions[0].question_name == "label"
69+
assert record.suggestions[0].score == 0.9
70+
assert record.suggestions[0].agent == "model_name"
71+
assert record.responses[0].values["label"]["value"] == "positive"
72+
assert record.responses[0].user_id == mock_user_id
73+
74+
record = record_api_models[0]
75+
suggestions = [s.value for s in record.suggestions]
76+
assert record.fields["prompt_field"] == "What is the capital of France?"
77+
assert "positive" in suggestions
78+
assert "What is the capital of France?" in suggestions
79+
80+
81+
def test_settings_with_record_mapping_export(dataset):
82+
with TemporaryDirectory() as temp_dir:
83+
path = f"{temp_dir}/test_dataset.json"
84+
dataset.settings.to_json(path)
85+
loaded_settings = rg.Settings.from_json(path)
86+
87+
assert dataset.settings.mapping == loaded_settings.mapping
88+
assert dataset.settings == loaded_settings

0 commit comments

Comments
 (0)