Skip to content

Commit fcf61a3

Browse files
Samoedvoorhs
andauthored
load from hub (#33)
* load from hub * type and lint * move `datasets` from dev dependencies to common --------- Co-authored-by: voorhs <[email protected]>
1 parent c7b5c4e commit fcf61a3

File tree

4 files changed

+69
-2
lines changed

4 files changed

+69
-2
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ autointent hydra.job_logging.root.level=ERROR
7070
Еще можно изменить параметры логгера через yaml файлы:
7171
1. Создадим папку с конфиг. файлами: test_config
7272
2. test_config/config.yaml:
73-
```
73+
```yaml
7474
defaults:
7575
- optimization_config
7676
- _self_

autointent/context/data_handler/schemas.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from enum import Enum
22
from functools import cached_property
3+
from typing import Any
34

5+
import datasets
46
from pydantic import BaseModel
7+
from typing_extensions import Self
58

69
from autointent.custom_types import LabelType
710

@@ -83,3 +86,35 @@ def n_classes(self) -> int:
8386

8487
def to_multilabel(self) -> "Dataset":
8588
return Dataset(utterances=[utterance.to_multilabel() for utterance in self.utterances], intents=self.intents)
89+
90+
@classmethod
91+
def from_datasets(
92+
cls,
93+
dataset_name: str,
94+
split: str = "train",
95+
utterances_kwargs: dict[str, Any] | None = None,
96+
intents_kwargs: dict[str, Any] | None = None,
97+
# tags_kwargs: dict[str, Any] | None = None,
98+
) -> Self:
99+
configs = datasets.get_dataset_config_names(dataset_name)
100+
101+
utterances = []
102+
intents = []
103+
if "utterances" in configs:
104+
utterance_ds = datasets.load_dataset(
105+
dataset_name, name="utterances", split=split, **(utterances_kwargs or {})
106+
)
107+
utterances = [Utterance(**item) for item in utterance_ds]
108+
# tags = []
109+
# if "tags" in configs:
110+
# tags_ds = datasets.load_dataset(dataset_name, name="tags", split=split, **(tags_kwargs or {}))
111+
if "intents" in configs:
112+
intents_ds = datasets.load_dataset(dataset_name, name="intents", split=split, **(intents_kwargs or {}))
113+
intents = [Intent(**item) for item in intents_ds]
114+
return cls(utterances=utterances, intents=intents)
115+
116+
def push_to_hub(self, dataset_name: str, split: str = "train") -> None:
117+
utterances_ds = datasets.Dataset.from_list([utterance.model_dump() for utterance in self.utterances])
118+
intents_ds = datasets.Dataset.from_list([intent.model_dump() for intent in self.intents])
119+
utterances_ds.push_to_hub(dataset_name, config_name="utterances", split=split)
120+
intents_ds.push_to_hub(dataset_name, config_name="intents", split=split)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ pydantic = "^2.9.2"
1919
hydra-core = "^1.3.2"
2020
faiss-cpu = "^1.9.0"
2121
openai = "^1.52.1"
22+
datasets = "2.20.0"
2223

2324

2425
[tool.poetry.group.dev]
2526
optional = true
2627

2728
[tool.poetry.group.dev.dependencies]
28-
datasets = "2.20.0"
2929
tach = "^0.11.3"
3030
ipykernel = "^6.29.5"
3131
ipywidgets = "^8.1.5"
@@ -139,6 +139,7 @@ module = [
139139
"hydra.*",
140140
"transformers",
141141
"faiss",
142+
"datasets",
142143
"joblib",
143144
]
144145
ignore_missing_imports = true
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from datasets import Dataset, load_dataset, DatasetDict
2+
3+
4+
def transform_dataset(
5+
path: str,
6+
) -> tuple[Dataset | None, Dataset | None, Dataset | None]:
7+
ds: DatasetDict = load_dataset("json", data_files=path)["train"]
8+
utterance_ds = None
9+
tags_ds = None
10+
intents_ds = None
11+
if "utterances" in ds.column_names:
12+
utterance_ds = Dataset.from_list(ds["utterances"][0])
13+
if "tags" in ds.column_names:
14+
tags_ds = Dataset.from_list(ds["tags"][0])
15+
if "intents" in ds.column_names:
16+
intents_ds = Dataset.from_list(ds["intents"][0])
17+
return utterance_ds, tags_ds, intents_ds
18+
19+
20+
def push_json_to_hub(path: str, ds_name: str) -> None:
21+
utterance_ds, tags_ds, intents_ds = transform_dataset(path)
22+
if utterance_ds is not None:
23+
utterance_ds.push_to_hub(ds_name, config_name="utterances")
24+
if tags_ds is not None:
25+
tags_ds.push_to_hub(ds_name, config_name="tags")
26+
if intents_ds is not None:
27+
intents_ds.push_to_hub(ds_name, config_name="intents")
28+
29+
30+
if __name__ == "__main__":
31+
push_json_to_hub("../tests/assets/data/clinc_subset_multilabel.json", "clinc_subset_multilabel")

0 commit comments

Comments
 (0)