Skip to content

Commit 5cd6973

Browse files
authored
Refactor out different test cases to separate tests (#1176)
The previous solution had two test conditions (strict and not strict) and several scikit-learn versions, because of two distinct changes within scikit-learn (the removal of min_impurity_split in 1.0, and the restructuring of public/private models in 0.24). I refactored out the separate test cases to greatly simplify the individual tests, and I added a test case for scikit-learn>=1.0, which was previously not covered.
1 parent 22ee9cd commit 5cd6973

File tree

1 file changed

+45
-22
lines changed

1 file changed

+45
-22
lines changed

tests/test_flows/test_flow_functions.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self):
324324
)
325325

326326
@unittest.skipIf(
327-
LooseVersion(sklearn.__version__) == "0.19.1", reason="Target flow is from sklearn 0.19.1"
327+
LooseVersion(sklearn.__version__) == "0.19.1",
328+
reason="Requires scikit-learn!=0.19.1, because target flow is from that version.",
328329
)
329-
def test_get_flow_reinstantiate_model_wrong_version(self):
330-
# Note that CI does not test against 0.19.1.
330+
def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self):
331331
openml.config.server = self.production_server
332-
_, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3]
333-
if sklearn_major > 23:
334-
flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23
335-
flow_sklearn_version = "0.23.1"
336-
else:
337-
flow = 8175
338-
flow_sklearn_version = "0.19.1"
339-
expected = (
340-
"Trying to deserialize a model with dependency "
341-
"sklearn=={} not satisfied.".format(flow_sklearn_version)
342-
)
332+
flow = 8175
333+
expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied."
343334
self.assertRaisesRegex(
344-
ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True
335+
ValueError,
336+
expected,
337+
openml.flows.get_flow,
338+
flow_id=flow,
339+
reinstantiate=True,
340+
strict_version=True,
345341
)
346-
if LooseVersion(sklearn.__version__) > "0.19.1":
347-
# 0.18 actually can't deserialize this because of incompatibility
348-
flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False)
349-
# ensure that a new flow was created
350-
assert flow.flow_id is None
351-
assert "sklearn==0.19.1" not in flow.dependencies
352-
assert "sklearn>=0.19.1" not in flow.dependencies
342+
343+
@unittest.skipIf(
344+
LooseVersion(sklearn.__version__) < "1" and LooseVersion(sklearn.__version__) != "1.0.0",
345+
reason="Requires scikit-learn < 1.0.1."
346+
# Because scikit-learn dropped min_impurity_split hyperparameter in 1.0,
347+
# and the requested flow is from 1.0.0 exactly.
348+
)
349+
def test_get_flow_reinstantiate_flow_not_strict_post_1(self):
350+
openml.config.server = self.production_server
351+
flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False)
352+
assert flow.flow_id is None
353+
assert "sklearn==1.0.0" not in flow.dependencies
354+
355+
@unittest.skipIf(
356+
(LooseVersion(sklearn.__version__) < "0.23.2")
357+
or ("1.0" < LooseVersion(sklearn.__version__)),
358+
reason="Requires scikit-learn 0.23.2 or ~0.24."
359+
# Because these still have min_impurity_split, but with new scikit-learn module structure."
360+
)
361+
def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self):
362+
openml.config.server = self.production_server
363+
flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False)
364+
assert flow.flow_id is None
365+
assert "sklearn==0.23.1" not in flow.dependencies
366+
367+
@unittest.skipIf(
368+
"0.23" < LooseVersion(sklearn.__version__),
369+
reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.",
370+
)
371+
def test_get_flow_reinstantiate_flow_not_strict_pre_023(self):
372+
openml.config.server = self.production_server
373+
flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False)
374+
assert flow.flow_id is None
375+
assert "sklearn==0.19.1" not in flow.dependencies
353376

354377
def test_get_flow_id(self):
355378
if self.long_version:

0 commit comments

Comments
 (0)