@@ -71,6 +71,16 @@ def n_classes(self) -> int:
7171 """Returns the number of classes in the dataset."""
7272 return len (self .intents )
7373
74+ def is_multilabel_format (self ) -> bool :
75+ """Checks if the dataset is in multilabel format.
76+
77+ Returns:
78+ bool: True if the dataset is multilabel, False otherwise.
79+ """
80+ ds_keys = list (self .keys ())
81+ first_split = self [ds_keys [0 ]]
82+ return isinstance (first_split .features [self .label_feature ], Sequence )
83+
7484 @classmethod
7585 def from_dict (cls , mapping : dict [str , Any ]) -> "Dataset" :
7686 """Creates a dataset from a dictionary mapping.
@@ -80,7 +90,11 @@ def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
8090 """
8191 from ._reader import DictReader
8292
83- return DictReader ().read (mapping )
93+ dataset = DictReader ().read (mapping )
94+
95+ if dataset .is_multilabel_format ():
96+ dataset = dataset .to_multilabel ()
97+ return dataset
8498
8599 @classmethod
86100 def from_json (cls , filepath : str | Path ) -> "Dataset" :
@@ -91,7 +105,10 @@ def from_json(cls, filepath: str | Path) -> "Dataset":
91105 """
92106 from ._reader import JsonReader
93107
94- return JsonReader ().read (filepath )
108+ dataset = JsonReader ().read (filepath )
109+ if dataset .is_multilabel_format ():
110+ dataset = dataset .to_multilabel ()
111+ return dataset
95112
96113 @classmethod
97114 def from_hub (cls , repo_name : str , data_split : str = "default" ) -> "Dataset" :
@@ -109,9 +126,7 @@ def from_hub(cls, repo_name: str, data_split: str = "default") -> "Dataset":
109126 mapping [Split .INTENTS ] = load_dataset (repo_name , name = Split .INTENTS , split = Split .INTENTS ).to_list ()
110127
111128 dataset = DictReader ().read (mapping )
112- ds_keys = list (dataset .keys ())
113- first_split = dataset [ds_keys [0 ]]
114- if isinstance (first_split .features [dataset .label_feature ], Sequence ):
129+ if dataset .is_multilabel_format ():
115130 dataset = dataset .to_multilabel ()
116131 return dataset
117132
0 commit comments