Skip to content

Commit 0def226

Browse files
authored
Class to enum (#958)
* convert TaskTypeEnum class to TaskType enum * update docstrings for TaskType * fix bug in examples, import TaskType directly * use task_type instead of task_type_id
1 parent 5641828 commit 0def226

File tree

15 files changed

+134
-121
lines changed

15 files changed

+134
-121
lines changed

examples/30_extended/tasks_tutorial.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# License: BSD 3-Clause
99

1010
import openml
11+
from openml.tasks import TaskType
1112
import pandas as pd
1213

1314
############################################################################
@@ -30,7 +31,7 @@
3031
#
3132
# We will start by simply listing only *supervised classification* tasks:
3233

33-
tasks = openml.tasks.list_tasks(task_type_id=1)
34+
tasks = openml.tasks.list_tasks(task_type=TaskType.SUPERVISED_CLASSIFICATION)
3435

3536
############################################################################
3637
# **openml.tasks.list_tasks()** returns a dictionary of dictionaries by default, which we convert
@@ -45,7 +46,9 @@
4546

4647
# As conversion to a pandas dataframe is a common task, we have added this functionality to the
4748
# OpenML-Python library which can be used by passing ``output_format='dataframe'``:
48-
tasks_df = openml.tasks.list_tasks(task_type_id=1, output_format="dataframe")
49+
tasks_df = openml.tasks.list_tasks(
50+
task_type=TaskType.SUPERVISED_CLASSIFICATION, output_format="dataframe"
51+
)
4952
print(tasks_df.head())
5053

5154
############################################################################
@@ -155,7 +158,7 @@
155158
#
156159
# Creating a task requires the following input:
157160
#
158-
# * task_type_id: The task type ID, required (see below). Required.
161+
# * task_type: The task type ID, required (see below). Required.
159162
# * dataset_id: The dataset ID. Required.
160163
# * target_name: The name of the attribute you aim to predict. Optional.
161164
# * estimation_procedure_id : The ID of the estimation procedure used to create train-test
@@ -186,9 +189,8 @@
186189
openml.config.start_using_configuration_for_example()
187190

188191
try:
189-
tasktypes = openml.tasks.TaskTypeEnum
190192
my_task = openml.tasks.create_task(
191-
task_type_id=tasktypes.SUPERVISED_CLASSIFICATION,
193+
task_type=TaskType.SUPERVISED_CLASSIFICATION,
192194
dataset_id=128,
193195
target_name="class",
194196
evaluation_measure="predictive_accuracy",

examples/40_paper/2015_neurips_feurer_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
# deactivated, which also deactivated the tasks on them. More information on active or inactive
5959
# datasets can be found in the `online docs <https://docs.openml.org/#dataset-status>`_.
6060
tasks = openml.tasks.list_tasks(
61-
task_type_id=openml.tasks.TaskTypeEnum.SUPERVISED_CLASSIFICATION,
61+
task_type=openml.tasks.TaskType.SUPERVISED_CLASSIFICATION,
6262
status="all",
6363
output_format="dataframe",
6464
)

openml/runs/functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from .run import OpenMLRun
3434
from .trace import OpenMLRunTrace
35-
from ..tasks import TaskTypeEnum, get_task
35+
from ..tasks import TaskType, get_task
3636

3737
# Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
3838
if TYPE_CHECKING:
@@ -274,7 +274,7 @@ def run_flow_on_task(
274274
run.parameter_settings = flow.extension.obtain_parameter_values(flow)
275275

276276
# now we need to attach the detailed evaluations
277-
if task.task_type_id == TaskTypeEnum.LEARNING_CURVE:
277+
if task.task_type_id == TaskType.LEARNING_CURVE:
278278
run.sample_evaluations = sample_evaluations
279279
else:
280280
run.fold_evaluations = fold_evaluations
@@ -772,7 +772,7 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None):
772772

773773
if "predictions" not in files and from_server is True:
774774
task = openml.tasks.get_task(task_id)
775-
if task.task_type_id == TaskTypeEnum.SUBGROUP_DISCOVERY:
775+
if task.task_type_id == TaskType.SUBGROUP_DISCOVERY:
776776
raise NotImplementedError("Subgroup discovery tasks are not yet supported.")
777777
else:
778778
# JvR: actually, I am not sure whether this error should be raised.
@@ -1008,7 +1008,7 @@ def __list_runs(api_call, output_format="dict"):
10081008
"setup_id": int(run_["oml:setup_id"]),
10091009
"flow_id": int(run_["oml:flow_id"]),
10101010
"uploader": int(run_["oml:uploader"]),
1011-
"task_type": int(run_["oml:task_type_id"]),
1011+
"task_type": TaskType(int(run_["oml:task_type_id"])),
10121012
"upload_time": str(run_["oml:upload_time"]),
10131013
"error_message": str((run_["oml:error_message"]) or ""),
10141014
}

openml/runs/run.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..flows import get_flow
1717
from ..tasks import (
1818
get_task,
19-
TaskTypeEnum,
19+
TaskType,
2020
OpenMLClassificationTask,
2121
OpenMLLearningCurveTask,
2222
OpenMLClusteringTask,
@@ -401,17 +401,13 @@ def get_metric_fn(self, sklearn_fn, kwargs=None):
401401

402402
attribute_names = [att[0] for att in predictions_arff["attributes"]]
403403
if (
404-
task.task_type_id
405-
in [TaskTypeEnum.SUPERVISED_CLASSIFICATION, TaskTypeEnum.LEARNING_CURVE]
404+
task.task_type_id in [TaskType.SUPERVISED_CLASSIFICATION, TaskType.LEARNING_CURVE]
406405
and "correct" not in attribute_names
407406
):
408407
raise ValueError('Attribute "correct" should be set for ' "classification task runs")
409-
if (
410-
task.task_type_id == TaskTypeEnum.SUPERVISED_REGRESSION
411-
and "truth" not in attribute_names
412-
):
408+
if task.task_type_id == TaskType.SUPERVISED_REGRESSION and "truth" not in attribute_names:
413409
raise ValueError('Attribute "truth" should be set for ' "regression task runs")
414-
if task.task_type_id != TaskTypeEnum.CLUSTERING and "prediction" not in attribute_names:
410+
if task.task_type_id != TaskType.CLUSTERING and "prediction" not in attribute_names:
415411
raise ValueError('Attribute "predict" should be set for ' "supervised task runs")
416412

417413
def _attribute_list_to_dict(attribute_list):
@@ -431,11 +427,11 @@ def _attribute_list_to_dict(attribute_list):
431427
predicted_idx = attribute_dict["prediction"] # Assume supervised task
432428

433429
if (
434-
task.task_type_id == TaskTypeEnum.SUPERVISED_CLASSIFICATION
435-
or task.task_type_id == TaskTypeEnum.LEARNING_CURVE
430+
task.task_type_id == TaskType.SUPERVISED_CLASSIFICATION
431+
or task.task_type_id == TaskType.LEARNING_CURVE
436432
):
437433
correct_idx = attribute_dict["correct"]
438-
elif task.task_type_id == TaskTypeEnum.SUPERVISED_REGRESSION:
434+
elif task.task_type_id == TaskType.SUPERVISED_REGRESSION:
439435
correct_idx = attribute_dict["truth"]
440436
has_samples = False
441437
if "sample" in attribute_dict:
@@ -465,14 +461,14 @@ def _attribute_list_to_dict(attribute_list):
465461
samp = 0 # No learning curve sample, always 0
466462

467463
if task.task_type_id in [
468-
TaskTypeEnum.SUPERVISED_CLASSIFICATION,
469-
TaskTypeEnum.LEARNING_CURVE,
464+
TaskType.SUPERVISED_CLASSIFICATION,
465+
TaskType.LEARNING_CURVE,
470466
]:
471467
prediction = predictions_arff["attributes"][predicted_idx][1].index(
472468
line[predicted_idx]
473469
)
474470
correct = predictions_arff["attributes"][predicted_idx][1].index(line[correct_idx])
475-
elif task.task_type_id == TaskTypeEnum.SUPERVISED_REGRESSION:
471+
elif task.task_type_id == TaskType.SUPERVISED_REGRESSION:
476472
prediction = line[predicted_idx]
477473
correct = line[correct_idx]
478474
if rep not in values_predict:

openml/tasks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
OpenMLRegressionTask,
88
OpenMLClusteringTask,
99
OpenMLLearningCurveTask,
10-
TaskTypeEnum,
10+
TaskType,
1111
)
1212
from .split import OpenMLSplit
1313
from .functions import (
@@ -29,5 +29,5 @@
2929
"get_tasks",
3030
"list_tasks",
3131
"OpenMLSplit",
32-
"TaskTypeEnum",
32+
"TaskType",
3333
]

openml/tasks/functions.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
OpenMLClassificationTask,
1616
OpenMLClusteringTask,
1717
OpenMLLearningCurveTask,
18-
TaskTypeEnum,
18+
TaskType,
1919
OpenMLRegressionTask,
2020
OpenMLSupervisedTask,
2121
OpenMLTask,
@@ -109,7 +109,7 @@ def _get_estimation_procedure_list():
109109
procs.append(
110110
{
111111
"id": int(proc_["oml:id"]),
112-
"task_type_id": int(proc_["oml:ttid"]),
112+
"task_type_id": TaskType(int(proc_["oml:ttid"])),
113113
"name": proc_["oml:name"],
114114
"type": proc_["oml:type"],
115115
}
@@ -119,22 +119,22 @@ def _get_estimation_procedure_list():
119119

120120

121121
def list_tasks(
122-
task_type_id: Optional[int] = None,
122+
task_type: Optional[TaskType] = None,
123123
offset: Optional[int] = None,
124124
size: Optional[int] = None,
125125
tag: Optional[str] = None,
126126
output_format: str = "dict",
127127
**kwargs
128128
) -> Union[Dict, pd.DataFrame]:
129129
"""
130-
Return a number of tasks having the given tag and task_type_id
130+
Return a number of tasks having the given tag and task_type
131131
132132
Parameters
133133
----------
134-
Filter task_type_id is separated from the other filters because
135-
it is used as task_type_id in the task description, but it is named
134+
Filter task_type is separated from the other filters because
135+
it is used as task_type in the task description, but it is named
136136
type when used as a filter in list tasks call.
137-
task_type_id : int, optional
137+
task_type : TaskType, optional
138138
ID of the task type as detailed `here <https://www.openml.org/search?type=task_type>`_.
139139
- Supervised classification: 1
140140
- Supervised regression: 2
@@ -162,12 +162,12 @@ def list_tasks(
162162
Returns
163163
-------
164164
dict
165-
All tasks having the given task_type_id and the give tag. Every task is
165+
All tasks having the given task_type and the give tag. Every task is
166166
represented by a dictionary containing the following information:
167167
task id, dataset id, task_type and status. If qualities are calculated
168168
for the associated dataset, some of these are also returned.
169169
dataframe
170-
All tasks having the given task_type_id and the give tag. Every task is
170+
All tasks having the given task_type and the give tag. Every task is
171171
represented by a row in the data frame containing the following information
172172
as columns: task id, dataset id, task_type and status. If qualities are
173173
calculated for the associated dataset, some of these are also returned.
@@ -179,23 +179,23 @@ def list_tasks(
179179
return openml.utils._list_all(
180180
output_format=output_format,
181181
listing_call=_list_tasks,
182-
task_type_id=task_type_id,
182+
task_type=task_type,
183183
offset=offset,
184184
size=size,
185185
tag=tag,
186186
**kwargs
187187
)
188188

189189

190-
def _list_tasks(task_type_id=None, output_format="dict", **kwargs):
190+
def _list_tasks(task_type=None, output_format="dict", **kwargs):
191191
"""
192192
Perform the api call to return a number of tasks having the given filters.
193193
Parameters
194194
----------
195-
Filter task_type_id is separated from the other filters because
196-
it is used as task_type_id in the task description, but it is named
195+
Filter task_type is separated from the other filters because
196+
it is used as task_type in the task description, but it is named
197197
type when used as a filter in list tasks call.
198-
task_type_id : int, optional
198+
task_type : TaskType, optional
199199
ID of the task type as detailed
200200
`here <https://www.openml.org/search?type=task_type>`_.
201201
- Supervised classification: 1
@@ -220,8 +220,8 @@ def _list_tasks(task_type_id=None, output_format="dict", **kwargs):
220220
dict or dataframe
221221
"""
222222
api_call = "task/list"
223-
if task_type_id is not None:
224-
api_call += "/type/%d" % int(task_type_id)
223+
if task_type is not None:
224+
api_call += "/type/%d" % task_type.value
225225
if kwargs is not None:
226226
for operator, value in kwargs.items():
227227
if operator == "task_id":
@@ -259,7 +259,7 @@ def __list_tasks(api_call, output_format="dict"):
259259
tid = int(task_["oml:task_id"])
260260
task = {
261261
"tid": tid,
262-
"ttid": int(task_["oml:task_type_id"]),
262+
"ttid": TaskType(int(task_["oml:task_type_id"])),
263263
"did": int(task_["oml:did"]),
264264
"name": task_["oml:name"],
265265
"task_type": task_["oml:task_type"],
@@ -417,18 +417,18 @@ def _create_task_from_xml(xml):
417417
"oml:evaluation_measure"
418418
]
419419

420-
task_type_id = int(dic["oml:task_type_id"])
420+
task_type = TaskType(int(dic["oml:task_type_id"]))
421421
common_kwargs = {
422422
"task_id": dic["oml:task_id"],
423423
"task_type": dic["oml:task_type"],
424-
"task_type_id": dic["oml:task_type_id"],
424+
"task_type_id": task_type,
425425
"data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"],
426426
"evaluation_measure": evaluation_measures,
427427
}
428-
if task_type_id in (
429-
TaskTypeEnum.SUPERVISED_CLASSIFICATION,
430-
TaskTypeEnum.SUPERVISED_REGRESSION,
431-
TaskTypeEnum.LEARNING_CURVE,
428+
if task_type in (
429+
TaskType.SUPERVISED_CLASSIFICATION,
430+
TaskType.SUPERVISED_REGRESSION,
431+
TaskType.LEARNING_CURVE,
432432
):
433433
# Convert some more parameters
434434
for parameter in inputs["estimation_procedure"]["oml:estimation_procedure"][
@@ -448,18 +448,18 @@ def _create_task_from_xml(xml):
448448
]["oml:data_splits_url"]
449449

450450
cls = {
451-
TaskTypeEnum.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask,
452-
TaskTypeEnum.SUPERVISED_REGRESSION: OpenMLRegressionTask,
453-
TaskTypeEnum.CLUSTERING: OpenMLClusteringTask,
454-
TaskTypeEnum.LEARNING_CURVE: OpenMLLearningCurveTask,
455-
}.get(task_type_id)
451+
TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask,
452+
TaskType.SUPERVISED_REGRESSION: OpenMLRegressionTask,
453+
TaskType.CLUSTERING: OpenMLClusteringTask,
454+
TaskType.LEARNING_CURVE: OpenMLLearningCurveTask,
455+
}.get(task_type)
456456
if cls is None:
457457
raise NotImplementedError("Task type %s not supported." % common_kwargs["task_type"])
458458
return cls(**common_kwargs)
459459

460460

461461
def create_task(
462-
task_type_id: int,
462+
task_type: TaskType,
463463
dataset_id: int,
464464
estimation_procedure_id: int,
465465
target_name: Optional[str] = None,
@@ -480,7 +480,7 @@ def create_task(
480480
481481
Parameters
482482
----------
483-
task_type_id : int
483+
task_type : TaskType
484484
Id of the task type.
485485
dataset_id : int
486486
The id of the dataset for the task.
@@ -501,17 +501,17 @@ def create_task(
501501
OpenMLLearningCurveTask, OpenMLClusteringTask
502502
"""
503503
task_cls = {
504-
TaskTypeEnum.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask,
505-
TaskTypeEnum.SUPERVISED_REGRESSION: OpenMLRegressionTask,
506-
TaskTypeEnum.CLUSTERING: OpenMLClusteringTask,
507-
TaskTypeEnum.LEARNING_CURVE: OpenMLLearningCurveTask,
508-
}.get(task_type_id)
504+
TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask,
505+
TaskType.SUPERVISED_REGRESSION: OpenMLRegressionTask,
506+
TaskType.CLUSTERING: OpenMLClusteringTask,
507+
TaskType.LEARNING_CURVE: OpenMLLearningCurveTask,
508+
}.get(task_type)
509509

510510
if task_cls is None:
511-
raise NotImplementedError("Task type {0:d} not supported.".format(task_type_id))
511+
raise NotImplementedError("Task type {0:d} not supported.".format(task_type))
512512
else:
513513
return task_cls(
514-
task_type_id=task_type_id,
514+
task_type_id=task_type,
515515
task_type=None,
516516
data_set_id=dataset_id,
517517
target_name=target_name,

0 commit comments

Comments
 (0)