File tree Expand file tree Collapse file tree 3 files changed +22
-8
lines changed
pytabkit/models/nn_models Expand file tree Collapse file tree 3 files changed +22
-8
lines changed Original file line number Diff line number Diff line change @@ -27,7 +27,7 @@ seaborn>=0.0.13
2727skorch >= 0.15
2828sphinx >= 7.0
2929sphinx_rtd_theme >= 2.0
30- torch >= 2.0
30+ torch >= 2.0 , < 2.6
3131torchmetrics >= 1.2.1
3232tqdm
3333tueplots >= 0.0.12
Original file line number Diff line number Diff line change @@ -26,7 +26,9 @@ classifiers = [
2626 " License :: OSI Approved :: Apache Software License" ,
2727]
2828dependencies = [
29- " torch>=2.0" ,
29+ # use <2.6 for now since it can run into pickling issues with skorch if the skorch version is too old
30+ # see https://github.com/skorch-dev/skorch/commit/be93b7769d61aa22fb928d2e89e258c629bfeaf9
31+ " torch>=2.0,<2.6" ,
3032 " numpy>=1.25,<2.0" ,
3133 " pandas>=2.0" ,
3234 " scikit-learn>=1.3,<1.6" ,
Original file line number Diff line number Diff line change @@ -960,10 +960,8 @@ def create_regressor_skorch(
960960 model_class = RTDL_MLP
961961 else :
962962 raise ValueError (f'Model { model_name } not implemented here! Choose from "ft_transformer", "resnet", "mlp"' )
963- model = nn_class (
964- model_class ,
965- # Shuffle training data on each epoch
966- optimizer = optimizer ,
963+
964+ new_kwargs = dict (optimizer = optimizer ,
967965 batch_size = max (
968966 batch_size , 1
969967 ), # if batch size is float, it will be reset during fit
@@ -974,8 +972,22 @@ def create_regressor_skorch(
974972 module__regression = True ,
975973 module__categorical_indicator = None , # will be change when fitted
976974 callbacks = callbacks ,
977- ** kwargs ,
978- )
975+ ** kwargs )
976+
977+ try :
978+ # try the torch_load_kwargs but it's only available in newer versions of skorch
979+ model = nn_class (
980+ model_class ,
981+ # Shuffle training data on each epoch
982+ ** new_kwargs ,
983+ torch_load_kwargs = {'weights_only' : False }, # quick-fix for pickling errors in torch>=2.6
984+ )
985+ except ValueError :
986+ model = nn_class (
987+ model_class ,
988+ # Shuffle training data on each epoch
989+ ** new_kwargs ,
990+ )
979991
980992 return model
981993
You can’t perform that action at this time.
0 commit comments