Skip to content

Commit e8c9410

Browse files
committed
v1.7.1: move realmlp to different devices, implement missing lgbm parameters, fix xRFM bug
1 parent a9631bc commit e8c9410

19 files changed

+823
-77
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public_export
1111
dist
1212
files
1313
plots
14+
lightning_logs
1415

1516
docs/build
1617
docs/source/modules.rst

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ on our benchmarks.
2828
(e.g., `class_error`, `cross_entropy`, `brier`, `1-auc_ovr`), or the corresponding `Regressor`.
2929
(This might take very long to fit.)
3030
- For only a single model, we recommend using
31-
`RealMLP_HPO_Classifier(n_cv=8, hpo_space_name='tabarena', use_caruana_ensembling=True, n_hyperopt_steps=50)`,
31+
`RealMLP_HPO_Classifier(n_cv=8, hpo_space_name='tabarena-new', use_caruana_ensembling=True, n_hyperopt_steps=50)`,
3232
also with `val_metric_name` as above, or the corresponding `Regressor`.
3333
- **Models**: [TabArena](https://github.com/AutoGluon/tabarena)
3434
also includes some newer models like RealMLP and TabM
@@ -184,9 +184,9 @@ If you use this repository for research purposes, please cite our [paper](https:
184184
- Léo Grinsztajn (deep learning baselines, plotting)
185185
- Ingo Steinwart (UCI dataset download)
186186
- Katharina Strecker (PyTorch-Lightning interface)
187+
- Daniel Beaglehole (part of the xRFM implementation)
187188
- Lennart Purucker (some features/fixes)
188189
- Jérôme Dockès (deployment, continuous integration)
189-
-
190190

191191
## Acknowledgements
192192

@@ -200,6 +200,12 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
200200

201201
## Releases (see git tags)
202202

203+
- v1.7.1:
204+
- LightGBM now processes the `extra_trees`, `max_cat_to_onehot`, and `min_data_per_group` parameters
205+
used in the `'tabarena'` search space, which should improve results.
206+
- Scikit-learn interfaces for RealMLP (TD, HPO) now support moving the model to a different device
207+
(e.g., before saving). This can be achived using, e.g., `model.to('cpu')` (which is in-place).
208+
- Fixed an xRFM bug in handling binary categorical features.
203209
- v1.7.0:
204210
- added [xRFM](https://arxiv.org/abs/2508.10053) (D, HPO)
205211
- added new `'tabarena-new'` search space for RealMLP-HPO, including per-fold ensembling (more expensive)

pytabkit/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.7.0"
5+
__version__ = "1.7.1"

pytabkit/bench/alg_wrappers/interface_wrappers.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import shutil
2+
import time
23
from pathlib import Path
34
from typing import Callable, List, Optional, Dict
45

@@ -9,6 +10,7 @@
910
from pytabkit.models.alg_interfaces.autogluon_model_interfaces import AutoGluonModelAlgInterface
1011
from pytabkit.models.alg_interfaces.catboost_interfaces import CatBoostSubSplitInterface, CatBoostHyperoptAlgInterface, \
1112
CatBoostSklearnSubSplitInterface, RandomParamsCatBoostAlgInterface
13+
from pytabkit.models.alg_interfaces.custom_interfaces import TabPFNV2SubSplitInterface
1214
from pytabkit.models.alg_interfaces.ensemble_interfaces import PrecomputedPredictionsAlgInterface, \
1315
CaruanaEnsembleAlgInterface, AlgorithmSelectionAlgInterface
1416
from pytabkit.models.alg_interfaces.lightgbm_interfaces import LGBMSubSplitInterface, LGBMHyperoptAlgInterface, \
@@ -34,8 +36,10 @@
3436
NNHyperoptAlgInterface
3537
from pytabkit.models.alg_interfaces.xgboost_interfaces import XGBSubSplitInterface, XGBHyperoptAlgInterface, \
3638
XGBSklearnSubSplitInterface, RandomParamsXGBAlgInterface
39+
from pytabkit.models.alg_interfaces.xrfm_interfaces import xRFMSubSplitInterface, RandomParamsxRFMAlgInterface
3740
from pytabkit.models.data.data import TaskType, DictDataset
3841
from pytabkit.models.nn_models.models import PreprocessingFactory
42+
from pytabkit.models.torch_utils import TorchTimer
3943
from pytabkit.models.training.logging import Logger
4044
from pytabkit.models.training.metrics import Metrics
4145

@@ -115,15 +119,13 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod
115119

116120
interface_resources = assigned_resources.get_interface_resources()
117121

118-
119122
old_torch_n_threads = torch.get_num_threads()
120123
old_torch_n_interop_threads = torch.get_num_interop_threads()
121124
torch.set_num_threads(interface_resources.n_threads)
122125
# don't set this because it can throw
123126
# Error: cannot set number of interop threads after parallel work has started or set_num_interop_threads called
124127
# torch.set_num_interop_threads(interface_resources.n_threads)
125128

126-
127129
ds = task.ds
128130
name = 'alg ' + task_package.alg_name + ' on task ' + str(task_desc)
129131

@@ -185,22 +187,33 @@ def run(self, task_package: TaskPackage, logger: Logger, assigned_resources: Nod
185187

186188
rms = {name: [ResultManager() for _ in task_package.split_infos] for name in pred_param_names}
187189

188-
cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name)
190+
with TorchTimer() as cv_fit_timer:
191+
cv_alg_interface.fit(ds, cv_idxs_list, interface_resources, logger, cv_tmp_folders, name)
189192

190193
for pred_param_name in pred_param_names:
191194
cv_alg_interface.set_current_predict_params(pred_param_name)
192195

193-
cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds)
196+
with TorchTimer() as cv_eval_timer:
197+
cv_results_list = cv_alg_interface.eval(ds, cv_idxs_list, metrics, return_preds)
194198

195199
for rm, cv_results in zip(rms[pred_param_name], cv_results_list):
196-
rm.add_results(is_cv=True, results_dict=cv_results.get_dict())
200+
rm.add_results(is_cv=True, results_dict=cv_results.get_dict() |
201+
dict(fit_time_s=cv_fit_timer.elapsed,
202+
eval_time_s=cv_eval_timer.elapsed))
197203

198204
if n_refit > 0:
199205
refit_alg_interface = cv_alg_interface.get_refit_interface(n_refit)
200-
refit_results_list = refit_alg_interface.fit_and_eval(ds, refit_idxs_list, interface_resources, logger,
201-
refit_tmp_folders, name, metrics, return_preds)
206+
207+
with TorchTimer() as refit_fit_timer:
208+
refit_alg_interface.fit(ds, refit_idxs_list, interface_resources, logger, refit_tmp_folders, name)
209+
210+
with TorchTimer() as refit_eval_timer:
211+
refit_results_list = refit_alg_interface.eval(ds, refit_idxs_list, metrics, return_preds)
202212
for rm, refit_results in zip(rms[pred_param_name], refit_results_list):
203-
rm.add_results(is_cv=False, results_dict=refit_results.get_dict())
213+
rm.add_results(is_cv=False,
214+
results_dict=refit_results.get_dict() |
215+
dict(fit_time_s=refit_fit_timer.elapsed,
216+
eval_time_s=refit_eval_timer.elapsed))
204217

205218
torch.set_num_threads(old_torch_n_threads)
206219
# torch.set_num_interop_threads(old_torch_n_interop_threads)
@@ -578,3 +591,19 @@ class RandomParamsLinearModelInterfaceWrapper(AlgInterfaceWrapper):
578591
def __init__(self, model_idx: int, **config):
579592
# model_idx should be the random search iteration (i.e. start from zero)
580593
super().__init__(RandomParamsLinearModelAlgInterface, model_idx=model_idx, **config)
594+
595+
596+
class TabPFNV2InterfaceWrapper(SubSplitInterfaceWrapper):
597+
def create_sub_split_interface(self, task_type: TaskType) -> AlgInterface:
598+
return TabPFNV2SubSplitInterface(**self.config)
599+
600+
601+
class xRFMInterfaceWrapper(SubSplitInterfaceWrapper):
602+
def create_sub_split_interface(self, task_type: TaskType) -> AlgInterface:
603+
return xRFMSubSplitInterface(**self.config)
604+
605+
606+
class RandomParamsxRFMInterfaceWrapper(MultiSplitAlgInterfaceWrapper):
607+
def create_single_alg_interface(self, n_cv: int, task_type: TaskType) \
608+
-> AlgInterface:
609+
return RandomParamsxRFMAlgInterface(**self.config)

pytabkit/bench/eval/tables.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
from pytabkit.models.data.data import TaskType
1212
from pytabkit.models.data.nested_dict import NestedDict
1313

14-
15-
def _get_table_str(table_head: List[List[str]], table_body: List[List[str]]):
16-
head_row_strs = [' & '.join(row) + r' \\' for row in table_head]
17-
body_row_strs = [' & '.join(row) + r' \\' for row in table_body]
18-
n_cols = max(len(row) for row in table_head + table_body)
14+
def _get_table_str(*parts: List[List[str]]):
15+
part_rows = [[' & '.join(row) + r' \\' for row in part] for part in parts]
16+
n_cols = max(len(row) for part in parts for row in part)
1917
begin_table_str = r'\begin{tabular}{' + ('c' * n_cols) + r'}' + '\n' + r'\toprule'
2018
end_table_str = r'\bottomrule' + '\n' + r'\end{tabular}'
21-
all_row_strs = [begin_table_str] + head_row_strs + [r'\midrule'] + body_row_strs + [end_table_str]
19+
all_row_strs = [begin_table_str]
20+
for part in part_rows[:-1]:
21+
all_row_strs.extend(part)
22+
all_row_strs.append(r'\midrule')
23+
all_row_strs.extend(part_rows[-1])
24+
all_row_strs.append(end_table_str)
2225
complete_str = '\n'.join(all_row_strs)
2326
return complete_str
2427

@@ -208,7 +211,7 @@ def generate_ablations_table(paths: Paths, tables: ResultsTables):
208211
(r'Activation=SELU', 'act-selu'),
209212
('', ''),
210213
('No dropout', 'pdrop-0.0'),
211-
('Dropout prob.\ $0.15$ (constant)', 'pdrop-0.15'),
214+
(r'Dropout prob.\ $0.15$ (constant)', 'pdrop-0.15'),
212215
('', ''),
213216
('No weight decay', 'wd-0.0'),
214217
# ('Weight decay = 0.02 ($\operatorname{flat\_cos}$)', 'wd-0.02-flatcos'),

pytabkit/models/alg_interfaces/alg_interfaces.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import warnings
23
from pathlib import Path
34
from typing import List, Tuple, Any, Optional, Dict
45

@@ -246,6 +247,9 @@ def get_current_predict_params_dict(self):
246247
def set_current_predict_params(self, name: str) -> None:
247248
self.curr_pred_params_name = name
248249

250+
def to(self, device: str) -> None:
251+
warnings.warn(f'.to() method does nothing for {self.__class__} (not implemented)')
252+
249253

250254
class MultiSplitWrapperAlgInterface(AlgInterface):
251255
# todo: do we need the option to run this with a "split batch size" > 1 for the NNInterface?

pytabkit/models/alg_interfaces/calibration.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,52 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
4141
self.alg_interface.fit(ds, idxs_list, interface_resources, logger, tmp_folders, name)
4242
y_preds = self.alg_interface.predict(ds)
4343

44-
for tt_split_idx, split_idxs in enumerate(idxs_list):
45-
for tv_split_idx in range(split_idxs.n_trainval_splits):
46-
val_idxs = split_idxs.val_idxs[tv_split_idx]
47-
y = ds.tensors['y'][val_idxs]
48-
y_pred = y_preds[len(self.calibrators), val_idxs]
49-
y_pred_probs = torch.softmax(y_pred, dim=-1)
44+
self.n_tv_splits_list_ = [idxs.n_trainval_splits for idxs in idxs_list]
45+
46+
if self.config.get('calibrate_per_fold', True):
47+
for tt_split_idx, split_idxs in enumerate(idxs_list):
48+
for tv_split_idx in range(split_idxs.n_trainval_splits):
49+
val_idxs = split_idxs.val_idxs[tv_split_idx]
50+
y = ds.tensors['y'][val_idxs]
51+
y_pred = y_preds[len(self.calibrators), val_idxs]
52+
y_pred_probs = torch.softmax(y_pred, dim=-1)
53+
54+
import probmetrics.calibrators
55+
import probmetrics.distributions
56+
calib = probmetrics.calibrators.get_calibrator(**self.config)
57+
if self.config.get('calibrate_with_logits', True):
58+
calib.fit_torch(y_pred=probmetrics.distributions.CategoricalLogits(y_pred.detach().cpu()),
59+
y_true_labels=y[:, 0])
60+
else:
61+
calib.fit(self._transform_probs(y_pred_probs.detach().cpu().numpy()), y.cpu().numpy()[:, 0])
62+
63+
self.calibrators.append(calib)
64+
self.n_calibs.append(val_idxs.shape[-1])
65+
else:
66+
y_pred_idx = 0
67+
for tt_split_idx, split_idxs in enumerate(idxs_list):
68+
y_pred_list = []
69+
y_list = []
70+
for tv_split_idx in range(split_idxs.n_trainval_splits):
71+
val_idxs = split_idxs.val_idxs[tv_split_idx]
72+
y_pred_list.append(y_preds[y_pred_idx, val_idxs])
73+
y_list.append(ds.tensors['y'][val_idxs])
74+
y_pred_idx += 1
75+
76+
y_pred = torch.cat(y_pred_list, dim=0)
77+
y = torch.cat(y_list, dim=0)
5078

5179
import probmetrics.calibrators
5280
import probmetrics.distributions
5381
calib = probmetrics.calibrators.get_calibrator(**self.config)
5482
if self.config.get('calibrate_with_logits', True):
5583
calib.fit_torch(y_pred=probmetrics.distributions.CategoricalLogits(y_pred.detach().cpu()),
56-
y_true_labels=y[:, 0])
84+
y_true_labels=y[:, 0].detach().cpu())
5785
else:
58-
calib.fit(self._transform_probs(y_pred_probs.detach().cpu().numpy()), y.cpu().numpy()[:, 0])
86+
calib.fit(self._transform_probs(torch.softmax(y_pred, dim=-1).detach().cpu().numpy()), y.cpu().numpy()[:, 0])
5987

60-
self.calibrators.append(calib)
61-
self.n_calibs.append(val_idxs.shape[-1])
88+
self.calibrators.extend([calib] * split_idxs.n_trainval_splits)
89+
self.n_calibs.extend([y_pred.shape[0]] * split_idxs.n_trainval_splits)
6290

6391
self.fit_params = [dict(sub_fit_params=fp) for fp in self.alg_interface.fit_params]
6492

@@ -68,6 +96,15 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
6896
y_preds = self.alg_interface.predict(ds)
6997
y_preds_probs = torch.softmax(y_preds, dim=-1)
7098
y_preds_calib = []
99+
100+
if self.config.get('ensemble_before_calib', False):
101+
start_idx = 0
102+
for n_tv_splits in self.n_tv_splits_list_:
103+
avg_probs = y_preds_probs[start_idx:start_idx+n_tv_splits].mean(dim=0, keepdim=True)
104+
y_preds_probs[start_idx:start_idx + n_tv_splits] = avg_probs
105+
start_idx += n_tv_splits
106+
y_preds = torch.log(y_preds_probs + 1e-30)
107+
71108
for i in range(y_preds.shape[0]):
72109
if self.config.get('calibrate_with_logits', True):
73110
from probmetrics.distributions import CategoricalLogits
@@ -90,3 +127,7 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
90127
def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_splits: int,
91128
split_seeds: List[int], n_train: int) -> RequiredResources:
92129
return self.alg_interface.get_required_resources(ds, n_cv, n_refit, n_splits, split_seeds, n_train=n_train)
130+
131+
def to(self, device: str) -> None:
132+
self.alg_interface.to(device)
133+

