Skip to content

Commit 75a69b5

Browse files
committed
Merge branch 'dev'
2 parents cecd3f6 + d944b1c commit 75a69b5

19 files changed

+822
-76
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public_export
1111
dist
1212
files
1313
plots
14+
lightning_logs
1415

1516
docs/build
1617
docs/source/modules.rst

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
200200

201201
## Releases (see git tags)
202202

203+
- v1.7.1:
204+
- LightGBM now processes the `extra_trees`, `max_cat_to_onehot`, and `min_data_per_group` parameters
205+
used in the `'tabarena'` search space, which should improve results.
206+
- Scikit-learn interfaces for RealMLP (TD, HPO) now support moving the model to a different device
207+
(e.g., before saving). This can be achived using, e.g., `model.to('cpu')` (which is in-place).
208+
- Fixed an xRFM bug in handling binary categorical features.
203209
- v1.7.0:
204210
- added [xRFM](https://arxiv.org/abs/2508.10053) (D, HPO)
205211
- added new `'tabarena-new'` search space for RealMLP-HPO, including per-fold ensembling (more expensive)

pytabkit/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.7.0"
5+
__version__ = "1.7.1"

pytabkit/bench/alg_wrappers/interface_wrappers.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from pytabkit.models.alg_interfaces.other_interfaces import RFSubSplitInterface, SklearnMLPSubSplitInterface, \
2020
KANSubSplitInterface, GrandeSubSplitInterface, GBTSubSplitInterface, RandomParamsRFAlgInterface, \
2121
TabPFN2SubSplitInterface, TabICLSubSplitInterface, RandomParamsExtraTreesAlgInterface, RandomParamsKNNAlgInterface, \
22-
ExtraTreesSubSplitInterface, KNNSubSplitInterface, RandomParamsLinearModelAlgInterface, LinearModelSubSplitInterface
22+
ExtraTreesSubSplitInterface, KNNSubSplitInterface, RandomParamsLinearModelAlgInterface, \
23+
LinearModelSubSplitInterface
2324
from pytabkit.bench.scheduling.resources import NodeResources
2425
from pytabkit.models.alg_interfaces.alg_interfaces import AlgInterface, MultiSplitWrapperAlgInterface
2526
from pytabkit.models.alg_interfaces.base import SplitIdxs, RequiredResources
@@ -34,8 +35,10 @@
3435
NNHyperoptAlgInterface
3536
from pytabkit.models.alg_interfaces.xgboost_interfaces import XGBSubSplitInterface, XGBHyperoptAlgInterface, \
3637
XGBSklearnSubSplitInterface, RandomParamsXGBAlgInterface
38+
from pytabkit.models.alg_interfaces.xrfm_interfaces import xRFMSubSplitInterface, RandomParamsxRFMAlgInterface
3739
from pytabkit.models.data.data import TaskType, DictDataset
3840
from pytabkit.models.nn_models.models import PreprocessingFactory
41+
from pytabkit.models.torch_utils import TorchTimer
3942
from pytabkit.models.training.logging import Logger
4043
from pytabkit.models.training.metrics import Metrics
4144

@@ -115,15 +118,13 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod
115118

116119
interface_resources = assigned_resources.get_interface_resources()
117120

118-
119121
old_torch_n_threads = torch.get_num_threads()
120122
old_torch_n_interop_threads = torch.get_num_interop_threads()
121123
torch.set_num_threads(interface_resources.n_threads)
122124
# don't set this because it can throw
123125
# Error: cannot set number of interop threads after parallel work has started or set_num_interop_threads called
124126
# torch.set_num_interop_threads(interface_resources.n_threads)
125127

126-
127128
ds = task.ds
128129
name = 'alg ' + task_package.alg_name + ' on task ' + str(task_desc)
129130

@@ -185,22 +186,33 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod
185186

186187
rms = {name: [ResultManager() for _ in task_package.split_infos] for name in pred_param_names}
187188

188-
cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name)
189+
with TorchTimer() as cv_fit_timer:
190+
cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name)
189191

190192
for pred_param_name in pred_param_names:
191193
cv_alg_interface.set_current_predict_params(pred_param_name)
192194

193-
cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds)
195+
with TorchTimer() as cv_eval_timer:
196+
cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds)
194197

