Skip to content

Commit f1b59c2

Browse files
committed
fixed TabM error
1 parent 81fbf54 commit f1b59c2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pytabkit/models/alg_interfaces/tabm_interface.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,12 @@ def apply_model(part: str, idx: torch.Tensor) -> torch.Tensor:
263263
)
264264

265265
if train_metric_name is None:
266-
base_loss_fn = torch.nn.functional.mse_loss if self.n_classes_ == 0 else torch.nn.functional.cross_entropy # defaults
267-
elif train_metric_name == 'mse':
266+
train_metric_name = 'mse' if self.n_classes_ == 0 else 'cross_entropy'
267+
268+
if train_metric_name == 'mse':
268269
base_loss_fn = torch.nn.functional.mse_loss
269270
elif train_metric_name == 'cross_entropy':
270-
base_loss_fn = torch.nn.functional.cross_entropy
271+
base_loss_fn = lambda a, b: torch.nn.functional.cross_entropy(a, b.squeeze(-1))
271272
else:
272273
base_loss_fn = functools.partial(Metrics.apply, metric_name=train_metric_name)
273274

@@ -276,6 +277,7 @@ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276277
# (regression) y_pred.shape == (batch_size, k)
277278
# (classification) y_pred.shape == (batch_size, k, n_classes)
278279
k = y_pred.shape[1]
280+
print(f'{y_pred.flatten(0, 1).shape=}, {y_true.shape=}')
279281
return base_loss_fn(
280282
y_pred.flatten(0, 1),
281283
y_true.repeat_interleave(k) if model.share_training_batches else y_true,

0 commit comments

Comments
 (0)