Skip to content

Commit 57d61c4

Browse files
ArlindKadramfeurer
authored andcommitted
Single input task partial fix (#541)
* Partial starting fix for single input task, cache dir multiplatform change * Reduce line size * changing type to isinstance * Refactoring the cache directory path to be more general * Fixing problem with clustering task in accordance with the different tasks implementation * Fixing flake8 problem, adding unit test for clustering task * Fixing bug with regression tasks, adding more checks to the get_task unit tests
1 parent 070b363 commit 57d61c4

File tree

5 files changed

+50
-34
lines changed

5 files changed

+50
-34
lines changed

ci_scripts/flake8_diff.sh

100644100755
File mode changed.

openml/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
'apikey': None,
2020
'server': "https://www.openml.org/api/v1/xml",
2121
'verbosity': 0,
22-
'cachedir': os.path.expanduser('~/.openml/cache'),
22+
'cachedir': os.path.expanduser(os.path.join('~', '.openml', 'cache')),
2323
'avoid_duplicate_runs': 'True',
2424
}
2525

26-
config_file = os.path.expanduser('~/.openml/config')
26+
config_file = os.path.expanduser(os.path.join('~', '.openml' 'config'))
2727

2828
# Default values are actually added here in the _setup() function which is
2929
# called at the end of this module
@@ -48,7 +48,7 @@ def _setup():
4848
global avoid_duplicate_runs
4949
# read config file, create cache directory
5050
try:
51-
os.mkdir(os.path.expanduser('~/.openml'))
51+
os.mkdir(os.path.expanduser(os.path.join('~', '.openml')))
5252
except (IOError, OSError):
5353
# TODO add debug information
5454
pass
@@ -96,7 +96,7 @@ def get_cache_directory():
9696
9797
"""
9898
url_suffix = urlparse(server).netloc
99-
reversed_url_suffix = '/'.join(url_suffix.split('.')[::-1])
99+
reversed_url_suffix = os.sep.join(url_suffix.split('.')[::-1])
100100
if not cache_directory:
101101
_cachedir = _defaults(cache_directory)
102102
else:

openml/tasks/functions.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from ..datasets import get_dataset
1111
from .task import (
1212
OpenMLClassificationTask,
13-
OpenMLRegressionTask,
1413
OpenMLClusteringTask,
1514
OpenMLLearningCurveTask,
15+
OpenMLRegressionTask,
16+
OpenMLSupervisedTask
1617
)
1718
import openml.utils
1819
import openml._api_calls
@@ -292,9 +293,13 @@ def get_task(task_id):
292293
try:
293294
task = _get_task_description(task_id)
294295
dataset = get_dataset(task.dataset_id)
295-
class_labels = dataset.retrieve_class_labels(task.target_name)
296-
task.class_labels = class_labels
297-
task.download_split()
296+
# Clustering tasks do not have class labels
297+
# and do not offer download_split
298+
if isinstance(task, OpenMLSupervisedTask):
299+
task.download_split()
300+
if isinstance(task, OpenMLClassificationTask):
301+
task.class_labels = \
302+
dataset.retrieve_class_labels(task.target_name)
298303
except Exception as e:
299304
openml.utils._remove_cache_dir_for_id(
300305
TASKS_CACHE_DIR_NAME,
@@ -323,6 +328,7 @@ def _get_task_description(task_id):
323328
fh.write(task_xml)
324329
return _create_task_from_xml(task_xml)
325330

331+
326332
def _create_task_from_xml(xml):
327333
"""Create a task given a xml string.
328334
@@ -336,46 +342,53 @@ def _create_task_from_xml(xml):
336342
OpenMLTask
337343
"""
338344
dic = xmltodict.parse(xml)["oml:task"]
339-
340345
estimation_parameters = dict()
341346
inputs = dict()
342347
# Due to the unordered structure we obtain, we first have to extract
343348
# the possible keys of oml:input; dic["oml:input"] is a list of
344349
# OrderedDicts
345-
for input_ in dic["oml:input"]:
346-
name = input_["@name"]
347-
inputs[name] = input_
350+
351+
# Check if there is a list of inputs
352+
if isinstance(dic["oml:input"], list):
353+
for input_ in dic["oml:input"]:
354+
name = input_["@name"]
355+
inputs[name] = input_
356+
# Single input case
357+
elif isinstance(dic["oml:input"], dict):
358+
name = dic["oml:input"]["@name"]
359+
inputs[name] = dic["oml:input"]
348360

349361
evaluation_measures = None
350362
if 'evaluation_measures' in inputs:
351363
evaluation_measures = inputs["evaluation_measures"][
352364
"oml:evaluation_measures"]["oml:evaluation_measure"]
353365

354-
# Convert some more parameters
355-
for parameter in \
356-
inputs["estimation_procedure"]["oml:estimation_procedure"][
357-
"oml:parameter"]:
358-
name = parameter["@name"]
359-
text = parameter.get("#text", "")
360-
estimation_parameters[name] = text
361-
362366
task_type = dic["oml:task_type"]
363367
common_kwargs = {
364368
'task_id': dic["oml:task_id"],
365369
'task_type': task_type,
366370
'task_type_id': dic["oml:task_type_id"],
367371
'data_set_id': inputs["source_data"][
368372
"oml:data_set"]["oml:data_set_id"],
369-
'estimation_procedure_type': inputs["estimation_procedure"][
370-
"oml:estimation_procedure"]["oml:type"],
371-
'estimation_parameters': estimation_parameters,
372373
'evaluation_measure': evaluation_measures,
373374
}
374375
if task_type in (
375376
"Supervised Classification",
376377
"Supervised Regression",
377378
"Learning Curve"
378379
):
380+
# Convert some more parameters
381+
for parameter in \
382+
inputs["estimation_procedure"]["oml:estimation_procedure"][
383+
"oml:parameter"]:
384+
name = parameter["@name"]
385+
text = parameter.get("#text", "")
386+
estimation_parameters[name] = text
387+
388+
common_kwargs['estimation_procedure_type'] = inputs[
389+
"estimation_procedure"][
390+
"oml:estimation_procedure"]["oml:type"],
391+
common_kwargs['estimation_parameters'] = estimation_parameters,
379392
common_kwargs['target_name'] = inputs[
380393
"source_data"]["oml:data_set"]["oml:target_feature"]
381394
common_kwargs['data_splits_url'] = inputs["estimation_procedure"][

openml/tasks/task.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@
99

1010
class OpenMLTask(object):
1111
def __init__(self, task_id, task_type_id, task_type, data_set_id,
12-
estimation_procedure_type, estimation_parameters,
1312
evaluation_measure):
1413
self.task_id = int(task_id)
1514
self.task_type_id = int(task_type_id)
1615
self.task_type = task_type
1716
self.dataset_id = int(data_set_id)
18-
self.estimation_procedure = dict()
19-
self.estimation_procedure["type"] = estimation_procedure_type
20-
self.estimation_procedure["parameters"] = estimation_parameters
21-
self.estimation_parameters = estimation_parameters
2217
self.evaluation_measure = evaluation_measure
2318

2419
def get_dataset(self):
@@ -57,12 +52,14 @@ def __init__(self, task_id, task_type_id, task_type, data_set_id,
5752
task_type_id=task_type_id,
5853
task_type=task_type,
5954
data_set_id=data_set_id,
60-
estimation_procedure_type=estimation_procedure_type,
61-
estimation_parameters=estimation_parameters,
6255
evaluation_measure=evaluation_measure,
6356
)
64-
self.target_name = target_name
57+
self.estimation_procedure = dict()
58+
self.estimation_procedure["type"] = estimation_procedure_type
59+
self.estimation_procedure["parameters"] = estimation_parameters
60+
self.estimation_parameters = estimation_parameters
6561
self.estimation_procedure["data_splits_url"] = data_splits_url
62+
self.target_name = target_name
6663
self.split = None
6764

6865
def get_X_and_y(self):
@@ -169,15 +166,12 @@ def __init__(self, task_id, task_type_id, task_type, data_set_id,
169166

170167
class OpenMLClusteringTask(OpenMLTask):
171168
def __init__(self, task_id, task_type_id, task_type, data_set_id,
172-
estimation_procedure_type, estimation_parameters,
173169
evaluation_measure, number_of_clusters=None):
174170
super(OpenMLClusteringTask, self).__init__(
175171
task_id=task_id,
176172
task_type_id=task_type_id,
177173
task_type=task_type,
178174
data_set_id=data_set_id,
179-
estimation_procedure_type=estimation_procedure_type,
180-
estimation_parameters=estimation_parameters,
181175
evaluation_measure=evaluation_measure,
182176
)
183177
self.number_of_clusters = number_of_clusters

tests/test_tasks/test_task_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ def test_get_task_with_cache(self):
156156
task = openml.tasks.get_task(1)
157157
self.assertIsInstance(task, OpenMLTask)
158158

159+
def test_get_task_different_types(self):
160+
openml.config.server = self.production_server
161+
# Regression task
162+
openml.tasks.functions.get_task(5001)
163+
# Learning curve
164+
openml.tasks.functions.get_task(64)
165+
# Issue 538, get_task failing with clustering task.
166+
openml.tasks.functions.get_task(126033)
167+
159168
def test_download_split(self):
160169
task = openml.tasks.get_task(1)
161170
split = task.download_split()

0 commit comments

Comments
 (0)