Skip to content

Commit 2e69fe0

Browse files
janvanrijnmfeurer
authored andcommitted
Per fold evals (#613)
* added ability to obtain per fold evaluation measures * added json loads * updated unit test
1 parent 4a7db0e commit 2e69fe0

File tree

5 files changed

+82
-20
lines changed

5 files changed

+82
-20
lines changed

openml/evaluations/evaluation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
class OpenMLEvaluation(object):
3-
'''
3+
"""
44
Contains all meta-information about a run / evaluation combination,
55
according to the evaluation/list function
66
@@ -26,11 +26,13 @@ class OpenMLEvaluation(object):
2626
the time of evaluation
2727
value : float
2828
the value of this evaluation
29+
values : List[float]
30+
the values per repeat and fold (if requested)
2931
array_data : str
3032
list of information per class (e.g., in case of precision, auroc, recall)
31-
'''
33+
"""
3234
def __init__(self, run_id, task_id, setup_id, flow_id, flow_name,
33-
data_id, data_name, function, upload_time, value,
35+
data_id, data_name, function, upload_time, value, values,
3436
array_data=None):
3537
self.run_id = run_id
3638
self.task_id = task_id
@@ -42,4 +44,5 @@ def __init__(self, run_id, task_id, setup_id, flow_id, flow_name,
4244
self.function = function
4345
self.upload_time = upload_time
4446
self.value = value
47+
self.values = values
4548
self.array_data = array_data

openml/evaluations/functions.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import json
12
import xmltodict
23

3-
from openml.exceptions import OpenMLServerNoResult
44
import openml.utils
55
import openml._api_calls
66
from ..evaluations import OpenMLEvaluation
77

88

99
def list_evaluations(function, offset=None, size=None, id=None, task=None,
10-
setup=None, flow=None, uploader=None, tag=None):
10+
setup=None, flow=None, uploader=None, tag=None,
11+
per_fold=None):
1112
"""
1213
List all run-evaluation pairs matching all of the given filters.
1314
(Supports large amount of results)
@@ -33,13 +34,19 @@ def list_evaluations(function, offset=None, size=None, id=None, task=None,
3334
3435
tag : str, optional
3536
37+
per_fold : bool, optional
38+
3639
Returns
3740
-------
3841
dict
3942
"""
43+
if per_fold is not None:
44+
per_fold = str(per_fold).lower()
4045

41-
return openml.utils._list_all(_list_evaluations, function, offset=offset, size=size,
42-
id=id, task=task, setup=setup, flow=flow, uploader=uploader, tag=tag)
46+
return openml.utils._list_all(_list_evaluations, function, offset=offset,
47+
size=size, id=id, task=task, setup=setup,
48+
flow=flow, uploader=uploader, tag=tag,
49+
per_fold=per_fold)
4350

4451

4552
def _list_evaluations(function, id=None, task=None,
@@ -97,24 +104,34 @@ def __list_evaluations(api_call):
97104
evals_dict = xmltodict.parse(xml_string, force_list=('oml:evaluation',))
98105
# Minimalistic check if the XML is useful
99106
if 'oml:evaluations' not in evals_dict:
100-
raise ValueError('Error in return XML, does not contain "oml:evaluations": %s'
101-
% str(evals_dict))
107+
raise ValueError('Error in return XML, does not contain '
108+
'"oml:evaluations": %s' % str(evals_dict))
102109

103110
assert type(evals_dict['oml:evaluations']['oml:evaluation']) == list, \
104111
type(evals_dict['oml:evaluations'])
105112

106113
evals = dict()
107114
for eval_ in evals_dict['oml:evaluations']['oml:evaluation']:
108115
run_id = int(eval_['oml:run_id'])
116+
value = None
117+
values = None
109118
array_data = None
119+
if 'oml:value' in eval_:
120+
value = float(eval_['oml:value'])
121+
if 'oml:values' in eval_:
122+
values = json.loads(eval_['oml:values'])
110123
if 'oml:array_data' in eval_:
111124
array_data = eval_['oml:array_data']
112125

113-
evals[run_id] = OpenMLEvaluation(int(eval_['oml:run_id']), int(eval_['oml:task_id']),
114-
int(eval_['oml:setup_id']), int(eval_['oml:flow_id']),
115-
eval_['oml:flow_name'], eval_['oml:data_id'],
116-
eval_['oml:data_name'], eval_['oml:function'],
117-
eval_['oml:upload_time'], float(eval_['oml:value']),
118-
array_data)
126+
evals[run_id] = OpenMLEvaluation(int(eval_['oml:run_id']),
127+
int(eval_['oml:task_id']),
128+
int(eval_['oml:setup_id']),
129+
int(eval_['oml:flow_id']),
130+
eval_['oml:flow_name'],
131+
eval_['oml:data_id'],
132+
eval_['oml:data_name'],
133+
eval_['oml:function'],
134+
eval_['oml:upload_time'],
135+
value, values, array_data)
119136

120137
return evals

openml/runs/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _publish_flow_if_necessary(flow):
184184
except OpenMLServerException as e:
185185
if e.message == "flow already exists":
186186
# TODO: JvR: the following lines of code can be replaced by
187-
# a pass (after changing the unit test) as run_flow_on_task does
187+
# a pass (after changing the unit tests) as run_flow_on_task does
188188
# not longer rely on it
189189
flow_id = openml.flows.flow_exists(flow.name,
190190
flow.external_version)

tests/test_evaluations/test_evaluation_functions.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import openml.evaluations
33
from openml.testing import TestBase
44

5+
56
class TestEvaluationFunctions(TestBase):
67
_multiprocess_can_split_ = True
78

@@ -15,6 +16,10 @@ def test_evaluation_list_filter_task(self):
1516
self.assertGreater(len(evaluations), 100)
1617
for run_id in evaluations.keys():
1718
self.assertEquals(evaluations[run_id].task_id, task_id)
19+
# default behaviour of this method: return aggregated results (not
20+
# per fold)
21+
self.assertIsNotNone(evaluations[run_id].value)
22+
self.assertIsNone(evaluations[run_id].values)
1823

1924
def test_evaluation_list_filter_uploader_ID_16(self):
2025
openml.config.server = self.production_server
@@ -23,7 +28,7 @@ def test_evaluation_list_filter_uploader_ID_16(self):
2328

2429
evaluations = openml.evaluations.list_evaluations("predictive_accuracy", uploader=[uploader_id])
2530

26-
self.assertGreater(len(evaluations), 100)
31+
self.assertGreater(len(evaluations), 50)
2732

2833
def test_evaluation_list_filter_uploader_ID_10(self):
2934
openml.config.server = self.production_server
@@ -32,9 +37,13 @@ def test_evaluation_list_filter_uploader_ID_10(self):
3237

3338
evaluations = openml.evaluations.list_evaluations("predictive_accuracy", setup=[setup_id])
3439

35-
self.assertGreater(len(evaluations), 100)
40+
self.assertGreater(len(evaluations), 50)
3641
for run_id in evaluations.keys():
3742
self.assertEquals(evaluations[run_id].setup_id, setup_id)
43+
# default behaviour of this method: return aggregated results (not
44+
# per fold)
45+
self.assertIsNotNone(evaluations[run_id].value)
46+
self.assertIsNone(evaluations[run_id].values)
3847

3948
def test_evaluation_list_filter_flow(self):
4049
openml.config.server = self.production_server
@@ -46,17 +55,25 @@ def test_evaluation_list_filter_flow(self):
4655
self.assertGreater(len(evaluations), 2)
4756
for run_id in evaluations.keys():
4857
self.assertEquals(evaluations[run_id].flow_id, flow_id)
58+
# default behaviour of this method: return aggregated results (not
59+
# per fold)
60+
self.assertIsNotNone(evaluations[run_id].value)
61+
self.assertIsNone(evaluations[run_id].values)
4962

5063
def test_evaluation_list_filter_run(self):
5164
openml.config.server = self.production_server
5265

53-
run_id = 1
66+
run_id = 12
5467

5568
evaluations = openml.evaluations.list_evaluations("predictive_accuracy", id=[run_id])
5669

5770
self.assertEquals(len(evaluations), 1)
5871
for run_id in evaluations.keys():
5972
self.assertEquals(evaluations[run_id].run_id, run_id)
73+
# default behaviour of this method: return aggregated results (not
74+
# per fold)
75+
self.assertIsNotNone(evaluations[run_id].value)
76+
self.assertIsNone(evaluations[run_id].values)
6077

6178
def test_evaluation_list_limit(self):
6279
openml.config.server = self.production_server
@@ -70,3 +87,28 @@ def test_list_evaluations_empty(self):
7087
raise ValueError('UnitTest Outdated, got somehow results')
7188

7289
self.assertIsInstance(evaluations, dict)
90+
91+
def test_evaluation_list_per_fold(self):
92+
openml.config.server = self.production_server
93+
size = 1000
94+
task_ids = [6]
95+
uploader_ids = [1]
96+
flow_ids = [6969]
97+
98+
evaluations = openml.evaluations.list_evaluations(
99+
"predictive_accuracy", size=size, offset=0, task=task_ids,
100+
flow=flow_ids, uploader=uploader_ids, per_fold=True)
101+
102+
self.assertEquals(len(evaluations), size)
103+
for run_id in evaluations.keys():
104+
self.assertIsNone(evaluations[run_id].value)
105+
self.assertIsNotNone(evaluations[run_id].values)
106+
# potentially we could also test array values, but these might be
107+
# added in the future
108+
109+
evaluations = openml.evaluations.list_evaluations(
110+
"predictive_accuracy", size=size, offset=0, task=task_ids,
111+
flow=flow_ids, uploader=uploader_ids, per_fold=False)
112+
for run_id in evaluations.keys():
113+
self.assertIsNotNone(evaluations[run_id].value)
114+
self.assertIsNone(evaluations[run_id].values)

tests/test_runs/test_run_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ def _check_run(self, run):
999999
def test_get_runs_list(self):
10001000
# TODO: comes from live, no such lists on test
10011001
openml.config.server = self.production_server
1002-
runs = openml.runs.list_runs(id=[2])
1002+
runs = openml.runs.list_runs(id=[2], show_errors=True)
10031003
self.assertEqual(len(runs), 1)
10041004
for rid in runs:
10051005
self._check_run(runs[rid])

0 commit comments

Comments
 (0)