Skip to content

Commit 774944a

Browse files
committed
v1.6.0: TabM training metrics, RealMLP ensembles
1 parent c8be0b2 commit 774944a

File tree

11 files changed

+181
-69
lines changed

11 files changed

+181
-69
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pip install pytabkit[models]
5151
[faiss](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md),
5252
which is only available on **conda**.
5353
- Please install torch separately if you want to control the version (CPU/GPU etc.)
54-
- Use `pytabkit[models,autogluon,extra,hpo,bench,dev]` to install additional dependencies for
54+
- Use `pytabkit[models,autogluon,extra,hpo,bench,dev]` to install additional dependencies for the other models,
5555
AutoGluon models, extra preprocessing,
5656
hyperparameter optimization methods beyond random search (hyperopt/SMAC),
5757
the benchmarking part, and testing/documentation. For the hpo part,
@@ -196,6 +196,15 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
196196

197197
## Releases (see git tags)
198198

199+
- v1.6.0:
200+
- Added support for other training losses in TabM through the `train_metric_name` parameter,
201+
for example, (multi)quantile regression via `train_metric_name='multi_pinball(0.05,0.95)'`.
202+
- RealMLP-TD now adds the `n_ens` hyperparameter, which can be set to values >1
203+
to train ensembles per train-validation split (called PackedEnsemble in the TabM paper).
204+
This is especially useful when using holdout validation instead of cross-validation ensembles,
205+
and to get more reliable validation predictions and scores for tuning/ensembling.
206+
- fixed RealMLP TabArena search space (`hpo_space_name='tabarena'`) for classification
207+
(allow no label smoothing through `use_ls=False` instead of `use_ls="auto"`).
199208
- v1.5.2: fixed more device bugs for HPO and ensembling
200209
- v1.5.1: fixed a device bug in TabM for GPU
201210
- v1.5.0:

pytabkit/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.5.2"
5+
__version__ = "1.6.0"

pytabkit/models/alg_interfaces/ensemble_interfaces.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,15 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
6464
with alg_ctx as alg_interface:
6565
sub_tmp_folders = [tmp_folder / str(alg_idx) if tmp_folder is not None else None for tmp_folder in
6666
tmp_folders]
67-
alg_interface.fit(ds, idxs_list, interface_resources, logger, sub_tmp_folders, name + f'sub-alg-{alg_idx}')
67+
if self.config.get('diversify_seeds', False):
68+
sub_idxs_list = [SplitIdxs(train_idxs=idxs.train_idxs, val_idxs=idxs.val_idxs,
69+
test_idxs=idxs.test_idxs, split_seed=idxs.split_seed + alg_idx,
70+
sub_split_seeds=[sss + alg_idx for sss in idxs.sub_split_seeds],
71+
split_id=idxs.split_id) for idxs in idxs_list]
72+
else:
73+
sub_idxs_list = idxs_list
74+
alg_interface.fit(ds, sub_idxs_list, interface_resources, logger, sub_tmp_folders,
75+
name + f'sub-alg-{alg_idx}')
6876
sub_fit_params.append(alg_interface.get_fit_params()[0])
6977

7078
if self.fit_params is not None:
@@ -132,7 +140,6 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
132140

133141
weights[weight_idx] += 1
134142

135-
136143
if best_step_loss < best_loss:
137144
best_loss = best_step_loss
138145
best_weights = np.copy(best_step_weights)
@@ -202,7 +209,7 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
202209
tmp_folders]
203210
with self.alg_contexts_[best_alg_idx] as alg_interface:
204211
alg_interface.fit(ds, idxs_list, interface_resources, logger, sub_tmp_folders,
205-
name + f'sub-alg-{best_alg_idx}')
212+
name + f'sub-alg-{best_alg_idx}')
206213

207214
return
208215

