Skip to content

Commit ecee3e4

Browse files
committed
fixed device bug in TabM for GPU
1 parent 7cf0157 commit ecee3e4

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
196196

197197
## Releases (see git tags)
198198

199+
- v1.5.1: fixed a device bug in TabM for GPU
199200
- v1.5.0:
200201
- added `n_repeats` parameter to scikit-learn interfaces for repeated cross-validation
201202
- HPO sklearn interfaces (the ones using random search)

pytabkit/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.5.0"
5+
__version__ = "1.5.1"

pytabkit/models/alg_interfaces/tabm_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def evaluate(part: str) -> float:
261261
eval_batch_size
262262
)
263263
]
264-
)
264+
).cpu()
265265
)
266266
if task_type == 'regression':
267267
# Transform the predictions back to the original label space.
@@ -277,7 +277,7 @@ def evaluate(part: str) -> float:
277277
if not average_logits:
278278
y_pred = y_pred.mean(dim=1)
279279

280-
y_true = data[part]['y']
280+
y_true = data[part]['y'].cpu()
281281
if task_type == 'regression' and len(y_true.shape) == 1:
282282
y_true = y_true.unsqueeze(-1)
283283
if task_type == 'regression' and len(y_pred.shape) == 1:

0 commit comments

Comments
 (0)