Skip to content

Commit 692af97

Browse files
PGijsbersmfeurer
authored andcommitted
Reinstantiate model if needed. Better errors if can't. (#722)
* Clearer error messages when trying to reinstantiate a model and this is not possible. Automatically reinstantiate flow model if possible when run_flow_on_task is called. * Updated changelog. * Fix unit test mistakes. * Check error message with regex.
1 parent ebae892 commit 692af97

File tree

6 files changed

+39
-5
lines changed

6 files changed

+39
-5
lines changed

doc/progress.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Changelog
88

99
0.10.0
1010
~~~~~~
11+
* ADD #722: Automatic reinstantiation of flow in `run_model_on_task`. Clearer errors if that's not possible.
1112
* FIX #608: Fixing dataset_id referenced before assignment error in get_run function.
1213
* ADD #715: `list_evaluations` now has an option to sort evaluations by score (value).
1314
* FIX #589: Fixing a bug that did not successfully upload the columns to ignore when creating and publishing a dataset.

openml/flows/flow.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,15 @@ def __init__(self, name, description, model, components, parameters,
132132
self.dependencies = dependencies
133133
self.flow_id = flow_id
134134

135-
self.extension = get_extension_by_flow(self)
135+
self._extension = get_extension_by_flow(self)
136+
137+
@property
138+
def extension(self):
139+
if self._extension is not None:
140+
return self._extension
141+
else:
142+
raise RuntimeError("No extension could be found for flow {}: {}"
143+
.format(self.flow_id, self.name))
136144

137145
def __str__(self):
138146
header = "OpenML Flow"

openml/flows/functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow:
9292

9393
if reinstantiate:
9494
flow.model = flow.extension.flow_to_model(flow)
95-
9695
return flow
9796

9897

@@ -360,7 +359,7 @@ def assert_flows_equal(flow1: OpenMLFlow, flow2: OpenMLFlow,
360359
assert_flows_equal(attr1[name], attr2[name],
361360
ignore_parameter_values_on_older_children,
362361
ignore_parameter_values)
363-
elif key == 'extension':
362+
elif key == '_extension':
364363
continue
365364
else:
366365
if key == 'parameters':

openml/runs/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def run_flow_on_task(
171171
if task.task_id is None:
172172
raise ValueError("The task should be published at OpenML")
173173

174+
if flow.model is None:
175+
flow.model = flow.extension.flow_to_model(flow)
174176
flow.model = flow.extension.seed_model(flow.model, seed=seed)
175177

176178
# We only need to sync with the server right now if we want to upload the flow,

tests/test_flows/test_flow_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,27 @@ def test_sklearn_to_flow_list_of_lists(self):
256256
server_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True)
257257
self.assertEqual(server_flow.parameters['categories'], '[[0, 1], [0, 1]]')
258258
self.assertEqual(server_flow.model.categories, flow.model.categories)
259+
260+
def test_get_flow_reinstantiate_model(self):
261+
model = sklearn.ensemble.RandomForestClassifier(n_estimators=33)
262+
extension = openml.extensions.get_extension_by_model(model)
263+
flow = extension.model_to_flow(model)
264+
flow.publish(raise_error_if_exists=False)
265+
266+
downloaded_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True)
267+
self.assertIsInstance(downloaded_flow.model, sklearn.ensemble.RandomForestClassifier)
268+
269+
def test_get_flow_reinstantiate_model_no_extension(self):
270+
# Flow 10 is a WEKA flow
271+
self.assertRaisesRegex(RuntimeError,
272+
"No extension could be found for flow 10: weka.SMO",
273+
openml.flows.get_flow,
274+
flow_id=10,
275+
reinstantiate=True)
276+
277+
@unittest.skipIf(LooseVersion(sklearn.__version__) == "0.20.0",
278+
reason="No non-0.20 scikit-learn flow known.")
279+
def test_get_flow_reinstantiate_model_wrong_version(self):
280+
# 20 is scikit-learn ==0.20.0
281+
# I can't find a != 0.20 permanent flow on the test server.
282+
self.assertRaises(ValueError, openml.flows.get_flow, flow_id=20, reinstantiate=True)

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,12 +1259,12 @@ def test_get_uncached_run(self):
12591259
with self.assertRaises(openml.exceptions.OpenMLCacheException):
12601260
openml.runs.functions._get_cached_run(10)
12611261

1262-
def test_run_model_on_task_downloaded_flow(self):
1262+
def test_run_flow_on_task_downloaded_flow(self):
12631263
model = sklearn.ensemble.RandomForestClassifier(n_estimators=33)
12641264
flow = self.extension.model_to_flow(model)
12651265
flow.publish(raise_error_if_exists=False)
12661266

1267-
downloaded_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True)
1267+
downloaded_flow = openml.flows.get_flow(flow.flow_id)
12681268
task = openml.tasks.get_task(119) # diabetes
12691269
run = openml.runs.run_flow_on_task(
12701270
flow=downloaded_flow,

0 commit comments

Comments
 (0)