195198
for rm, cv_results in zip(rms[pred_param_name], cv_results_list):
196-
rm.add_results(is_cv=True, results_dict=cv_results.get_dict())
199+
rm.add_results(is_cv=True, results_dict=cv_results.get_dict() |
200+
dict(fit_time_s=cv_fit_timer.elapsed,
201+
eval_time_s=cv_eval_timer.elapsed))
197202

198203
if n_refit > 0:
199204
refit_alg_interface = cv_alg_interface.get_refit_interface(n_refit)
200-
refit_results_list = refit_alg_interface.fit_and_eval(ds, refit_idxs_list, interface_resources, logger,
201-
refit_tmp_folders, name, metrics, return_preds)
205+
206+
with TorchTimer() as refit_fit_timer:
207+
refit_alg_interface.fit(ds, refit_idxs_list, interface_resources, logger, refit_tmp_folders, name)
208+
209+
with TorchTimer() as refit_eval_timer:
210+
refit_results_list = refit_alg_interface.eval(ds, refit_idxs_list, metrics, return_preds)
202211
for rm, refit_results in zip(rms[pred_param_name], refit_results_list):
203-
rm.add_results(is_cv=False, results_dict=refit_results.get_dict())
212+
rm.add_results(is_cv=False,
213+
results_dict=refit_results.get_dict() |
214+
dict(fit_time_s=refit_fit_timer.elapsed,
215+
eval_time_s=refit_eval_timer.elapsed))
204216

205217
torch.set_num_threads(old_torch_n_threads)
206218
# torch.set_num_interop_threads(old_torch_n_interop_threads)
@@ -578,3 +590,14 @@ class RandomParamsLinearModelInterfaceWrapper(AlgInterfaceWrapper):
578590
def __init__(self, model_idx: int, **config):
579591
# model_idx should be the random search iteration (i.e. start from zero)
580592
super().__init__(RandomParamsLinearModelAlgInterface, model_idx=model_idx, **config)
593+
594+
595+
class xRFMInterfaceWrapper(SubSplitInterfaceWrapper):
596+
def create_sub_split_interface(self, task_type: TaskType) -> AlgInterface:
597+
return xRFMSubSplitInterface(**self.config)
598+
599+
600+
class RandomParamsxRFMInterfaceWrapper(MultiSplitAlgInterfaceWrapper):
601+
def create_single_alg_interface(self, n_cv: int, task_type: TaskType) \
602+
-> AlgInterface:
603+
return RandomParamsxRFMAlgInterface(**self.config)

pytabkit/bench/eval/tables.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
from pytabkit.models.data.data import TaskType
1212
from pytabkit.models.data.nested_dict import NestedDict
1313

14-
15-
def _get_table_str(table_head: List[List[str]], table_body: List[List[str]]):
16-
head_row_strs = [' & '.join(row) + r' \\' for row in table_head]
17-
body_row_strs = [' & '.join(row) + r' \\' for row in table_body]
18-
n_cols = max(len(row) for row in table_head + table_body)
14+
def _get_table_str(*parts: List[List[str]]):
15+
part_rows = [[' & '.join(row) + r' \\' for row in part] for part in parts]
16+
n_cols = max(len(row) for part in parts for row in part)
1917
begin_table_str = r'\begin{tabular}{' + ('c' * n_cols) + r'}' + '\n' + r'\toprule'
2018
end_table_str = r'\bottomrule' + '\n' + r'\end{tabular}'
21-
all_row_strs = [begin_table_str] + head_row_strs + [r'\midrule'] + body_row_strs + [end_table_str]
19+
all_row_strs = [begin_table_str]
20+
for part in part_rows[:-1]:
21+
all_row_strs.extend(part)
22+
all_row_strs.append(r'\midrule')
23+
all_row_strs.extend(part_rows[-1])
24+
all_row_strs.append(end_table_str)
2225
complete_str = '\n'.join(all_row_strs)
2326
return complete_str
2427

@@ -208,7 +211,7 @@ def generate_ablations_table(paths: Paths, tables: ResultsTables):
208211
(r'Activation=SELU', 'act-selu'),
209212
('', ''),
210213
('No dropout', 'pdrop-0.0'),
211-
('Dropout prob.\ $0.15$ (constant)', 'pdrop-0.15'),
214+
(r'Dropout prob.\ $0.15$ (constant)', 'pdrop-0.15'),
212215
('', ''),
213216
('No weight decay', 'wd-0.0'),
214217
# ('Weight decay = 0.02 ($\operatorname{flat\_cos}$)', 'wd-0.02-flatcos'),