@@ -228,7 +235,8 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
228235
with alg_ctx as alg_interface:
229236
sub_tmp_folders = [tmp_folder / str(alg_idx) if tmp_folder is not None else None for tmp_folder in
230237
tmp_folders]
231-
alg_interface.fit(ds, idxs_list, interface_resources, logger, sub_tmp_folders, name + f'sub-alg-{alg_idx}')
238+
alg_interface.fit(ds, idxs_list, interface_resources, logger, sub_tmp_folders,
239+
name + f'sub-alg-{alg_idx}')
232240
y_preds = alg_interface.predict(ds)
233241
# get out-of-bag predictions
234242
y_pred_oob = cat_if_necessary([y_preds[j, idxs_list[0].val_idxs[j]]

pytabkit/models/alg_interfaces/nn_interfaces.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_spl
167167
static_tensor_infos = static_fitter.forward_tensor_infos(tensor_infos)
168168
n_params = fitter.get_n_params(tensor_infos)
169169
n_forward = fitter.get_n_forward(tensor_infos)
170-
n_parallel = max(n_cv, n_refit) * n_splits
170+
n_parallel = max(n_cv, n_refit) * n_splits * self.config.get('n_ens', 1)
171171
batch_size = self.config.get('batch_size', 256)
172172
if batch_size == 'auto':
173173
batch_size = get_realmlp_auto_batch_size(n_train)
@@ -192,6 +192,8 @@ def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_spl
192192
init_ram_gb = min(init_ram_gb_max, init_ram_gb_full)
193193
# init_ram_gb = 1.5
194194

195+
# print(f'{ds_ram_gb=}, {pass_memory/(1024**3)=}, {param_memory/(1024**3)=}, {init_ram_gb=}')
196+
195197
factor = 1.2 # to go safe on ram
196198
gpu_ram_gb = fixed_ram_gb + ds_ram_gb + max(init_ram_gb,
197199
factor * (n_parallel * (pass_memory + param_memory)) / (1024 ** 3))
@@ -639,7 +641,7 @@ def sample_params(self, seed: int) -> Dict[str, Any]:
639641
'p_drop_sched': 'flat_cos',
640642
'lr': np.exp(rng.uniform(np.log(2e-2), np.log(3e-1))),
641643
'wd': np.exp(rng.uniform(np.log(1e-3), np.log(5e-2))),
642-
'use_ls': rng.choice(["auto", True]), # use label smoothing (will be ignored for regression)
644+
'use_ls': rng.choice([False, True]), # use label smoothing (will be ignored for regression)
643645
}
644646

645647
if rng.uniform(0.0, 1.0) > 0.5:
@@ -685,7 +687,7 @@ def _create_sub_interface(self, ds: DictDataset, seed: int):
685687
is_classification = not ds.tensor_infos['y'].is_cont()
686688
self.fit_params = [RealMLPParamSampler(is_classification, **self.config).sample_params(hparam_seed)]
687689
# todo: need epoch for refit
688-
params = utils.update_dict(self.config, self.fit_params[0])
690+
params = utils.join_dicts(self.config, self.fit_params[0], self.config.get('override_params', dict()) or dict())
689691
# params = utils.update_dict(self.fit_params[0], self.config)
690692
if 'n_epochs' in self.config:
691693
params['n_epochs'] = self.config['n_epochs']

pytabkit/models/alg_interfaces/tabm_interface.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import math
23
import random
34
from pathlib import Path
@@ -76,6 +77,7 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
7677
# set default to True for backward compatibility
7778
share_training_batches = self.config.get("share_training_batches", False)
7879
val_metric_name = self.config.get('val_metric_name', None)
80+
train_metric_name = self.config.get('train_metric_name', None)
7981

8082
weight_decay = self.config.get('weight_decay', 0.0)
8183
gradient_clipping_norm = self.config.get('gradient_clipping_norm', None)
@@ -145,9 +147,11 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
145147

146148
Y_train = ds_parts['train'].tensors['y'].clone()
147149
if task_type == 'regression':
148-
assert ds.tensor_infos['y'].get_n_features() == 1
149-
self.y_mean_ = ds_parts['train'].tensors['y'].mean().item()
150-
self.y_std_ = ds_parts['train'].tensors['y'].std(correction=0).item()
150+
assert Y_train.shape[-1] == 1
151+
self.y_mean_ = ds_parts['train'].tensors['y'].mean(dim=0, keepdim=True).item()
152+
self.y_std_ = ds_parts['train'].tensors['y'].std(dim=0, keepdim=True, correction=0).item()
153+
self.y_max_ = ds_parts['train'].tensors['y'].max().item()
154+
self.y_min_ = ds_parts['train'].tensors['y'].min().item()
151155

152156
Y_train = (Y_train - self.y_mean_) / (self.y_std_ + 1e-30)
153157

@@ -170,7 +174,7 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
170174
else None
171175
)
172176
# Changing False to True will result in faster training on compatible hardware.
173-
amp_enabled = allow_amp and amp_dtype is not None
177+
amp_enabled = allow_amp and amp_dtype is not None and device.type == 'cuda'
174178
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None # type: ignore
175179

176180
# fmt: off
@@ -186,11 +190,14 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
186190

187191
# TabM
188192
bins = None if num_emb_type != 'pwl' or n_cont_features == 0 else rtdl_num_embeddings.compute_bins(data['train']['x_cont'], n_bins=num_emb_n_bins)
193+
d_out = n_classes if n_classes > 0 else 1
194+
if train_metric_name is not None and train_metric_name.startswith('multi_pinball'):
195+
d_out = train_metric_name.count(',')+1
189196

