@@ -211,6 +211,61 @@ def remove_all_in_parentheses(string: str) -> str:
211211
212212 return short_name .format (pipeline )
213213
214+ @classmethod
215+ def _min_dependency_str (cls , sklearn_version : str ) -> str :
216+ """ Returns a string containing the minimum dependencies for the sklearn version passed.
217+
218+ Parameters
219+ ----------
220+ sklearn_version : str
221+ A version string of the xx.xx.xx
222+
223+ Returns
224+ -------
225+ str
226+ """
227+ openml_major_version = int (LooseVersion (openml .__version__ ).version [1 ])
228+ # This explicit check is necessary to support existing entities on the OpenML servers
229+ # that used the fixed dependency string (in the else block)
230+ if openml_major_version > 11 :
231+ # OpenML v0.11 onwards supports sklearn>=0.24
232+ # assumption: 0.24 onwards sklearn should contain a _min_dependencies.py file with
233+ # variables declared for extracting minimum dependency for that version
234+ if LooseVersion (sklearn_version ) >= "0.24" :
235+ from sklearn import _min_dependencies as _mindep
236+
237+ dependency_list = {
238+ "numpy" : "{}" .format (_mindep .NUMPY_MIN_VERSION ),
239+ "scipy" : "{}" .format (_mindep .SCIPY_MIN_VERSION ),
240+ "joblib" : "{}" .format (_mindep .JOBLIB_MIN_VERSION ),
241+ "threadpoolctl" : "{}" .format (_mindep .THREADPOOLCTL_MIN_VERSION ),
242+ }
243+ elif LooseVersion (sklearn_version ) >= "0.23" :
244+ dependency_list = {
245+ "numpy" : "1.13.3" ,
246+ "scipy" : "0.19.1" ,
247+ "joblib" : "0.11" ,
248+ "threadpoolctl" : "2.0.0" ,
249+ }
250+ if LooseVersion (sklearn_version ).version [2 ] == 0 :
251+ dependency_list .pop ("threadpoolctl" )
252+ elif LooseVersion (sklearn_version ) >= "0.21" :
253+ dependency_list = {"numpy" : "1.11.0" , "scipy" : "0.17.0" , "joblib" : "0.11" }
254+ elif LooseVersion (sklearn_version ) >= "0.19" :
255+ dependency_list = {"numpy" : "1.8.2" , "scipy" : "0.13.3" }
256+ else :
257+ dependency_list = {"numpy" : "1.6.1" , "scipy" : "0.9" }
258+ else :
259+ # this is INCORRECT for sklearn versions >= 0.19 and < 0.24
260+ # given that OpenML has existing flows uploaded with such dependency information,
261+ # we change no behaviour for older sklearn version, however from 0.24 onwards
262+ # the dependency list will be accurately updated for any flow uploaded to OpenML
263+ dependency_list = {"numpy" : "1.6.1" , "scipy" : "0.9" }
264+
265+ sklearn_dep = "sklearn=={}" .format (sklearn_version )
266+ dep_str = "\n " .join (["{}>={}" .format (k , v ) for k , v in dependency_list .items ()])
267+ return "\n " .join ([sklearn_dep , dep_str ])
268+
214269 ################################################################################################
215270 # Methods for flow serialization and de-serialization
216271
@@ -769,20 +824,13 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
769824 tags = tags ,
770825 extension = self ,
771826 language = "English" ,
772- # TODO fill in dependencies!
773827 dependencies = dependencies ,
774828 )
775829
776830 return flow
777831
778832 def _get_dependencies (self ) -> str :
779- dependencies = "\n " .join (
780- [
781- self ._format_external_version ("sklearn" , sklearn .__version__ ,),
782- "numpy>=1.6.1" ,
783- "scipy>=0.9" ,
784- ]
785- )
833+ dependencies = self ._min_dependency_str (sklearn .__version__ )
786834 return dependencies
787835
788836 def _get_tags (self ) -> List [str ]:
0 commit comments