Skip to content

Commit ff7a251

Browse files
authored
Adding sklearn min. dependencies for all versions (#1022)
* Squashing commits * All flow dependencies for sklearn>0.24 will change now * Dep. string change only for OpenML>v0.11
1 parent 7553281 commit ff7a251

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

openml/extensions/sklearn/extension.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,61 @@ def remove_all_in_parentheses(string: str) -> str:
211211

212212
return short_name.format(pipeline)
213213

214+
@classmethod
215+
def _min_dependency_str(cls, sklearn_version: str) -> str:
216+
""" Returns a string containing the minimum dependencies for the sklearn version passed.
217+
218+
Parameters
219+
----------
220+
sklearn_version : str
221+
A version string of the xx.xx.xx
222+
223+
Returns
224+
-------
225+
str
226+
"""
227+
openml_major_version = int(LooseVersion(openml.__version__).version[1])
228+
# This explicit check is necessary to support existing entities on the OpenML servers
229+
# that used the fixed dependency string (in the else block)
230+
if openml_major_version > 11:
231+
# OpenML v0.11 onwards supports sklearn>=0.24
232+
# assumption: 0.24 onwards sklearn should contain a _min_dependencies.py file with
233+
# variables declared for extracting minimum dependency for that version
234+
if LooseVersion(sklearn_version) >= "0.24":
235+
from sklearn import _min_dependencies as _mindep
236+
237+
dependency_list = {
238+
"numpy": "{}".format(_mindep.NUMPY_MIN_VERSION),
239+
"scipy": "{}".format(_mindep.SCIPY_MIN_VERSION),
240+
"joblib": "{}".format(_mindep.JOBLIB_MIN_VERSION),
241+
"threadpoolctl": "{}".format(_mindep.THREADPOOLCTL_MIN_VERSION),
242+
}
243+
elif LooseVersion(sklearn_version) >= "0.23":
244+
dependency_list = {
245+
"numpy": "1.13.3",
246+
"scipy": "0.19.1",
247+
"joblib": "0.11",
248+
"threadpoolctl": "2.0.0",
249+
}
250+
if LooseVersion(sklearn_version).version[2] == 0:
251+
dependency_list.pop("threadpoolctl")
252+
elif LooseVersion(sklearn_version) >= "0.21":
253+
dependency_list = {"numpy": "1.11.0", "scipy": "0.17.0", "joblib": "0.11"}
254+
elif LooseVersion(sklearn_version) >= "0.19":
255+
dependency_list = {"numpy": "1.8.2", "scipy": "0.13.3"}
256+
else:
257+
dependency_list = {"numpy": "1.6.1", "scipy": "0.9"}
258+
else:
259+
# this is INCORRECT for sklearn versions >= 0.19 and < 0.24
260+
# given that OpenML has existing flows uploaded with such dependency information,
261+
# we change no behaviour for older sklearn version, however from 0.24 onwards
262+
# the dependency list will be accurately updated for any flow uploaded to OpenML
263+
dependency_list = {"numpy": "1.6.1", "scipy": "0.9"}
264+
265+
sklearn_dep = "sklearn=={}".format(sklearn_version)
266+
dep_str = "\n".join(["{}>={}".format(k, v) for k, v in dependency_list.items()])
267+
return "\n".join([sklearn_dep, dep_str])
268+
214269
################################################################################################
215270
# Methods for flow serialization and de-serialization
216271

@@ -769,20 +824,13 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
769824
tags=tags,
770825
extension=self,
771826
language="English",
772-
# TODO fill in dependencies!
773827
dependencies=dependencies,
774828
)
775829

776830
return flow
777831

778832
def _get_dependencies(self) -> str:
779-
dependencies = "\n".join(
780-
[
781-
self._format_external_version("sklearn", sklearn.__version__,),
782-
"numpy>=1.6.1",
783-
"scipy>=0.9",
784-
]
785-
)
833+
dependencies = self._min_dependency_str(sklearn.__version__)
786834
return dependencies
787835

788836
def _get_tags(self) -> List[str]:

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_serialize_model(self):
146146
fixture_short_name = "sklearn.DecisionTreeClassifier"
147147
# str obtained from self.extension._get_sklearn_description(model)
148148
fixture_description = "A decision tree classifier."
149-
version_fixture = "sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9" % sklearn.__version__
149+
version_fixture = self.extension._min_dependency_str(sklearn.__version__)
150150

151151
presort_val = "false" if LooseVersion(sklearn.__version__) < "0.22" else '"deprecated"'
152152
# min_impurity_decrease has been introduced in 0.20
@@ -227,7 +227,7 @@ def test_serialize_model_clustering(self):
227227
fixture_description = "K-Means clustering{}".format(
228228
"" if LooseVersion(sklearn.__version__) < "0.22" else "."
229229
)
230-
version_fixture = "sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9" % sklearn.__version__
230+
version_fixture = self.extension._min_dependency_str(sklearn.__version__)
231231

232232
n_jobs_val = "null" if LooseVersion(sklearn.__version__) < "0.23" else '"deprecated"'
233233
precomp_val = '"auto"' if LooseVersion(sklearn.__version__) < "0.23" else '"deprecated"'

tests/test_flows/test_flow_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ def test_get_flow_reinstantiate_model_wrong_version(self):
343343
flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False)
344344
# ensure that a new flow was created
345345
assert flow.flow_id is None
346-
assert "0.19.1" not in flow.dependencies
346+
assert "sklearn==0.19.1" not in flow.dependencies
347+
assert "sklearn>=0.19.1" not in flow.dependencies
347348

348349
def test_get_flow_id(self):
349350
if self.long_version:

0 commit comments

Comments
 (0)