Skip to content

Commit 04c4d0e

Browse files
janvanrijnmfeurer
authored andcommitted
Fix #569: crash when sklearn version does not collide (#601)
* reinstantiate flow * reinstantiate flow fix * pep8 problems * pep8 fix
1 parent aae0e5b commit 04c4d0e

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

openml/flows/flow.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -310,19 +310,6 @@ def _from_dict(cls, xml_dict):
310310
arguments['model'] = None
311311
flow = cls(**arguments)
312312

313-
# try to parse to a model because not everything that can be
314-
# deserialized has to come from scikit-learn. If it can't be
315-
# serialized, but comes from scikit-learn this is worth an exception
316-
if (
317-
arguments['external_version'].startswith('sklearn==')
318-
or ',sklearn==' in arguments['external_version']
319-
):
320-
from .sklearn_converter import flow_to_sklearn
321-
model = flow_to_sklearn(flow)
322-
else:
323-
model = None
324-
flow.model = model
325-
326313
return flow
327314

328315
def publish(self):

openml/flows/functions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,36 @@
88
import openml.utils
99

1010

11-
def get_flow(flow_id):
11+
def get_flow(flow_id, reinstantiate=False):
1212
"""Download the OpenML flow for a given flow ID.
1313
1414
Parameters
1515
----------
1616
flow_id : int
1717
The OpenML flow id.
18+
19+
reinstantiate: bool
20+
Whether to reinstantiate the flow to a sklearn model.
21+
Note that this can only be done with sklearn flows, and
22+
when
23+
24+
Returns
25+
-------
26+
flow : OpenMLFlow
27+
the flow
1828
"""
1929
flow_id = int(flow_id)
2030
flow_xml = openml._api_calls._perform_api_call("flow/%d" % flow_id)
2131

2232
flow_dict = xmltodict.parse(flow_xml)
2333
flow = OpenMLFlow._from_dict(flow_dict)
2434

35+
if reinstantiate:
36+
if not (flow.external_version.startswith('sklearn==') or
37+
',sklearn==' in flow.external_version):
38+
raise ValueError('Only sklearn flows can be reinstantiated')
39+
flow.model = openml.flows.flow_to_sklearn(flow)
40+
2541
return flow
2642

2743

tests/test_flows/test_flow.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def test_existing_flow_exists(self):
275275
for classifier in [nb, complicated]:
276276
flow = openml.flows.sklearn_to_flow(classifier)
277277
flow, _ = self._add_sentinel_to_flow_name(flow, None)
278-
#publish the flow
278+
# publish the flow
279279
flow = flow.publish()
280-
#redownload the flow
280+
# redownload the flow
281281
flow = openml.flows.get_flow(flow.flow_id)
282282

283283
# check if flow exists can find it
@@ -329,7 +329,8 @@ def test_sklearn_to_upload_to_flow(self):
329329
# Check whether we can load the flow again
330330
# Remove the sentinel from the name again so that we can reinstantiate
331331
# the object again
332-
new_flow = openml.flows.get_flow(flow_id=flow.flow_id)
332+
new_flow = openml.flows.get_flow(flow_id=flow.flow_id,
333+
reinstantiate=True)
333334

334335
local_xml = flow._to_xml()
335336
server_xml = new_flow._to_xml()

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,8 @@ def test_get_run_trace(self):
627627
flow_exists = openml.flows.flow_exists(flow.name, flow.external_version)
628628
self.assertIsInstance(flow_exists, int)
629629
self.assertGreater(flow_exists, 0)
630-
downloaded_flow = openml.flows.get_flow(flow_exists)
630+
downloaded_flow = openml.flows.get_flow(flow_exists,
631+
reinstantiate=True)
631632
setup_exists = openml.setups.setup_exists(downloaded_flow)
632633
self.assertIsInstance(setup_exists, int)
633634
self.assertGreater(setup_exists, 0)

0 commit comments

Comments
 (0)