190197
model = Model(
191198
n_num_features=n_cont_features,
192199
cat_cardinalities=cat_cardinalities,
193-
n_classes=n_classes if n_classes > 0 else None,
200+
n_classes=d_out,
194201
backbone={
195202
'type': 'MLP',
196203
'n_blocks': n_blocks if n_blocks != 'auto' else (3 if bins is None else 2),
@@ -212,6 +219,27 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
212219
k=tabm_k,
213220
share_training_batches=share_training_batches,
214221
).to(device)
222+
223+
# import tabm
224+
# num_embeddings = None if bins is None else rtdl_num_embeddings.PiecewiseLinearEmbeddings(
225+
# bins=bins,
226+
# d_embedding=d_embedding,
227+
# activation=False,
228+
# version='B',
229+
# )
230+
# model = tabm.TabM(
231+
# n_num_features=n_cont_features,
232+
# cat_cardinalities=cat_cardinalities,
233+
# d_out = n_classes if n_classes > 0 else 1,
234+
# num_embeddings = num_embeddings,
235+
# n_blocks=n_blocks if n_blocks != 'auto' else (3 if bins is None else 2),
236+
# d_block=d_block,
237+
# dropout=dropout,
238+
# arch_type=arch_type,
239+
# k=tabm_k,
240+
# # todo: can introduce activation
241+
# share_training_batches=share_training_batches, # todo: disappeared?
242+
# )
215243
optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay)
216244

217245

@@ -231,11 +259,17 @@ def apply_model(part: str, idx: torch.Tensor) -> torch.Tensor:
231259
data[part]['x_cont'][idx],
232260
data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
233261
)
234-
.squeeze(-1) # Remove the last dimension for regression tasks.
235262
.float()
236263
)
237264

238-
base_loss_fn = torch.nn.functional.mse_loss if task_type == 'regression' else torch.nn.functional.cross_entropy
265+
if train_metric_name is None:
266+
base_loss_fn = torch.nn.functional.mse_loss if self.n_classes_ == 0 else torch.nn.functional.cross_entropy # defaults
267+
elif train_metric_name == 'mse':
268+
base_loss_fn = torch.nn.functional.mse_loss
269+
elif train_metric_name == 'cross_entropy':
270+
base_loss_fn = torch.nn.functional.cross_entropy
271+
else:
272+
base_loss_fn = functools.partial(Metrics.apply, metric_name=train_metric_name)
239273

240274
def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
241275
# TabM produces k predictions per object. Each of them must be trained separately.
@@ -244,7 +278,7 @@ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
244278
k = y_pred.shape[1]
245279
return base_loss_fn(
246280
y_pred.flatten(0, 1),
247-
y_true.repeat_interleave(k) if model.share_training_batches else y_true.squeeze(-1),
281+
y_true.repeat_interleave(k) if model.share_training_batches else y_true,
248282
)
249283

250284
@evaluation_mode()
@@ -261,7 +295,7 @@ def evaluate(part: str) -> float:
261295
eval_batch_size
262296
)
263297
]
264-
).cpu()
298+
)
265299
)
266300
if task_type == 'regression':
267301
# Transform the predictions back to the original label space.
@@ -278,6 +312,8 @@ def evaluate(part: str) -> float:
278312
y_pred = y_pred.mean(dim=1)
279313

280314
y_true = data[part]['y'].cpu()
315+
y_pred = y_pred.cpu()
316+
281317
if task_type == 'regression' and len(y_true.shape) == 1:
282318
y_true = y_true.unsqueeze(-1)
283319
if task_type == 'regression' and len(y_pred.shape) == 1:
@@ -390,7 +426,6 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
390426
ds.tensors['x_cont'][idx],
391427
ds.tensors['x_cat'][idx] if not ds.tensor_infos['x_cat'].is_empty() else None,
392428
)
393-
.squeeze(-1) # Remove the last dimension for regression tasks.
394429
.float()
395430
for idx in torch.arange(ds.n_samples, device=self.device_).split(
396431
eval_batch_size
@@ -400,9 +435,10 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
400435
)
401436
if self.task_type_ == 'regression':
402437
# Transform the predictions back to the original label space.
403-
y_pred = y_pred * self.y_std_ + self.y_mean_
404438
y_pred = y_pred.mean(1)
405-
y_pred = y_pred.unsqueeze(-1) # add extra "features" dimension
439+
y_pred = y_pred * self.y_std_ + self.y_mean_
440+
if self.config.get('clamp_output', False):
441+
y_pred = torch.clamp(y_pred, self.y_min_, self.y_max_)
406442
else:
407443
average_logits = self.config.get('average_logits', False)
408444
if average_logits:
@@ -411,7 +447,7 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
411447
# For classification, the mean must be computed in the probability space.
412448
y_pred = torch.log(torch.softmax(y_pred, dim=-1).mean(1) + 1e-30)
413449

