Skip to content

Commit 2c2a93d

Browse files
Tiny bug : oml_id must be an int in load_openml_task (#686)
* oml_id must be an int * Catch error earlier to report better error messages * Also give early warning for if task is defined as str in file --------- Co-authored-by: PGijsbers <p.gijsbers@tue.nl>
1 parent b719142 commit 2c2a93d

File tree

4 files changed

+31
-17
lines changed

4 files changed

+31
-17
lines changed

amlb/benchmarks/file.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,9 @@ def load_file_benchmark(
3535
log.info("Loading benchmark definitions from %s.", benchmark_file)
3636
tasks = config_load(benchmark_file)
3737
benchmark_name, _ = os.path.splitext(os.path.basename(benchmark_file))
38+
for task in tasks:
39+
if task["openml_task_id"] is not None and not isinstance(
40+
task["openml_task_id"], int
41+
):
42+
raise TypeError("OpenML task id for task {task.name!r} must be integer.")
3843
return benchmark_name, benchmark_file, tasks

amlb/benchmarks/openml.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import openml
88
import pandas as pd
9+
from openml import OpenMLTask, OpenMLDataset
910

1011
from amlb.utils import Namespace, str_sanitize
1112

@@ -20,7 +21,13 @@ def is_openml_benchmark(benchmark: str) -> bool:
2021

2122
def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]]:
2223
"""Loads benchmark defined by openml suite or task, from openml/s/X or openml/t/Y."""
23-
domain, oml_type, oml_id = benchmark.split("/")
24+
domain, oml_type, oml_id_str = benchmark.split("/")
25+
try:
26+
oml_id = int(oml_id_str)
27+
except ValueError:
28+
raise ValueError(
29+
f"Could not convert OpenML id {oml_id_str!r} in {benchmark!r} to integer."
30+
)
2431

2532
if domain == "test.openml":
2633
log.debug("Setting openml server to the test server.")
@@ -34,7 +41,7 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]
3441
openml.config.set_retry_policy("robot")
3542

3643
if oml_type == "t":
37-
tasks = load_openml_task(domain, oml_id)
44+
tasks = load_openml_task_as_definition(domain, oml_id)
3845
elif oml_type == "s":
3946
tasks = load_openml_tasks_from_suite(domain, oml_id)
4047
else:
@@ -44,7 +51,7 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]
4451
return benchmark, None, tasks
4552

4653

47-
def load_openml_tasks_from_suite(domain: str, oml_id: str) -> list[Namespace]:
54+
def load_openml_tasks_from_suite(domain: str, oml_id: int) -> list[Namespace]:
4855
log.info("Loading openml suite %s.", oml_id)
4956
suite = openml.study.get_suite(oml_id)
5057
# Here we know the (task, dataset) pairs so only download dataset meta-data is sufficient
@@ -66,18 +73,22 @@ def load_openml_tasks_from_suite(domain: str, oml_id: str) -> list[Namespace]:
6673
return tasks
6774

6875

69-
def load_openml_task(domain: str, oml_id: str) -> list[Namespace]:
76+
def load_openml_task_as_definition(domain: str, oml_id: int) -> list[Namespace]:
7077
log.info("Loading openml task %s.", oml_id)
71-
# We first have the retrieve the task because we don't know the dataset id
72-
t = openml.tasks.get_task(oml_id, download_data=False, download_qualities=False)
73-
data = openml.datasets.get_dataset(
74-
t.dataset_id, download_data=False, download_qualities=False
75-
)
78+
task, data = load_openml_task_and_data(oml_id)
7679
return [
7780
Namespace(
7881
name=str_sanitize(data.name),
7982
description=data.description,
80-
openml_task_id=t.id,
81-
id="{}.org/t/{}".format(domain, t.id),
83+
openml_task_id=task.id,
84+
id="{}.org/t/{}".format(domain, task.id),
8285
)
8386
]
87+
88+
89+
def load_openml_task_and_data(task_id: int) -> tuple[OpenMLTask, OpenMLDataset]:
90+
task = openml.tasks.get_task(task_id, download_data=False, download_qualities=False)
91+
data = openml.datasets.get_dataset(
92+
task.dataset_id, download_data=False, download_qualities=False
93+
)
94+
return task, data

amlb/datasets/openml.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import openml as oml
2222
import xmltodict
2323

24+
from ..benchmarks.openml import load_openml_task_and_data
2425
from ..data import AM, DF, Dataset, DatasetType, Datasplit, Feature
2526
from ..datautils import impute_array
2627
from ..resources import config as rconfig, get as rget
@@ -71,10 +72,7 @@ def load(self, task_id=None, dataset_id=None, fold=0):
7172
dataset_id, task_id
7273
)
7374
)
74-
task = oml.tasks.get_task(task_id, download_qualities=False)
75-
dataset = oml.datasets.get_dataset(
76-
task.dataset_id, download_qualities=False
77-
)
75+
task, dataset = load_openml_task_and_data(task_id)
7876
_, nfolds, _ = task.get_split_dimensions()
7977
if fold >= nfolds:
8078
raise ValueError(

tests/unit/amlb/benchmarks/test_openml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from amlb.benchmarks.openml import (
55
is_openml_benchmark,
6-
load_openml_task,
6+
load_openml_task_as_definition,
77
load_oml_benchmark,
88
)
99
from amlb.utils import Namespace
@@ -35,7 +35,7 @@ def test_load_openml_task(mocker, oml_task, oml_dataset):
3535
mocker.patch(
3636
"openml.datasets.get_dataset", new=mocker.Mock(return_value=oml_dataset)
3737
)
38-
[task] = load_openml_task("openml", oml_task.id)
38+
[task] = load_openml_task_as_definition("openml", oml_task.id)
3939
assert task.name == oml_dataset.name
4040
assert task.description == oml_dataset.description
4141
assert task.openml_task_id == oml_task.id

0 commit comments

Comments
 (0)