pytabkit/models/alg_interfaces/alg_interfaces.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import warnings
23
from pathlib import Path
34
from typing import List, Tuple, Any, Optional, Dict
45

@@ -246,6 +247,9 @@ def get_current_predict_params_dict(self):
246247
def set_current_predict_params(self, name: str) -> None:
247248
self.curr_pred_params_name = name
248249

250+
def to(self, device: str) -> None:
251+
warnings.warn(f'.to() method does nothing for {self.__class__} (not implemented)')
252+
249253

250254
class MultiSplitWrapperAlgInterface(AlgInterface):
251255
# todo: do we need the option to run this with a "split batch size" > 1 for the NNInterface?

pytabkit/models/alg_interfaces/calibration.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,52 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
4141
self.alg_interface.fit(ds, idxs_list, interface_resources, logger, tmp_folders, name)
4242
y_preds = self.alg_interface.predict(ds)
4343

44-
for tt_split_idx, split_idxs in enumerate(idxs_list):
45-
for tv_split_idx in range(split_idxs.n_trainval_splits):
46-
val_idxs = split_idxs.val_idxs[tv_split_idx]
47-
y = ds.tensors['y'][val_idxs]
48-
y_pred = y_preds[len(self.calibrators), val_idxs]
49-
y_pred_probs = torch.softmax(y_pred, dim=-1)
44+
self.n_tv_splits_list_ = [idxs.n_trainval_splits for idxs in idxs_list]
45+
46+
if self.config.get('calibrate_per_fold', True):
47+
for tt_split_idx, split_idxs in enumerate(idxs_list):
48+
for tv_split_idx in range(split_idxs.n_trainval_splits):
49+
val_idxs = split_idxs.val_idxs[tv_split_idx]
50+
y = ds.tensors['y'][val_idxs]
51+
y_pred = y_preds[len(self.calibrators), val_idxs]
52+
y_pred_probs = torch.softmax(y_pred, dim=-1)
53+
54+
import probmetrics.calibrators
55+
import probmetrics.distributions
56+
calib = probmetrics.calibrators.get_calibrator(**self.config)
57+
if self.config.get('calibrate_with_logits', True):
58+
calib.fit_torch(y_pred=probmetrics.distributions.CategoricalLogits(y_pred.detach().cpu()),
59+
y_true_labels=y[:, 0])
60+
else:
61+
calib.fit(self._transform_probs(y_pred_probs.detach().cpu().numpy()), y.cpu().numpy()[:, 0])
62+
63+
self.calibrators.append(calib)
64+
self.n_calibs.append(val_idxs.shape[-1])
65+
else:
66+
y_pred_idx = 0
67+
for tt_split_idx, split_idxs in enumerate(idxs_list):
68+
y_pred_list = []
69+
y_list = []
70+
for tv_split_idx in range(split_idxs.n_trainval_splits):
71+
val_idxs = split_idxs.val_idxs[tv_split_idx]
72+
y_pred_list.append(y_preds[y_pred_idx, val_idxs])
73+
y_list.append(ds.tensors['y'][val_idxs])
74+
y_pred_idx += 1
75+
76+
y_pred = torch.cat(y_pred_list, dim=0)
77+
y = torch.cat(y_list, dim=0)
5078

5179
import probmetrics.calibrators
5280
import probmetrics.distributions
5381
calib = probmetrics.calibrators.get_calibrator(**self.config)
5482
if self.config.get('calibrate_with_logits', True):
5583
calib.fit_torch(y_pred=probmetrics.distributions.CategoricalLogits(y_pred.detach().cpu()),
56-
y_true_labels=y[:, 0])
84+
y_true_labels=y[:, 0].detach().cpu())
5785
else:
58-
calib.fit(self._transform_probs(y_pred_probs.detach().cpu().numpy()), y.cpu().numpy()[:, 0])
86+
calib.fit(self._transform_probs(torch.softmax(y_pred, dim=-1).detach().cpu().numpy()), y.cpu().numpy()[:, 0])
5987

60-
self.calibrators.append(calib)
61-
self.n_calibs.append(val_idxs.shape[-1])
88+
self.calibrators.extend([calib] * split_idxs.n_trainval_splits)
89+
self.n_calibs.extend([y_pred.shape[0]] * split_idxs.n_trainval_splits)
6290

