|
9 | 9 | import yaml |
10 | 10 |
|
11 | 11 | from autointent.configs.optimization_cli import ( |
12 | | - AugmentationConfig, |
13 | 12 | DataConfig, |
14 | 13 | EmbedderConfig, |
15 | 14 | LoggingConfig, |
16 | 15 | VectorIndexConfig, |
17 | 16 | ) |
18 | 17 |
|
19 | | -from .data_handler import DataAugmenter, DataHandler, Dataset |
| 18 | +from .data_handler import DataHandler, Dataset |
20 | 19 | from .optimization_info import OptimizationInfo |
21 | 20 | from .utils import NumpyEncoder, load_data |
22 | 21 | from .vector_index_client import VectorIndex, VectorIndexClient |
@@ -71,43 +70,29 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb |
71 | 70 | self.embedder_config.max_length, |
72 | 71 | ) |
73 | 72 |
|
74 | | - def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None: |
| 73 | + def configure_data(self, config: DataConfig) -> None: |
75 | 74 | """ |
76 | | - Configure data handling and augmentation. |
| 75 | + Configure data handling. |
77 | 76 |
|
78 | 77 | :param config: Configuration for the data handling process. |
79 | | - :param augmentation_config: Configuration for data augmentation. If None, no augmentation is applied. |
80 | | - """ |
81 | | - if augmentation_config is not None: |
82 | | - self.augmentation_config = AugmentationConfig() |
83 | | - augmenter = DataAugmenter( |
84 | | - self.augmentation_config.multilabel_generation_config, |
85 | | - self.augmentation_config.regex_sampling, |
86 | | - self.seed, |
87 | | - ) |
88 | | - else: |
89 | | - augmenter = None |
90 | | - |
| 78 | + """ |
91 | 79 | self.data_handler = DataHandler( |
92 | 80 | dataset=load_data(config.train_path), |
93 | | - test_dataset=None if config.test_path is None else load_data(config.test_path), |
94 | 81 | random_seed=self.seed, |
95 | 82 | force_multilabel=config.force_multilabel, |
96 | | - augmenter=augmenter, |
97 | 83 | ) |
98 | 84 |
|
99 | | - def set_datasets( |
100 | | - self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False |
101 | | - ) -> None: |
| 85 | + def set_dataset(self, dataset: Dataset, force_multilabel: bool = False) -> None: |
102 | 86 | """ |
103 | | - Set the datasets for training and validation. |
| 87 | + Set the datasets for training, validation and testing. |
104 | 88 |
|
105 | | - :param train_data: Training dataset. |
106 | | - :param val_data: Validation dataset. If None, only training data is used. |
| 89 | + :param dataset: Dataset. |
107 | 90 | :param force_multilabel: Whether to force multilabel classification. |
108 | 91 | """ |
109 | 92 | self.data_handler = DataHandler( |
110 | | - dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel |
| 93 | + dataset=dataset, |
| 94 | + force_multilabel=force_multilabel, |
| 95 | + random_seed=self.seed, |
111 | 96 | ) |
112 | 97 |
|
113 | 98 | def get_best_index(self) -> VectorIndex: |
@@ -159,13 +144,12 @@ def dump(self) -> None: |
159 | 144 | with logs_path.open("w") as file: |
160 | 145 | json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder) |
161 | 146 |
|
162 | | - train_data, test_data = self.data_handler.dump() |
163 | | - train_path = logs_dir / "train_data.json" |
164 | | - test_path = logs_dir / "test_data.json" |
165 | | - with train_path.open("w") as file: |
166 | | - json.dump(train_data, file, indent=4, ensure_ascii=False) |
167 | | - with test_path.open("w") as file: |
168 | | - json.dump(test_data, file, indent=4, ensure_ascii=False) |
| 147 | + # self._logger.info(make_report(optimization_results, nodes=nodes)) |
| 148 | + |
| 149 | + # dump train and test data splits |
| 150 | + dataset_path = logs_dir / "dataset.json" |
| 151 | + with dataset_path.open("w") as file: |
| 152 | + json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False) |
169 | 153 |
|
170 | 154 | self._logger.info("logs and other assets are saved to %s", logs_dir) |
171 | 155 |
|
|
0 commit comments