Skip to content

Commit 666ca68

Browse files
authored
Adding support for scikit-learn > 0.22 (#936)
* Preliminary changes * Updating unit tests for sklearn 0.22 and above * Triggering sklearn tests + fixes * Refactoring to inspect.signature in extensions
1 parent 9c93f5b commit 666ca68

File tree

5 files changed

+216
-91
lines changed

5 files changed

+216
-91
lines changed

.travis.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ env:
1515
- TEST_DIR=/tmp/test_dir/
1616
- MODULE=openml
1717
matrix:
18-
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
1918
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" RUN_FLAKE8="true" SKIP_TESTS="true"
2019
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" COVERAGE="true" DOCPUSH="true"
20+
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
21+
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
22+
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
23+
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
24+
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
2125
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.20.2"
2226
# Checks for older scikit-learn versions (which also don't nicely work with
2327
# Python3.7)

openml/extensions/sklearn/extension.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -994,12 +994,16 @@ def _get_fn_arguments_with_defaults(self, fn_name: Callable) -> Tuple[Dict, Set]
994994
a set with all parameters that do not have a default value
995995
"""
996996
# parameters with defaults are optional, all others are required.
997-
signature = inspect.getfullargspec(fn_name)
998-
if signature.defaults:
999-
optional_params = dict(zip(reversed(signature.args), reversed(signature.defaults)))
1000-
else:
1001-
optional_params = dict()
1002-
required_params = {arg for arg in signature.args if arg not in optional_params}
997+
parameters = inspect.signature(fn_name).parameters
998+
required_params = set()
999+
optional_params = dict()
1000+
for param in parameters.keys():
1001+
parameter = parameters.get(param)
1002+
default_val = parameter.default # type: ignore
1003+
if default_val is inspect.Signature.empty:
1004+
required_params.add(param)
1005+
else:
1006+
optional_params[param] = default_val
10031007
return optional_params, required_params
10041008

10051009
def _deserialize_model(
@@ -1346,7 +1350,7 @@ def _can_measure_cputime(self, model: Any) -> bool:
13461350
# check the parameters for n_jobs
13471351
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(model.get_params(), "n_jobs")
13481352
for val in n_jobs_vals:
1349-
if val is not None and val != 1:
1353+
if val is not None and val != 1 and val != "deprecated":
13501354
return False
13511355
return True
13521356

0 commit comments

Comments
 (0)