Skip to content

Commit b0c14fb

Browse files
committed
fixed a sklearn 1.6 issue but disabled it due to skorch incompatibility
1 parent fcae206 commit b0c14fb

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html
136136
## Releases (see git tags)
137137

138138
- v1.1.2:
139-
- Compatibility fixes for scikit-learn 1.6.
139+
- Some compatibility improvements for scikit-learn 1.6
140+
(but disabled 1.6 since skorch is not compatible with it).
140141
- Improved documentation for Pytorch-Lightning interface.
141142
- Other small bugfixes and improvements.
142143
- v1.1.1:

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pytorch_lightning>=2.0
2121
pyyaml>=5.0
2222
ray>=2.8
2323
requests>=2.0
24-
scikit-learn>=1.3
24+
scikit-learn>=1.3,<1.6
2525
seaborn>=0.0.13
2626
skorch>=0.15
2727
sphinx>=7.0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"torch>=2.0",
3030
"numpy>=1.25,<2.0",
3131
"pandas>=2.0",
32-
"scikit-learn>=1.3",
32+
"scikit-learn>=1.3,<1.6",
3333
"xgboost>=2.0",
3434
"catboost>=1.2",
3535
"lightgbm>=4.1",

pytabkit/models/data/conversion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def transform(self, x: Union[np.ndarray, pd.DataFrame, pd.Series, DictDataset])
8787

8888
if set(x.columns) != self.fitted_columns:
8989
print('Raising column error')
90-
raise ValueError(f'Different columns during fit() and predict(): {self.fitted_columns} and {set(x.columns)}')
90+
# second line is to satisfy the sklearn test
91+
# check_n_features_in_after_fitting in scikit-learn >= 1.6
92+
raise ValueError(f'Different columns during fit() and predict(): {self.fitted_columns} and {set(x.columns)}\n'
93+
f'X has {len(x.columns)} features, but estimator is expecting {len(self.fitted_columns)} features as input')
9194

9295
x_cont = torch.as_tensor(self.num_tf.transform(x), dtype=torch.float32)
9396
x_cat = torch.as_tensor(self.cat_tf.transform(x) + 1, dtype=torch.long)

0 commit comments

Comments
 (0)