Skip to content

Commit 0c1a8f9

Browse files
authored
Fix and improve working with datasets (#78)
* Add user defined dataset splits * Make encode labels method private * Fix encode_labels method usage * Add n_classes validation in DatasetReader * Add methods to export Dataset to json and dict * Change methods order in Dataset * Add tests for dataset validation * Forbid extra splits in Dataset * Add PT011 rule ignore for file * Refactor tests
1 parent ecad794 commit 0c1a8f9

File tree

8 files changed

+301
-120
lines changed

8 files changed

+301
-120
lines changed

autointent/_dataset/_dataset.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""File with Dataset definition."""
22

3+
import json
34
from collections import defaultdict
45
from functools import cached_property
56
from pathlib import Path
@@ -48,12 +49,15 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #
4849

4950
self.intents = intents
5051

52+
self._encoded_labels = False
53+
54+
if self.multilabel:
55+
self._encode_labels()
56+
5157
oos_split = self._create_oos_split()
5258
if oos_split is not None:
5359
self[Split.OOS] = oos_split
5460

55-
self._encoded_labels = False
56-
5761
@property
5862
def multilabel(self) -> bool:
5963
"""
@@ -71,31 +75,31 @@ def n_classes(self) -> int:
7175
7276
:return: Number of classes.
7377
"""
74-
return self.get_n_classes(Split.TRAIN)
78+
return len(self.intents)
7579

7680
@classmethod
77-
def from_json(cls, filepath: str | Path) -> "Dataset":
81+
def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
7882
"""
79-
Load a dataset from a JSON file.
83+
Load a dataset from a dictionary mapping.
8084
81-
:param filepath: Path to the JSON file.
85+
:param mapping: Dictionary representing the dataset.
8286
:return: Initialized Dataset object.
8387
"""
84-
from ._reader import JsonReader
88+
from ._reader import DictReader
8589

86-
return JsonReader().read(filepath)
90+
return DictReader().read(mapping)
8791

8892
@classmethod
89-
def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
93+
def from_json(cls, filepath: str | Path) -> "Dataset":
9094
"""
91-
Load a dataset from a dictionary mapping.
95+
Load a dataset from a JSON file.
9296
93-
:param mapping: Dictionary representing the dataset.
97+
:param filepath: Path to the JSON file.
9498
:return: Initialized Dataset object.
9599
"""
96-
from ._reader import DictReader
100+
from ._reader import JsonReader
97101

98-
return DictReader().read(mapping)
102+
return JsonReader().read(filepath)
99103

100104
@classmethod
101105
def from_hub(cls, repo_id: str) -> "Dataset":
@@ -113,34 +117,35 @@ def from_hub(cls, repo_id: str) -> "Dataset":
113117
intents=[Intent.model_validate(intent) for intent in intents],
114118
)
115119

116-
def dump(self) -> dict[str, list[dict[str, Any]]]:
120+
def to_multilabel(self) -> "Dataset":
117121
"""
118-
Convert the dataset splits to a dictionary of lists.
122+
Convert dataset labels to multilabel format.
119123
120-
:return: Dictionary containing dataset splits as lists.
124+
:return: Self, with labels converted to multilabel.
121125
"""
122-
return {split_name: split.to_list() for split_name, split in self.items()}
126+
for split_name, split in self.items():
127+
self[split_name] = split.map(self._to_multilabel)
128+
self._encode_labels()
129+
return self
123130

124-
def encode_labels(self) -> "Dataset":
131+
def to_dict(self) -> dict[str, list[dict[str, Any]]]:
125132
"""
126-
Encode dataset labels into one-hot or multilabel format.
133+
Convert the dataset splits and intents to a dictionary of lists.
127134
128-
:return: Self, with labels encoded.
135+
:return: A dictionary containing dataset splits and intents as lists of dictionaries.
129136
"""
130-
for split_name, split in self.items():
131-
self[split_name] = split.map(self._encode_label)
132-
self._encoded_labels = True
133-
return self
137+
mapping = {split_name: split.to_list() for split_name, split in self.items()}
138+
mapping[Split.INTENTS] = [intent.model_dump() for intent in self.intents]
139+
return mapping
134140

135-
def to_multilabel(self) -> "Dataset":
141+
def to_json(self, filepath: str | Path) -> None:
136142
"""
137-
Convert dataset labels to multilabel format.
143+
Save the dataset splits and intents to a JSON file.
138144
139-
:return: Self, with labels converted to multilabel.
145+
:param filepath: The path to the file where the JSON data will be saved.
140146
"""
141-
for split_name, split in self.items():
142-
self[split_name] = split.map(self._to_multilabel)
143-
return self
147+
with Path(filepath).open("w") as file:
148+
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
144149

145150
def push_to_hub(self, repo_id: str, private: bool = False) -> None:
146151
"""
@@ -188,6 +193,17 @@ def get_n_classes(self, split: str) -> int:
188193
classes.add(idx)
189194
return len(classes)
190195

196+
def _encode_labels(self) -> "Dataset":
197+
"""
198+
Encode dataset labels into one-hot or multilabel format.
199+
200+
:return: Self, with labels encoded.
201+
"""
202+
for split_name, split in self.items():
203+
self[split_name] = split.map(self._encode_label)
204+
self._encoded_labels = True
205+
return self
206+
191207
def _is_oos(self, sample: Sample) -> bool:
192208
"""
193209
Check if a sample is out-of-scope.

autointent/_dataset/_validation.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""File with definitions of DatasetReader and DatasetValidator."""
22

3-
from pydantic import BaseModel, model_validator
3+
from pydantic import BaseModel, ConfigDict, model_validator
44

55
from autointent.schemas import Intent, Sample
66

@@ -9,16 +9,27 @@ class DatasetReader(BaseModel):
99
"""
1010
A class to represent a dataset reader for handling training, validation, and test data.
1111
12-
:param train: List of samples for training.
12+
:param train: List of samples for training. Defaults to an empty list.
13+
:param train_0: List of samples for scoring module training. Defaults to an empty list.
14+
:param train_1: List of samples for decision module training. Defaults to an empty list.
1315
:param validation: List of samples for validation. Defaults to an empty list.
16+
:param validation_0: List of samples for scoring module validation. Defaults to an empty list.
17+
:param validation_1: List of samples for decision module validation. Defaults to an empty list.
1418
:param test: List of samples for testing. Defaults to an empty list.
1519
:param intents: List of intents associated with the dataset.
1620
"""
1721

18-
train: list[Sample]
22+
train: list[Sample] = []
23+
train_0: list[Sample] = []
24+
train_1: list[Sample] = []
25+
validation: list[Sample] = []
26+
validation_0: list[Sample] = []
27+
validation_1: list[Sample] = []
1928
test: list[Sample] = []
2029
intents: list[Intent] = []
2130

31+
model_config = ConfigDict(extra="forbid")
32+
2233
@model_validator(mode="after")
2334
def validate_dataset(self) -> "DatasetReader":
2435
"""
@@ -27,19 +38,78 @@ def validate_dataset(self) -> "DatasetReader":
2738
:raises ValueError: If intents or samples are not properly validated.
2839
:return: The validated DatasetReader instance.
2940
"""
30-
self._validate_intents()
31-
for split in [self.train, self.test]:
41+
if self.train and (self.train_0 or self.train_1):
42+
message = "If `train` is provided, `train_0` and `train_1` should be empty."
43+
raise ValueError(message)
44+
if not self.train and (not self.train_0 or not self.train_1):
45+
message = "Both `train_0` and `train_1` must be provided if `train` is empty."
46+
raise ValueError(message)
47+
48+
if self.validation and (self.validation_0 or self.validation_1):
49+
message = "If `validation` is provided, `validation_0` and `validation_1` should be empty."
50+
raise ValueError(message)
51+
if not self.validation:
52+
message = "Either both `validation_0` and `validation_1` must be provided, or neither of them."
53+
if not self.validation_0 and self.validation_1:
54+
raise ValueError(message)
55+
if self.validation_0 and not self.validation_1:
56+
raise ValueError(message)
57+
58+
splits = [
59+
self.train,
60+
self.train_0,
61+
self.train_1,
62+
self.validation,
63+
self.validation_0,
64+
self.validation_1,
65+
self.test,
66+
]
67+
splits = [split for split in splits if split]
68+
69+
n_classes = [self._get_n_classes(split) for split in splits]
70+
if len(set(n_classes)) != 1:
71+
message = (
72+
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
73+
"Ensure all splits have the same number of classes."
74+
)
75+
raise ValueError(message)
76+
if not n_classes[0]:
77+
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
78+
raise ValueError(message)
79+
80+
self._validate_intents(n_classes[0])
81+
82+
for split in splits:
3283
self._validate_split(split)
3384
return self
3485

35-
def _validate_intents(self) -> "DatasetReader":
86+
def _get_n_classes(self, split: list[Sample]) -> int:
87+
"""
88+
Get the number of classes in a dataset split.
89+
90+
:param split: List of samples in a dataset split (train, validation, or test).
91+
:return: The number of classes.
92+
"""
93+
classes = set()
94+
for sample in split:
95+
match sample.label:
96+
case int():
97+
classes.add(sample.label)
98+
case list():
99+
for label in sample.label:
100+
classes.add(label)
101+
return len(classes)
102+
103+
def _validate_intents(self, n_classes: int) -> "DatasetReader":
36104
"""
37105
Validate the intents by checking their IDs for sequential order.
38106
107+
:param n_classes: The number of classes in the dataset.
39108
:raises ValueError: If intent IDs are not sequential starting from 0.
40109
:return: The DatasetReader instance after validation.
41110
"""
42111
if not self.intents:
112+
self.intents = [Intent(id=idx) for idx in range(n_classes)]
43113
return self
44114
self.intents = sorted(self.intents, key=lambda intent: intent.id)
45115
intent_ids = [intent.id for intent in self.intents]
@@ -59,8 +129,6 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader":
59129
:raises ValueError: If a sample references an invalid or non-existent intent ID.
60130
:return: The DatasetReader instance after validation.
61131
"""
62-
if not split or not self.intents:
63-
return self
64132
intent_ids = {intent.id for intent in self.intents}
65133
for sample in split:
66134
message = (

autointent/context/_context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ def dump(self) -> None:
137137
# self._logger.info(make_report(optimization_results, nodes=nodes))
138138

139139
# dump train and test data splits
140-
dataset_path = logs_dir / "dataset.json"
141-
with dataset_path.open("w") as file:
142-
json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False)
140+
self.data_handler.dump(logs_dir / "dataset.json")
143141

144142
self._logger.info("logs and other assets are saved to %s", logs_dir)
145143

autointent/context/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def load_data(filepath: str | Path) -> Dataset:
5656
if filepath == "default-multiclass":
5757
return Dataset.from_hub("AutoIntent/clinc150_subset")
5858
if filepath == "default-multilabel":
59-
return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel().encode_labels()
59+
return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel()
6060
if not Path(filepath).exists():
6161
return Dataset.from_hub(str(filepath))
6262
return Dataset.from_json(filepath)

0 commit comments

Comments
 (0)