Skip to content

Commit 4e1d43f

Browse files
truff4utvoorhs
andauthored
Refactor datasets logic (#43)
Co-authored-by: voorhs <[email protected]>
1 parent 1cb4760 commit 4e1d43f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1623
-1781
lines changed

autointent/configs/optimization_cli.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ class OptimizationConfig:
123123
"""Configuration for the logging"""
124124
vector_index: VectorIndexConfig = field(default_factory=VectorIndexConfig)
125125
"""Configuration for the vector index"""
126-
augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
127-
"""Configuration for the augmentation"""
128126
embedder: EmbedderConfig = field(default_factory=EmbedderConfig)
129127
"""Configuration for the embedder"""
130128

@@ -133,7 +131,7 @@ class OptimizationConfig:
133131
"_self_",
134132
{"override hydra/job_logging": "autointent_standard_job_logger"},
135133
{"override hydra/help": "autointent_help"},
136-
]
134+
],
137135
)
138136

139137

autointent/context/context.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
import yaml
1010

1111
from autointent.configs.optimization_cli import (
12-
AugmentationConfig,
1312
DataConfig,
1413
EmbedderConfig,
1514
LoggingConfig,
1615
VectorIndexConfig,
1716
)
1817

19-
from .data_handler import DataAugmenter, DataHandler, Dataset
18+
from .data_handler import DataHandler, Dataset
2019
from .optimization_info import OptimizationInfo
2120
from .utils import NumpyEncoder, load_data
2221
from .vector_index_client import VectorIndex, VectorIndexClient
@@ -71,43 +70,29 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb
7170
self.embedder_config.max_length,
7271
)
7372

74-
def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
73+
def configure_data(self, config: DataConfig) -> None:
7574
"""
76-
Configure data handling and augmentation.
75+
Configure data handling.
7776
7877
:param config: Configuration for the data handling process.
79-
:param augmentation_config: Configuration for data augmentation. If None, no augmentation is applied.
80-
"""
81-
if augmentation_config is not None:
82-
self.augmentation_config = AugmentationConfig()
83-
augmenter = DataAugmenter(
84-
self.augmentation_config.multilabel_generation_config,
85-
self.augmentation_config.regex_sampling,
86-
self.seed,
87-
)
88-
else:
89-
augmenter = None
90-
78+
"""
9179
self.data_handler = DataHandler(
9280
dataset=load_data(config.train_path),
93-
test_dataset=None if config.test_path is None else load_data(config.test_path),
9481
random_seed=self.seed,
9582
force_multilabel=config.force_multilabel,
96-
augmenter=augmenter,
9783
)
9884

99-
def set_datasets(
100-
self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
101-
) -> None:
85+
def set_dataset(self, dataset: Dataset, force_multilabel: bool = False) -> None:
10286
"""
103-
Set the datasets for training and validation.
87+
Set the datasets for training, validation and testing.
10488
105-
:param train_data: Training dataset.
106-
:param val_data: Validation dataset. If None, only training data is used.
89+
:param dataset: Dataset.
10790
:param force_multilabel: Whether to force multilabel classification.
10891
"""
10992
self.data_handler = DataHandler(
110-
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
93+
dataset=dataset,
94+
force_multilabel=force_multilabel,
95+
random_seed=self.seed,
11196
)
11297

11398
def get_best_index(self) -> VectorIndex:
@@ -159,13 +144,12 @@ def dump(self) -> None:
159144
with logs_path.open("w") as file:
160145
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
161146

162-
train_data, test_data = self.data_handler.dump()
163-
train_path = logs_dir / "train_data.json"
164-
test_path = logs_dir / "test_data.json"
165-
with train_path.open("w") as file:
166-
json.dump(train_data, file, indent=4, ensure_ascii=False)
167-
with test_path.open("w") as file:
168-
json.dump(test_data, file, indent=4, ensure_ascii=False)
147+
# self._logger.info(make_report(optimization_results, nodes=nodes))
148+
149+
# dump train and test data splits
150+
dataset_path = logs_dir / "dataset.json"
151+
with dataset_path.open("w") as file:
152+
json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False)
169153

170154
self._logger.info("logs and other assets are saved to %s", logs_dir)
171155

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .data_handler import DataAugmenter, DataHandler
2-
from .schemas import Dataset
3-
from .tags import Tag
1+
from .data_handler import DataHandler
2+
from .dataset import Dataset
3+
from .schemas import Intent, Sample, Tag
44

5-
__all__ = ["DataAugmenter", "DataHandler", "Dataset", "Tag"]
5+
__all__ = ["DataHandler", "Dataset", "Intent", "Sample", "Tag"]

0 commit comments

Comments
 (0)