Skip to content

Commit 8497978

Browse files
authored
Merge branch 'dev' into update_multilabel
2 parents 4757dd8 + 218db9a commit 8497978

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

autointent/_dataset/_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,20 @@ def from_json(cls, filepath: str | Path) -> "Dataset":
9494
return JsonReader().read(filepath)
9595

9696
@classmethod
97-
def from_hub(cls, repo_name: str, data_split: str = "default") -> "Dataset":
97+
def from_hub(cls, repo_name: str, data_split: str = "default", intent_subset_name: str = Split.INTENTS) -> "Dataset":
9898
"""Loads a dataset from the Hugging Face Hub.
9999
100100
Args:
101101
repo_name: The name of the Hugging Face repository, like `DeepPavlov/clinc150`.
102102
data_split: The name of the dataset split to load, defaults to `default`.
103+
intent_subset_name: The name of the intent subset to load, defaults to `intents`.
103104
"""
104105
from ._reader import DictReader
105106

106107
splits = load_dataset(repo_name, data_split)
107108
mapping = dict(**splits)
108-
if Split.INTENTS in get_dataset_config_names(repo_name):
109-
mapping[Split.INTENTS] = load_dataset(repo_name, name=Split.INTENTS, split=Split.INTENTS).to_list()
109+
if intent_subset_name in get_dataset_config_names(repo_name):
110+
mapping[Split.INTENTS] = load_dataset(repo_name, name=intent_subset_name, split=Split.INTENTS).to_list()
110111

111112
return DictReader().read(mapping)
112113

0 commit comments

Comments
 (0)