414-
return y_pred[None] # add n_models dimension
450+
return y_pred[None].cpu() # add n_models dimension
415451

416452
def get_required_resources(self, ds: DictDataset, n_cv: int, n_refit: int, n_splits: int,
417453
split_seeds: List[int], n_train: int) -> RequiredResources:
@@ -440,7 +476,7 @@ def _sample_params(self, is_classification: bool, seed: int, n_train: int):
440476
params = {
441477
"batch_size": "auto",
442478
"patience": 16,
443-
"amp": True,
479+
"allow_amp": True,
444480
"arch_type": "tabm-mini",
445481
"tabm_k": 32,
446482
"gradient_clipping_norm": 1.0,
@@ -461,7 +497,7 @@ def _sample_params(self, is_classification: bool, seed: int, n_train: int):
461497
params = {
462498
"batch_size": "auto",
463499
"patience": 16,
464-
"amp": False, # only for GPU, maybe we should change it to True?
500+
"allow_amp": False, # only for GPU, maybe we should change it to True?
465501
"arch_type": "tabm-mini",
466502
"tabm_k": 32,
467503
"gradient_clipping_norm": 1.0,

pytabkit/models/sklearn/sklearn_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ def fit(self, X, y, X_val: Optional = None, y_val: Optional = None, val_idxs: Op
346346
if val_idxs.shape[1] == 0:
347347
val_idxs = None # no validation set
348348

349+
# print(f'{val_idxs=}')
350+
# print(f'{np.mean(X / (1e-8 + np.linalg.norm(X, axis=0, keepdims=True)))=}')
351+
349352
idxs_list = [SplitIdxs(train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=None, split_seed=split_seed,
350353
sub_split_seeds=sub_split_seeds, split_id=0)]
351354

pytabkit/models/sklearn/sklearn_interfaces.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(self, device: Optional[str] = None, random_state: Optional[Union[in
8686
calibration_method: Optional[str] = None,
8787
sort_quantile_predictions: Optional[bool] = None,
8888
stop_epoch: Optional[int] = None,
89+
use_best_mean_epoch_for_cv: Optional[bool] = None,
90+
n_ens: Optional[int] = None,
8991
):
9092
"""
9193
Constructor for RealMLP, using the default parameters from RealMLP-TD.
@@ -251,6 +253,11 @@ def __init__(self, device: Optional[str] = None, random_state: Optional[Union[in
251253
Epoch at which training should be stopped (for refitting).
252254
The total length of training used for the schedules will be determined by n_epochs,
253255
but the stopping epoch will be min(stop_epoch, n_epochs).
256+
:param use_best_mean_epoch_for_cv: If training an ensemble,
257+
whether they should all use a checkpoint from the same epoch with the best average loss,
258+
instead of using the best individual epochs (default=False).
259+
:param n_ens: Number of ensemble members that should be used per train-validation split (default=1).
260+
For best-epoch selection, the validation scores of averaged predictions will be used.
254261
"""
255262
super().__init__() # call the constructor of the other superclass for multiple inheritance
256263
self.device = device
@@ -323,6 +330,8 @@ def __init__(self, device: Optional[str] = None, random_state: Optional[Union[in
323330
self.calibration_method = calibration_method
324331
self.sort_quantile_predictions = sort_quantile_predictions
325332
self.stop_epoch = stop_epoch
333+
self.use_best_mean_epoch_for_cv = use_best_mean_epoch_for_cv
334+
self.n_ens = n_ens
326335

327336

328337
class RealMLP_TD_Classifier(RealMLPConstructorMixin, AlgInterfaceClassifier):
@@ -1762,6 +1771,7 @@ def __init__(self, device: Optional[str] = None, random_state: Optional[Union[in
17621771
calibration_method: Optional[str] = None,
17631772
share_training_batches: Optional[bool] = None,
17641773
val_metric_name: Optional[str] = None,
1774+
train_metric_name: Optional[str] = None,
17651775
):
17661776
"""
17671777
@@ -1826,6 +1836,9 @@ def __init__(self, device: Optional[str] = None, random_state: Optional[Union[in
18261836
:param val_metric_name: Name of the validation metric used for early stopping.
18271837
For classification, the default is 'class_error' but could be 'cross_entropy', 'brier', '1-auc_ovr' etc.
18281838
For regression, the default is 'rmse' but could be 'mae'.
1839+
:param train_metric_name: Name of the metric (loss) used for training.
1840+
For classification, the default is 'cross_entropy'.
1841+
For regression, it is 'mse' but could be set to something like 'multi_pinball(0.05,0.95)'.
18291842
"""
18301843
self.device = device
18311844
self.random_state = random_state

0 commit comments

Comments
 (0)