|
1 | 1 | import shutil |
| 2 | +import time |
2 | 3 | from pathlib import Path |
3 | 4 | from typing import Callable, List, Optional, Dict |
4 | 5 |
|
|
9 | 10 | from pytabkit.models.alg_interfaces.autogluon_model_interfaces import AutoGluonModelAlgInterface |
10 | 11 | from pytabkit.models.alg_interfaces.catboost_interfaces import CatBoostSubSplitInterface, CatBoostHyperoptAlgInterface, \ |
11 | 12 | CatBoostSklearnSubSplitInterface, RandomParamsCatBoostAlgInterface |
| 13 | +from pytabkit.models.alg_interfaces.custom_interfaces import TabPFNV2SubSplitInterface |
12 | 14 | from pytabkit.models.alg_interfaces.ensemble_interfaces import PrecomputedPredictionsAlgInterface, \ |
13 | 15 | CaruanaEnsembleAlgInterface, AlgorithmSelectionAlgInterface |
14 | 16 | from pytabkit.models.alg_interfaces.lightgbm_interfaces import LGBMSubSplitInterface, LGBMHyperoptAlgInterface, \ |
|
34 | 36 | NNHyperoptAlgInterface |
35 | 37 | from pytabkit.models.alg_interfaces.xgboost_interfaces import XGBSubSplitInterface, XGBHyperoptAlgInterface, \ |
36 | 38 | XGBSklearnSubSplitInterface, RandomParamsXGBAlgInterface |
| 39 | +from pytabkit.models.alg_interfaces.xrfm_interfaces import xRFMSubSplitInterface, RandomParamsxRFMAlgInterface |
37 | 40 | from pytabkit.models.data.data import TaskType, DictDataset |
38 | 41 | from pytabkit.models.nn_models.models import PreprocessingFactory |
| 42 | +from pytabkit.models.torch_utils import TorchTimer |
39 | 43 | from pytabkit.models.training.logging import Logger |
40 | 44 | from pytabkit.models.training.metrics import Metrics |
41 | 45 |
|
@@ -115,15 +119,13 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod |
115 | 119 |
|
116 | 120 | interface_resources = assigned_resources.get_interface_resources() |
117 | 121 |
|
118 | | - |
119 | 122 | old_torch_n_threads = torch.get_num_threads() |
120 | 123 | old_torch_n_interop_threads = torch.get_num_interop_threads() |
121 | 124 | torch.set_num_threads(interface_resources.n_threads) |
122 | 125 | # don't set this because it can throw |
123 | 126 | # Error: cannot set number of interop threads after parallel work has started or set_num_interop_threads called |
124 | 127 | # torch.set_num_interop_threads(interface_resources.n_threads) |
125 | 128 |
|
126 | | - |
127 | 129 | ds = task.ds |
128 | 130 | name = 'alg ' + task_package.alg_name + ' on task ' + str(task_desc) |
129 | 131 |
|
@@ -185,22 +187,33 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod |
185 | 187 |
|
186 | 188 | rms = {name: [ResultManager() for _ in task_package.split_infos] for name in pred_param_names} |
187 | 189 |
|
188 | | - cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name) |
| 190 | + with TorchTimer() as cv_fit_timer: |
| 191 | + cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name) |
189 | 192 |
|
190 | 193 | for pred_param_name in pred_param_names: |
191 | 194 | cv_alg_interface.set_current_predict_params(pred_param_name) |
192 | 195 |
|
193 | | - cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds) |
| 196 | + with TorchTimer() as cv_eval_timer: |
| 197 | + cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds) |
194 | 198 |
|
195 | 199 | for rm, cv_results in zip(rms[pred_param_name], cv_results_list): |
196 | | - rm.add_results(is_cv=True, results_dict=cv_results.get_dict()) |
| 200 | + rm.add_results(is_cv=True, results_dict=cv_results.get_dict() | |
| 201 | + dict(fit_time_s=cv_fit_timer.elapsed, |
| 202 | + eval_time_s=cv_eval_timer.elapsed)) |
197 | 203 |
|
198 | 204 | if n_refit > 0: |
199 | 205 | refit_alg_interface = cv_alg_interface.get_refit_interface(n_refit) |
200 | | - refit_results_list = refit_alg_interface.fit_and_eval(ds, refit_idxs_list, interface_resources, logger, |
201 | | - refit_tmp_folders, name, metrics, return_preds) |
| 206 | + |
| 207 | + with TorchTimer() as refit_fit_timer: |
| 208 | + refit_alg_interface.fit(ds, refit_idxs_list, interface_resources, logger, refit_tmp_folders, name) |
| 209 | + |
| 210 | + with TorchTimer() as refit_eval_timer: |
| 211 | + refit_results_list = refit_alg_interface.eval(ds, refit_idxs_list, metrics, return_preds) |
202 | 212 | for rm, refit_results in zip(rms[pred_param_name], refit_results_list): |
203 | | - rm.add_results(is_cv=False, results_dict=refit_results.get_dict()) |
| 213 | + rm.add_results(is_cv=False, |
| 214 | + results_dict=refit_results.get_dict() | |
| 215 | + dict(fit_time_s=refit_fit_timer.elapsed, |
| 216 | + eval_time_s=refit_eval_timer.elapsed)) |
204 | 217 |
|
205 | 218 | torch.set_num_threads(old_torch_n_threads) |
206 | 219 | # torch.set_num_interop_threads(old_torch_n_interop_threads) |
@@ -578,3 +591,19 @@ class RandomParamsLinearModelInterfaceWrapper(AlgInterfaceWrapper): |
578 | 591 | def __init__(self, model_idx: int, **config): |
579 | 592 | # model_idx should be the random search iteration (i.e. start from zero) |
580 | 593 | super().__init__(RandomParamsLinearModelAlgInterface, model_idx=model_idx, **config) |
| 594 | + |
| 595 | + |
| 596 | +class TabPFNV2InterfaceWrapper(SubSplitInterfaceWrapper): |
| 597 | + def create_sub_split_interface(self, task_type: TaskType) -> AlgInterface: |
| 598 | + return TabPFNV2SubSplitInterface(**self.config) |
| 599 | + |
| 600 | + |
| 601 | +class xRFMInterfaceWrapper(SubSplitInterfaceWrapper): |
| 602 | + def create_sub_split_interface(self, task_type: TaskType) -> AlgInterface: |
| 603 | + return xRFMSubSplitInterface(**self.config) |
| 604 | + |
| 605 | + |
| 606 | +class RandomParamsxRFMInterfaceWrapper(MultiSplitAlgInterfaceWrapper): |
| 607 | + def create_single_alg_interface(self, n_cv: int, task_type: TaskType) \ |
| 608 | + -> AlgInterface: |
| 609 | + return RandomParamsxRFMAlgInterface(**self.config) |
0 commit comments