Skip to content

Commit 7db53a6

Browse files
committed
added requirement and skorch argument to avoid skorch unpickling errors in torch 2.6
1 parent 55f472a commit 7db53a6

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ seaborn>=0.0.13
2727
skorch>=0.15
2828
sphinx>=7.0
2929
sphinx_rtd_theme>=2.0
30-
torch>=2.0
30+
torch>=2.0,<2.6
3131
torchmetrics>=1.2.1
3232
tqdm
3333
tueplots>=0.0.12

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ classifiers = [
2626
"License :: OSI Approved :: Apache Software License",
2727
]
2828
dependencies = [
29-
"torch>=2.0",
29+
# use <2.6 for now since it can run into pickling issues with skorch if the skorch version is too old
30+
# see https://github.com/skorch-dev/skorch/commit/be93b7769d61aa22fb928d2e89e258c629bfeaf9
31+
"torch>=2.0,<2.6",
3032
"numpy>=1.25,<2.0",
3133
"pandas>=2.0",
3234
"scikit-learn>=1.3,<1.6",

pytabkit/models/nn_models/rtdl_resnet.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -960,10 +960,8 @@ def create_regressor_skorch(
960960
model_class = RTDL_MLP
961961
else:
962962
raise ValueError(f'Model {model_name} not implemented here! Choose from "ft_transformer", "resnet", "mlp"')
963-
model = nn_class(
964-
model_class,
965-
# Shuffle training data on each epoch
966-
optimizer=optimizer,
963+
964+
new_kwargs = dict(optimizer=optimizer,
967965
batch_size=max(
968966
batch_size, 1
969967
), # if batch size is float, it will be reset during fit
@@ -974,8 +972,22 @@ def create_regressor_skorch(
974972
module__regression=True,
975973
module__categorical_indicator=None, # will be change when fitted
976974
callbacks=callbacks,
977-
**kwargs,
978-
)
975+
**kwargs)
976+
977+
try:
978+
# try the torch_load_kwargs but it's only available in newer versions of skorch
979+
model = nn_class(
980+
model_class,
981+
# Shuffle training data on each epoch
982+
**new_kwargs,
983+
torch_load_kwargs={'weights_only': False}, # quick-fix for pickling errors in torch>=2.6
984+
)
985+
except ValueError:
986+
model = nn_class(
987+
model_class,
988+
# Shuffle training data on each epoch
989+
**new_kwargs,
990+
)
979991

980992
return model
981993

0 commit comments

Comments
 (0)