Skip to content

Commit a5dc0ae

Browse files
committed
add autoconvert to multilabel when read
1 parent f107fb8 commit a5dc0ae

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

autointent/_dataset/_dataset.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ def n_classes(self) -> int:
7171
"""Returns the number of classes in the dataset."""
7272
return len(self.intents)
7373

74+
def is_multilabel_format(self) -> bool:
75+
"""Checks if the dataset is in multilabel format.
76+
77+
Returns:
78+
bool: True if the dataset is multilabel, False otherwise.
79+
"""
80+
ds_keys = list(self.keys())
81+
first_split = self[ds_keys[0]]
82+
return isinstance(first_split.features[self.label_feature], Sequence)
83+
7484
@classmethod
7585
def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
7686
"""Creates a dataset from a dictionary mapping.
@@ -80,7 +90,11 @@ def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
8090
"""
8191
from ._reader import DictReader
8292

83-
return DictReader().read(mapping)
93+
dataset = DictReader().read(mapping)
94+
95+
if dataset.is_multilabel_format():
96+
dataset = dataset.to_multilabel()
97+
return dataset
8498

8599
@classmethod
86100
def from_json(cls, filepath: str | Path) -> "Dataset":
@@ -91,7 +105,10 @@ def from_json(cls, filepath: str | Path) -> "Dataset":
91105
"""
92106
from ._reader import JsonReader
93107

94-
return JsonReader().read(filepath)
108+
dataset = JsonReader().read(filepath)
109+
if dataset.is_multilabel_format():
110+
dataset = dataset.to_multilabel()
111+
return dataset
95112

96113
@classmethod
97114
def from_hub(cls, repo_name: str, data_split: str = "default") -> "Dataset":
@@ -109,9 +126,7 @@ def from_hub(cls, repo_name: str, data_split: str = "default") -> "Dataset":
109126
mapping[Split.INTENTS] = load_dataset(repo_name, name=Split.INTENTS, split=Split.INTENTS).to_list()
110127

111128
dataset = DictReader().read(mapping)
112-
ds_keys = list(dataset.keys())
113-
first_split = dataset[ds_keys[0]]
114-
if isinstance(first_split.features[dataset.label_feature], Sequence):
129+
if dataset.is_multilabel_format():
115130
dataset = dataset.to_multilabel()
116131
return dataset
117132

autointent/custom_types/_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class NodeType(str, Enum):
7373
decision = "decision"
7474

7575

76-
class Split:
76+
class Split(str, Enum):
7777
"""Enumeration of data splits in the AutoIntent framework.
7878
7979
Attributes:

0 commit comments

Comments
 (0)