1515 OpenMLClassificationTask ,
1616 OpenMLClusteringTask ,
1717 OpenMLLearningCurveTask ,
18- TaskTypeEnum ,
18+ TaskType ,
1919 OpenMLRegressionTask ,
2020 OpenMLSupervisedTask ,
2121 OpenMLTask ,
@@ -109,7 +109,7 @@ def _get_estimation_procedure_list():
109109 procs .append (
110110 {
111111 "id" : int (proc_ ["oml:id" ]),
112- "task_type_id" : int (proc_ ["oml:ttid" ]),
112+ "task_type_id" : TaskType ( int (proc_ ["oml:ttid" ]) ),
113113 "name" : proc_ ["oml:name" ],
114114 "type" : proc_ ["oml:type" ],
115115 }
@@ -119,22 +119,22 @@ def _get_estimation_procedure_list():
119119
120120
121121def list_tasks (
122- task_type_id : Optional [int ] = None ,
122+ task_type : Optional [TaskType ] = None ,
123123 offset : Optional [int ] = None ,
124124 size : Optional [int ] = None ,
125125 tag : Optional [str ] = None ,
126126 output_format : str = "dict" ,
127127 ** kwargs
128128) -> Union [Dict , pd .DataFrame ]:
129129 """
130- Return a number of tasks having the given tag and task_type_id
130+ Return a number of tasks having the given tag and task_type
131131
132132 Parameters
133133 ----------
134- Filter task_type_id is separated from the other filters because
135- it is used as task_type_id in the task description, but it is named
134+ Filter task_type is separated from the other filters because
135+ it is used as task_type in the task description, but it is named
136136 type when used as a filter in list tasks call.
137- task_type_id : int , optional
137+ task_type : TaskType , optional
138138 ID of the task type as detailed `here <https://www.openml.org/search?type=task_type>`_.
139139 - Supervised classification: 1
140140 - Supervised regression: 2
@@ -162,12 +162,12 @@ def list_tasks(
162162 Returns
163163 -------
164164 dict
165- All tasks having the given task_type_id and the give tag. Every task is
165+ All tasks having the given task_type and the give tag. Every task is
166166 represented by a dictionary containing the following information:
167167 task id, dataset id, task_type and status. If qualities are calculated
168168 for the associated dataset, some of these are also returned.
169169 dataframe
170- All tasks having the given task_type_id and the give tag. Every task is
170+ All tasks having the given task_type and the give tag. Every task is
171171 represented by a row in the data frame containing the following information
172172 as columns: task id, dataset id, task_type and status. If qualities are
173173 calculated for the associated dataset, some of these are also returned.
@@ -179,23 +179,23 @@ def list_tasks(
179179 return openml .utils ._list_all (
180180 output_format = output_format ,
181181 listing_call = _list_tasks ,
182- task_type_id = task_type_id ,
182+ task_type = task_type ,
183183 offset = offset ,
184184 size = size ,
185185 tag = tag ,
186186 ** kwargs
187187 )
188188
189189
190- def _list_tasks (task_type_id = None , output_format = "dict" , ** kwargs ):
190+ def _list_tasks (task_type = None , output_format = "dict" , ** kwargs ):
191191 """
192192 Perform the api call to return a number of tasks having the given filters.
193193 Parameters
194194 ----------
195- Filter task_type_id is separated from the other filters because
196- it is used as task_type_id in the task description, but it is named
195+ Filter task_type is separated from the other filters because
196+ it is used as task_type in the task description, but it is named
197197 type when used as a filter in list tasks call.
198- task_type_id : int , optional
198+ task_type : TaskType , optional
199199 ID of the task type as detailed
200200 `here <https://www.openml.org/search?type=task_type>`_.
201201 - Supervised classification: 1
@@ -220,8 +220,8 @@ def _list_tasks(task_type_id=None, output_format="dict", **kwargs):
220220 dict or dataframe
221221 """
222222 api_call = "task/list"
223- if task_type_id is not None :
224- api_call += "/type/%d" % int ( task_type_id )
223+ if task_type is not None :
224+ api_call += "/type/%d" % task_type . value
225225 if kwargs is not None :
226226 for operator , value in kwargs .items ():
227227 if operator == "task_id" :
@@ -259,7 +259,7 @@ def __list_tasks(api_call, output_format="dict"):
259259 tid = int (task_ ["oml:task_id" ])
260260 task = {
261261 "tid" : tid ,
262- "ttid" : int (task_ ["oml:task_type_id" ]),
262+ "ttid" : TaskType ( int (task_ ["oml:task_type_id" ]) ),
263263 "did" : int (task_ ["oml:did" ]),
264264 "name" : task_ ["oml:name" ],
265265 "task_type" : task_ ["oml:task_type" ],
@@ -417,18 +417,18 @@ def _create_task_from_xml(xml):
417417 "oml:evaluation_measure"
418418 ]
419419
420- task_type_id = int (dic ["oml:task_type_id" ])
420+ task_type = TaskType ( int (dic ["oml:task_type_id" ]) )
421421 common_kwargs = {
422422 "task_id" : dic ["oml:task_id" ],
423423 "task_type" : dic ["oml:task_type" ],
424- "task_type_id" : dic [ "oml:task_type_id" ] ,
424+ "task_type_id" : task_type ,
425425 "data_set_id" : inputs ["source_data" ]["oml:data_set" ]["oml:data_set_id" ],
426426 "evaluation_measure" : evaluation_measures ,
427427 }
428- if task_type_id in (
429- TaskTypeEnum .SUPERVISED_CLASSIFICATION ,
430- TaskTypeEnum .SUPERVISED_REGRESSION ,
431- TaskTypeEnum .LEARNING_CURVE ,
428+ if task_type in (
429+ TaskType .SUPERVISED_CLASSIFICATION ,
430+ TaskType .SUPERVISED_REGRESSION ,
431+ TaskType .LEARNING_CURVE ,
432432 ):
433433 # Convert some more parameters
434434 for parameter in inputs ["estimation_procedure" ]["oml:estimation_procedure" ][
@@ -448,18 +448,18 @@ def _create_task_from_xml(xml):
448448 ]["oml:data_splits_url" ]
449449
450450 cls = {
451- TaskTypeEnum .SUPERVISED_CLASSIFICATION : OpenMLClassificationTask ,
452- TaskTypeEnum .SUPERVISED_REGRESSION : OpenMLRegressionTask ,
453- TaskTypeEnum .CLUSTERING : OpenMLClusteringTask ,
454- TaskTypeEnum .LEARNING_CURVE : OpenMLLearningCurveTask ,
455- }.get (task_type_id )
451+ TaskType .SUPERVISED_CLASSIFICATION : OpenMLClassificationTask ,
452+ TaskType .SUPERVISED_REGRESSION : OpenMLRegressionTask ,
453+ TaskType .CLUSTERING : OpenMLClusteringTask ,
454+ TaskType .LEARNING_CURVE : OpenMLLearningCurveTask ,
455+ }.get (task_type )
456456 if cls is None :
457457 raise NotImplementedError ("Task type %s not supported." % common_kwargs ["task_type" ])
458458 return cls (** common_kwargs )
459459
460460
461461def create_task (
462- task_type_id : int ,
462+ task_type : TaskType ,
463463 dataset_id : int ,
464464 estimation_procedure_id : int ,
465465 target_name : Optional [str ] = None ,
@@ -480,7 +480,7 @@ def create_task(
480480
481481 Parameters
482482 ----------
483- task_type_id : int
483+ task_type : TaskType
484484 Id of the task type.
485485 dataset_id : int
486486 The id of the dataset for the task.
@@ -501,17 +501,17 @@ def create_task(
501501 OpenMLLearningCurveTask, OpenMLClusteringTask
502502 """
503503 task_cls = {
504- TaskTypeEnum .SUPERVISED_CLASSIFICATION : OpenMLClassificationTask ,
505- TaskTypeEnum .SUPERVISED_REGRESSION : OpenMLRegressionTask ,
506- TaskTypeEnum .CLUSTERING : OpenMLClusteringTask ,
507- TaskTypeEnum .LEARNING_CURVE : OpenMLLearningCurveTask ,
508- }.get (task_type_id )
504+ TaskType .SUPERVISED_CLASSIFICATION : OpenMLClassificationTask ,
505+ TaskType .SUPERVISED_REGRESSION : OpenMLRegressionTask ,
506+ TaskType .CLUSTERING : OpenMLClusteringTask ,
507+ TaskType .LEARNING_CURVE : OpenMLLearningCurveTask ,
508+ }.get (task_type )
509509
510510 if task_cls is None :
511- raise NotImplementedError ("Task type {0:d} not supported." .format (task_type_id ))
511+ raise NotImplementedError ("Task type {0:d} not supported." .format (task_type ))
512512 else :
513513 return task_cls (
514- task_type_id = task_type_id ,
514+ task_type_id = task_type ,
515515 task_type = None ,
516516 data_set_id = dataset_id ,
517517 target_name = target_name ,
0 commit comments