|
19 | 19 | from pytabkit.models.alg_interfaces.other_interfaces import RFSubSplitInterface, SklearnMLPSubSplitInterface, \ |
20 | 20 | KANSubSplitInterface, GrandeSubSplitInterface, GBTSubSplitInterface, RandomParamsRFAlgInterface, \ |
21 | 21 | TabPFN2SubSplitInterface, TabICLSubSplitInterface, RandomParamsExtraTreesAlgInterface, RandomParamsKNNAlgInterface, \ |
22 | | - ExtraTreesSubSplitInterface, KNNSubSplitInterface, RandomParamsLinearModelAlgInterface, LinearModelSubSplitInterface |
| 22 | + ExtraTreesSubSplitInterface, KNNSubSplitInterface, RandomParamsLinearModelAlgInterface, \ |
| 23 | + LinearModelSubSplitInterface |
23 | 24 | from pytabkit.bench.scheduling.resources import NodeResources |
24 | 25 | from pytabkit.models.alg_interfaces.alg_interfaces import AlgInterface, MultiSplitWrapperAlgInterface |
25 | 26 | from pytabkit.models.alg_interfaces.base import SplitIdxs, RequiredResources |
|
34 | 35 | NNHyperoptAlgInterface |
35 | 36 | from pytabkit.models.alg_interfaces.xgboost_interfaces import XGBSubSplitInterface, XGBHyperoptAlgInterface, \ |
36 | 37 | XGBSklearnSubSplitInterface, RandomParamsXGBAlgInterface |
| 38 | +from pytabkit.models.alg_interfaces.xrfm_interfaces import xRFMSubSplitInterface, RandomParamsxRFMAlgInterface |
37 | 39 | from pytabkit.models.data.data import TaskType, DictDataset |
38 | 40 | from pytabkit.models.nn_models.models import PreprocessingFactory |
| 41 | +from pytabkit.models.torch_utils import TorchTimer |
39 | 42 | from pytabkit.models.training.logging import Logger |
40 | 43 | from pytabkit.models.training.metrics import Metrics |
41 | 44 |
|
@@ -115,15 +118,13 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod |
115 | 118 |
|
116 | 119 | interface_resources = assigned_resources.get_interface_resources() |
117 | 120 |
|
118 | | - |
119 | 121 | old_torch_n_threads = torch.get_num_threads() |
120 | 122 | old_torch_n_interop_threads = torch.get_num_interop_threads() |
121 | 123 | torch.set_num_threads(interface_resources.n_threads) |
122 | 124 | # don't set this because it can throw |
123 | 125 | # Error: cannot set number of interop threads after parallel work has started or set_num_interop_threads called |
124 | 126 | # torch.set_num_interop_threads(interface_resources.n_threads) |
125 | 127 |
|
126 | | - |
127 | 128 | ds = task.ds |
128 | 129 | name = 'alg ' + task_package.alg_name + ' on task ' + str(task_desc) |
129 | 130 |
|
@@ -185,22 +186,33 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod |
185 | 186 |
|
186 | 187 | rms = {name: [ResultManager() for _ in task_package.split_infos] for name in pred_param_names} |
187 | 188 |
|
188 | | - cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name) |
| 189 | + with TorchTimer() as cv_fit_timer: |
| 190 | + cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name) |
189 | 191 |
|
190 | 192 | for pred_param_name in pred_param_names: |
191 | 193 | cv_alg_interface.set_current_predict_params(pred_param_name) |
192 | 194 |
|
193 | | - cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds) |
| 195 | + with TorchTimer() as cv_eval_timer: |
| 196 | + cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds) |
194 | 197 |
|
195 | 198 | for rm, cv_results in zip(rms[pred_param_name], cv_results_list): |
196 | | - rm.add_results(is_cv=True, results_dict=cv_results.get_dict()) |
| 199 | + rm.add_results(is_cv=True, results_dict=cv_results.get_dict() | |
| 200 | + dict(fit_time_s=cv_fit_timer.elapsed, |
| 201 | + eval_time_s=cv_eval_timer.elapsed)) |
197 | 202 |
|
198 | 203 | if n_refit > 0: |
199 | 204 | refit_alg_interface = cv_alg_interface.get_refit_interface(n_refit) |
200 | | - refit_results_list = refit_alg_interface.fit_and_eval(ds, refit_idxs_list, interface_resources, logger, |
201 | | - refit_tmp_folders, name, metrics, return_preds) |
| 205 | + |
| 206 | + with TorchTimer() as refit_fit_timer: |
| 207 | + refit_alg_interface.fit(ds, refit_idxs_list, interface_resources, logger, refit_tmp_folders, name) |
| 208 | + |
| 209 | + with TorchTimer() as refit_eval_timer: |
| 210 | + refit_results_list = refit_alg_interface.eval(ds, refit_idxs_list, metrics, return_preds) |
202 | 211 | for rm, refit_results in zip(rms[pred_param_name], refit_results_list): |
203 | | - rm.add_results(is_cv=False, results_dict=refit_results.get_dict()) |
| 212 | + rm.add_results(is_cv=False, |
| 213 | + results_dict=refit_results.get_dict() | |
| 214 | + dict(fit_time_s=refit_fit_timer.elapsed, |
| 215 | + eval_time_s=refit_eval_timer.elapsed)) |
204 | 216 |
|
205 | 217 | torch.set_num_threads(old_torch_n_threads) |
206 | 218 | # torch.set_num_interop_threads(old_torch_n_interop_threads) |
@@ -578,3 +590,14 @@ class RandomParamsLinearModelInterfaceWrapper(AlgInterfaceWrapper): |
578 | 590 | def __init__(self, model_idx: int, **config): |
579 | 591 | # model_idx should be the random search iteration (i.e. start from zero) |
580 | 592 | super().__init__(RandomParamsLinearModelAlgInterface, model_idx=model_idx, **config) |
| 593 | + |
| 594 | + |
| 595 | +class xRFMInterfaceWrapper(SubSplitInterfaceWrapper): |
| 596 | + def create_sub_split_interface(self, task_type: TaskType) -> AlgInterface: |
| 597 | + return xRFMSubSplitInterface(**self.config) |
| 598 | + |
| 599 | + |
| 600 | +class RandomParamsxRFMInterfaceWrapper(MultiSplitAlgInterfaceWrapper): |
| 601 | + def create_single_alg_interface(self, n_cv: int, task_type: TaskType) \ |
| 602 | + -> AlgInterface: |
| 603 | + return RandomParamsxRFMAlgInterface(**self.config) |
0 commit comments