@@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self):
324324 )
325325
326326 @unittest .skipIf (
327- LooseVersion (sklearn .__version__ ) == "0.19.1" , reason = "Target flow is from sklearn 0.19.1"
327+ LooseVersion (sklearn .__version__ ) == "0.19.1" ,
328+ reason = "Requires scikit-learn!=0.19.1, because target flow is from that version." ,
328329 )
329- def test_get_flow_reinstantiate_model_wrong_version (self ):
330- # Note that CI does not test against 0.19.1.
330+ def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception (self ):
331331 openml .config .server = self .production_server
332- _ , sklearn_major , _ = LooseVersion (sklearn .__version__ ).version [:3 ]
333- if sklearn_major > 23 :
334- flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23
335- flow_sklearn_version = "0.23.1"
336- else :
337- flow = 8175
338- flow_sklearn_version = "0.19.1"
339- expected = (
340- "Trying to deserialize a model with dependency "
341- "sklearn=={} not satisfied." .format (flow_sklearn_version )
342- )
332+ flow = 8175
333+ expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied."
343334 self .assertRaisesRegex (
344- ValueError , expected , openml .flows .get_flow , flow_id = flow , reinstantiate = True
335+ ValueError ,
336+ expected ,
337+ openml .flows .get_flow ,
338+ flow_id = flow ,
339+ reinstantiate = True ,
340+ strict_version = True ,
345341 )
346- if LooseVersion (sklearn .__version__ ) > "0.19.1" :
347- # 0.18 actually can't deserialize this because of incompatibility
348- flow = openml .flows .get_flow (flow_id = flow , reinstantiate = True , strict_version = False )
349- # ensure that a new flow was created
350- assert flow .flow_id is None
351- assert "sklearn==0.19.1" not in flow .dependencies
352- assert "sklearn>=0.19.1" not in flow .dependencies
342+
343+ @unittest .skipIf (
344+ LooseVersion (sklearn .__version__ ) < "1" and LooseVersion (sklearn .__version__ ) != "1.0.0" ,
345+ reason = "Requires scikit-learn < 1.0.1."
346+ # Because scikit-learn dropped min_impurity_split hyperparameter in 1.0,
347+ # and the requested flow is from 1.0.0 exactly.
348+ )
349+ def test_get_flow_reinstantiate_flow_not_strict_post_1 (self ):
350+ openml .config .server = self .production_server
351+ flow = openml .flows .get_flow (flow_id = 19190 , reinstantiate = True , strict_version = False )
352+ assert flow .flow_id is None
353+ assert "sklearn==1.0.0" not in flow .dependencies
354+
355+ @unittest .skipIf (
356+ (LooseVersion (sklearn .__version__ ) < "0.23.2" )
357+ or ("1.0" < LooseVersion (sklearn .__version__ )),
358+ reason = "Requires scikit-learn 0.23.2 or ~0.24."
359+ # Because these still have min_impurity_split, but with new scikit-learn module structure."
360+ )
361+ def test_get_flow_reinstantiate_flow_not_strict_023_and_024 (self ):
362+ openml .config .server = self .production_server
363+ flow = openml .flows .get_flow (flow_id = 18587 , reinstantiate = True , strict_version = False )
364+ assert flow .flow_id is None
365+ assert "sklearn==0.23.1" not in flow .dependencies
366+
367+ @unittest .skipIf (
368+ "0.23" < LooseVersion (sklearn .__version__ ),
369+ reason = "Requires scikit-learn<=0.23, because the scikit-learn module structure changed." ,
370+ )
371+ def test_get_flow_reinstantiate_flow_not_strict_pre_023 (self ):
372+ openml .config .server = self .production_server
373+ flow = openml .flows .get_flow (flow_id = 8175 , reinstantiate = True , strict_version = False )
374+ assert flow .flow_id is None
375+ assert "sklearn==0.19.1" not in flow .dependencies
353376
354377 def test_get_flow_id (self ):
355378 if self .long_version :
0 commit comments