1+ import functools
12import math
23import random
34from pathlib import Path
@@ -76,6 +77,7 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
7677 # set default to True for backward compatibility
7778 share_training_batches = self .config .get ("share_training_batches" , False )
7879 val_metric_name = self .config .get ('val_metric_name' , None )
80+ train_metric_name = self .config .get ('train_metric_name' , None )
7981
8082 weight_decay = self .config .get ('weight_decay' , 0.0 )
8183 gradient_clipping_norm = self .config .get ('gradient_clipping_norm' , None )
@@ -145,9 +147,11 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
145147
146148 Y_train = ds_parts ['train' ].tensors ['y' ].clone ()
147149 if task_type == 'regression' :
148- assert ds .tensor_infos ['y' ].get_n_features () == 1
149- self .y_mean_ = ds_parts ['train' ].tensors ['y' ].mean ().item ()
150- self .y_std_ = ds_parts ['train' ].tensors ['y' ].std (correction = 0 ).item ()
150+ assert Y_train .shape [- 1 ] == 1
151+ self .y_mean_ = ds_parts ['train' ].tensors ['y' ].mean (dim = 0 , keepdim = True ).item ()
152+ self .y_std_ = ds_parts ['train' ].tensors ['y' ].std (dim = 0 , keepdim = True , correction = 0 ).item ()
153+ self .y_max_ = ds_parts ['train' ].tensors ['y' ].max ().item ()
154+ self .y_min_ = ds_parts ['train' ].tensors ['y' ].min ().item ()
151155
152156 Y_train = (Y_train - self .y_mean_ ) / (self .y_std_ + 1e-30 )
153157
@@ -170,7 +174,7 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
170174 else None
171175 )
172176 # Changing False to True will result in faster training on compatible hardware.
173- amp_enabled = allow_amp and amp_dtype is not None
177+ amp_enabled = allow_amp and amp_dtype is not None and device . type == 'cuda'
174178 grad_scaler = torch .cuda .amp .GradScaler () if amp_dtype is torch .float16 else None # type: ignore
175179
176180 # fmt: off
@@ -186,11 +190,14 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
186190
187191 # TabM
188192 bins = None if num_emb_type != 'pwl' or n_cont_features == 0 else rtdl_num_embeddings .compute_bins (data ['train' ]['x_cont' ], n_bins = num_emb_n_bins )
193+ d_out = n_classes if n_classes > 0 else 1
194+ if train_metric_name is not None and train_metric_name .startswith ('multi_pinball' ):
195+ d_out = train_metric_name .count (',' )+ 1
189196
190197 model = Model (
191198 n_num_features = n_cont_features ,
192199 cat_cardinalities = cat_cardinalities ,
193- n_classes = n_classes if n_classes > 0 else None ,
200+ n_classes = d_out ,
194201 backbone = {
195202 'type' : 'MLP' ,
196203 'n_blocks' : n_blocks if n_blocks != 'auto' else (3 if bins is None else 2 ),
@@ -212,6 +219,27 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
212219 k = tabm_k ,
213220 share_training_batches = share_training_batches ,
214221 ).to (device )
222+
223+ # import tabm
224+ # num_embeddings = None if bins is None else rtdl_num_embeddings.PiecewiseLinearEmbeddings(
225+ # bins=bins,
226+ # d_embedding=d_embedding,
227+ # activation=False,
228+ # version='B',
229+ # )
230+ # model = tabm.TabM(
231+ # n_num_features=n_cont_features,
232+ # cat_cardinalities=cat_cardinalities,
233+ # d_out = n_classes if n_classes > 0 else 1,
234+ # num_embeddings = num_embeddings,
235+ # n_blocks=n_blocks if n_blocks != 'auto' else (3 if bins is None else 2),
236+ # d_block=d_block,
237+ # dropout=dropout,
238+ # arch_type=arch_type,
239+ # k=tabm_k,
240+ # # todo: can introduce activation
241+ # share_training_batches=share_training_batches, # todo: disappeared?
242+ # )
215243 optimizer = torch .optim .AdamW (make_parameter_groups (model ), lr = lr , weight_decay = weight_decay )
216244
217245
@@ -231,11 +259,17 @@ def apply_model(part: str, idx: torch.Tensor) -> torch.Tensor:
231259 data [part ]['x_cont' ][idx ],
232260 data [part ]['x_cat' ][idx ] if 'x_cat' in data [part ] else None ,
233261 )
234- .squeeze (- 1 ) # Remove the last dimension for regression tasks.
235262 .float ()
236263 )
237264
238- base_loss_fn = torch .nn .functional .mse_loss if task_type == 'regression' else torch .nn .functional .cross_entropy
265+ if train_metric_name is None :
266+ base_loss_fn = torch .nn .functional .mse_loss if self .n_classes_ == 0 else torch .nn .functional .cross_entropy # defaults
267+ elif train_metric_name == 'mse' :
268+ base_loss_fn = torch .nn .functional .mse_loss
269+ elif train_metric_name == 'cross_entropy' :
270+ base_loss_fn = torch .nn .functional .cross_entropy
271+ else :
272+ base_loss_fn = functools .partial (Metrics .apply , metric_name = train_metric_name )
239273
240274 def loss_fn (y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
241275 # TabM produces k predictions per object. Each of them must be trained separately.
@@ -244,7 +278,7 @@ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
244278 k = y_pred .shape [1 ]
245279 return base_loss_fn (
246280 y_pred .flatten (0 , 1 ),
247- y_true .repeat_interleave (k ) if model .share_training_batches else y_true . squeeze ( - 1 ) ,
281+ y_true .repeat_interleave (k ) if model .share_training_batches else y_true ,
248282 )
249283
250284 @evaluation_mode ()
@@ -261,7 +295,7 @@ def evaluate(part: str) -> float:
261295 eval_batch_size
262296 )
263297 ]
264- ). cpu ()
298+ )
265299 )
266300 if task_type == 'regression' :
267301 # Transform the predictions back to the original label space.
@@ -278,6 +312,8 @@ def evaluate(part: str) -> float:
278312 y_pred = y_pred .mean (dim = 1 )
279313
280314 y_true = data [part ]['y' ].cpu ()
315+ y_pred = y_pred .cpu ()
316+
281317 if task_type == 'regression' and len (y_true .shape ) == 1 :
282318 y_true = y_true .unsqueeze (- 1 )
283319 if task_type == 'regression' and len (y_pred .shape ) == 1 :
@@ -390,7 +426,6 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
390426 ds .tensors ['x_cont' ][idx ],
391427 ds .tensors ['x_cat' ][idx ] if not ds .tensor_infos ['x_cat' ].is_empty () else None ,
392428 )
393- .squeeze (- 1 ) # Remove the last dimension for regression tasks.
394429 .float ()
395430 for idx in torch .arange (ds .n_samples , device = self .device_ ).split (
396431 eval_batch_size
@@ -400,9 +435,10 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
400435 )
401436 if self .task_type_ == 'regression' :
402437 # Transform the predictions back to the original label space.
403- y_pred = y_pred * self .y_std_ + self .y_mean_
404438 y_pred = y_pred .mean (1 )
405- y_pred = y_pred .unsqueeze (- 1 ) # add extra "features" dimension
439+ y_pred = y_pred * self .y_std_ + self .y_mean_
440+ if self .config .get ('clamp_output' , False ):
441+ y_pred = torch .clamp (y_pred , self .y_min_ , self .y_max_ )
406442 else :
407443 average_logits = self .config .get ('average_logits' , False )
408444 if average_logits :
@@ -411,7 +447,7 @@ def predict(self, ds: DictDataset) -> torch.Tensor:
411447 # For classification, the mean must be computed in the probability space.
412448 y_pred = torch .log (torch .softmax (y_pred , dim = - 1 ).mean (1 ) + 1e-30 )
413449
414- return y_pred [None ] # add n_models dimension
450+ return y_pred [None ]. cpu () # add n_models dimension
415451
416452 def get_required_resources (self , ds : DictDataset , n_cv : int , n_refit : int , n_splits : int ,
417453 split_seeds : List [int ], n_train : int ) -> RequiredResources :
@@ -440,7 +476,7 @@ def _sample_params(self, is_classification: bool, seed: int, n_train: int):
440476 params = {
441477 "batch_size" : "auto" ,
442478 "patience" : 16 ,
443- "amp " : True ,
479+ "allow_amp " : True ,
444480 "arch_type" : "tabm-mini" ,
445481 "tabm_k" : 32 ,
446482 "gradient_clipping_norm" : 1.0 ,
@@ -461,7 +497,7 @@ def _sample_params(self, is_classification: bool, seed: int, n_train: int):
461497 params = {
462498 "batch_size" : "auto" ,
463499 "patience" : 16 ,
464- "amp " : False , # only for GPU, maybe we should change it to True?
500+ "allow_amp " : False , # only for GPU, maybe we should change it to True?
465501 "arch_type" : "tabm-mini" ,
466502 "tabm_k" : 32 ,
467503 "gradient_clipping_norm" : 1.0 ,
0 commit comments