Skip to content

Commit 586d3c3

Browse files
committed
v1.2.1: fix skorch problem, fix failing test for Mac OS
1 parent 617a2ff commit 586d3c3

File tree

5 files changed

+32
-22
lines changed

5 files changed

+32
-22
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
168168

169169
## Releases (see git tags)
170170

171+
- v1.2.1:
172+
- avoid error for older skorch versions
171173
- v1.2.0:
172174
- Included post-hoc calibration and more metrics through
173175
[probmetrics](https://github.com/dholzmueller/probmetrics).

pytabkit/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.2.0"
5+
__version__ = "1.2.1"

pytabkit/models/nn_models/rtdl_resnet.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -974,20 +974,27 @@ def create_regressor_skorch(
974974
callbacks=callbacks,
975975
**kwargs)
976976

977-
try:
978-
# try the torch_load_kwargs but it's only available in newer versions of skorch
979-
model = nn_class(
980-
model_class,
981-
# Shuffle training data on each epoch
982-
**new_kwargs,
983-
torch_load_kwargs={'weights_only': False}, # quick-fix for pickling errors in torch>=2.6
984-
)
985-
except ValueError:
986-
model = nn_class(
987-
model_class,
988-
# Shuffle training data on each epoch
989-
**new_kwargs,
990-
)
977+
# cannot do the try/catch here because params are validated in fit()
978+
# try:
979+
# # try the torch_load_kwargs but it's only available in newer versions of skorch
980+
# model = nn_class(
981+
# model_class,
982+
# # Shuffle training data on each epoch
983+
# **new_kwargs,
984+
# torch_load_kwargs={'weights_only': False}, # quick-fix for pickling errors in torch>=2.6
985+
# )
986+
# except ValueError:
987+
# model = nn_class(
988+
# model_class,
989+
# # Shuffle training data on each epoch
990+
# **new_kwargs,
991+
# )
992+
993+
model = nn_class(
994+
model_class,
995+
# Shuffle training data on each epoch
996+
**new_kwargs,
997+
)
991998

992999
return model
9931000

tests/test_ensemble.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from pytabkit.models.sklearn.sklearn_interfaces import Ensemble_HPO_Classifier, Ensemble_HPO_Regressor
77

88

9-
@pytest.mark.parametrize('model', [Ensemble_TD_Classifier(calibration_method='ts-mix', val_metric_name='ref-ll-ts'),
10-
Ensemble_TD_Regressor(),
11-
Ensemble_HPO_Classifier(calibration_method='ts-mix',
12-
val_metric_name='ref-ll-ts', n_hpo_steps=1),
13-
Ensemble_HPO_Regressor(n_hpo_steps=1),
14-
])
9+
@pytest.mark.parametrize('model', [
10+
Ensemble_TD_Classifier(calibration_method='ts-mix', val_metric_name='ref-ll-ts', device='cpu'),
11+
Ensemble_TD_Regressor(device='cpu'),
12+
Ensemble_HPO_Classifier(calibration_method='ts-mix',
13+
val_metric_name='ref-ll-ts', n_hpo_steps=1, device='cpu'),
14+
Ensemble_HPO_Regressor(n_hpo_steps=1, device='cpu'),
15+
])
1516
def test_ensemble(model):
1617
np.random.seed(0)
1718
X = np.random.randn(100, 2)
@@ -20,4 +21,3 @@ def test_ensemble(model):
2021
y = y > 0.0
2122
model.fit(X, y)
2223
model.predict(X)
23-

tests/test_sklearn_interfaces.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
LGBM_TD_Classifier(min_data_in_leaf=2, subsample=1.0), LGBM_TD_Regressor(subsample=1.0),
1818
XGB_TD_Classifier(), XGB_TD_Regressor(),
1919
CatBoost_TD_Classifier(), CatBoost_TD_Regressor(),
20+
# use CPU to avoid Mac OS errors with MPS backend
2021
RealMLP_TD_Classifier(n_epochs=8, device='cpu'), RealMLP_TD_Regressor(n_epochs=64, device='cpu'),
2122
TabM_D_Classifier(device='cpu', tabm_k=2, num_emb_type='pwl', arch_type='tabm-mini', num_emb_n_bins=2),
2223
TabM_D_Regressor(device='cpu', tabm_k=2, num_emb_type='pwl', arch_type='tabm-mini', num_emb_n_bins=2),

0 commit comments

Comments
 (0)