pytabkit/models/alg_interfaces/ensemble_interfaces.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_spl
176176
for ssi in self.alg_interfaces]
177177
return RequiredResources.combine_sequential(single_resources)
178178

179+
def to(self, device: str) -> None:
180+
for alg_idx, alg_ctx in enumerate(self.alg_contexts_):
181+
with alg_ctx as alg_interface:
182+
alg_interface.to(device)
183+
184+
179185

180186
class AlgorithmSelectionAlgInterface(SingleSplitAlgInterface):
181187
"""
@@ -277,6 +283,12 @@ def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_spl
277283
for ssi in self.alg_interfaces]
278284
return RequiredResources.combine_sequential(single_resources)
279285

286+
def to(self, device: str) -> None:
287+
for alg_idx, alg_ctx in enumerate(self.alg_contexts_):
288+
with alg_ctx as alg_interface:
289+
alg_interface.to(device)
290+
291+
280292

281293
class PrecomputedPredictionsAlgInterface(SingleSplitAlgInterface):
282294
def __init__(self, y_preds_cv: torch.Tensor, y_preds_refit: Optional[torch.Tensor],

pytabkit/models/alg_interfaces/lightgbm_interfaces.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def _get_params(self):
123123
('cat_smooth', None),
124124
('cat_l2', None),
125125
('early_stopping_round', ['early_stopping_round', 'early_stopping_rounds'], None),
126+
('extra_trees', None),
127+
('max_cat_to_onehot', None),
128+
('min_data_per_group', None),
126129
]
127130

128131
params = utils.extract_params(self.config, params_config)
@@ -686,10 +689,10 @@ def _sample_params(self, is_classification: bool, seed: int, n_train: int):
686689
'min_data_in_leaf': np.floor(np.exp(rng.uniform(np.log(1.0), np.log(65)))),
687690
'extra_trees': rng.choice([False, True]),
688691

689-
'min_data_per_group': np.floor(np.exp(rng.uniform(np.log(2.0), np.log(101)))),
692+
'min_data_per_group': round(np.floor(np.exp(rng.uniform(np.log(2.0), np.log(101))))),
690693
'cat_l2': np.exp(rng.uniform(np.log(5e-3), np.log(2.0))),
691694
'cat_smooth': np.exp(rng.uniform(np.log(1e-3), np.log(100.0))),
692-
'max_cat_to_onehot': np.floor(np.exp(rng.uniform(np.log(8.0), np.log(101.0)))),
695+
'max_cat_to_onehot': round(np.floor(np.exp(rng.uniform(np.log(8.0), np.log(101.0))))),
693696

694697
'lambda_l1': np.exp(rng.uniform(np.log(1e-5), np.log(1.0))),
695698
'lambda_l2': np.exp(rng.uniform(np.log(1e-5), np.log(2.0))),

0 commit comments

Comments
 (0)