6391
self.fit_params = [dict(sub_fit_params=fp) for fp in self.alg_interface.fit_params]
6492

@@ -68,6 +96,15 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
6896
y_preds = self.alg_interface.predict(ds)
6997
y_preds_probs = torch.softmax(y_preds, dim=-1)
7098
y_preds_calib = []
99+
100+
if self.config.get('ensemble_before_calib', False):
101+
start_idx = 0
102+
for n_tv_splits in self.n_tv_splits_list_:
103+
avg_probs = y_preds_probs[start_idx:start_idx+n_tv_splits].mean(dim=0, keepdim=True)
104+
y_preds_probs[start_idx:start_idx + n_tv_splits] = avg_probs
105+
start_idx += n_tv_splits
106+
y_preds = torch.log(y_preds_probs + 1e-30)
107+
71108
for i in range(y_preds.shape[0]):
72109
if self.config.get('calibrate_with_logits', True):
73110
from probmetrics.distributions import CategoricalLogits
@@ -90,3 +127,7 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
90127
def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_splits: int,
91128
split_seeds: List[int], n_train: int) -> RequiredResources:
92129
return self.alg_interface.get_required_resources(ds, n_cv, n_refit, n_splits, split_seeds, n_train=n_train)
130+
131+
def to(self, device: str) -> None:
132+
self.alg_interface.to(device)
133+

pytabkit/models/alg_interfaces/ensemble_interfaces.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_spl
176176
for ssi in self.alg_interfaces]
177177
return RequiredResources.combine_sequential(single_resources)
178178

179+
def to(self, device: str) -> None:
180+
for alg_idx, alg_ctx in enumerate(self.alg_contexts_):
181+
with alg_ctx as alg_interface:
182+
alg_interface.to(device)
183+
184+
179185

180186
class AlgorithmSelectionAlgInterface(SingleSplitAlgInterface):
181187
"""
@@ -277,6 +283,12 @@ def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_spl
277283
for ssi in self.alg_interfaces]
278284
return RequiredResources.combine_sequential(single_resources)
279285

286+
def to(self, device: str) -> None:
287+
for alg_idx, alg_ctx in enumerate(self.alg_contexts_):
288+
with alg_ctx as alg_interface:
289+
alg_interface.to(device)
290+
291+
280292

281293
class PrecomputedPredictionsAlgInterface(SingleSplitAlgInterface):
282294
def __init__(self, y_preds_cv: torch.Tensor, y_preds_refit: Optional[torch.Tensor],

pytabkit/models/alg_interfaces/lightgbm_interfaces.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def _get_params(self):
123123
('cat_smooth', None),
124124
('cat_l2', None),
125125
('early_stopping_round', ['early_stopping_round', 'early_stopping_rounds'], None),
126+
('extra_trees', None),
127+
('max_cat_to_onehot', None),
128+
('min_data_per_group', None),
126129
]
127130

128131
params = utils.extract_params(self.config, params_config)
@@ -686,10 +689,10 @@ def _sample_params(self, is_classification: bool, seed: int, n_train: int):
686689
'min_data_in_leaf': np.floor(np.exp(rng.uniform(np.log(1.0), np.log(65)))),
687690
'extra_trees': rng.choice([False, True]),
688691

689-
'min_data_per_group': np.floor(np.exp(rng.uniform(np.log(2.0), np.log(101)))),
692+
'min_data_per_group': round(np.floor(np.exp(rng.uniform(np.log(2.0), np.log(101))))),
690693
'cat_l2': np.exp(rng.uniform(np.log(5e-3), np.log(2.0))),
691694
'cat_smooth': np.exp(rng.uniform(np.log(1e-3), np.log(100.0))),
692-
'max_cat_to_onehot': np.floor(np.exp(rng.uniform(np.log(8.0), np.log(101.0)))),
695+
'max_cat_to_onehot': round(np.floor(np.exp(rng.uniform(np.log(8.0), np.log(101.0))))),
693696

694697
'lambda_l1': np.exp(rng.uniform(np.log(1e-5), np.log(1.0))),
695698
'lambda_l2': np.exp(rng.uniform(np.log(1e-5), np.log(2.0))),

0 commit comments

Comments
 (0)