Skip to content

Commit e560e99

Browse files
authored
Load data should download data to disk (#692)
1 parent 4d87260 commit e560e99

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

amlb/benchmarks/openml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def load_openml_task_as_definition(domain: str, oml_id: int) -> list[Namespace]:
8686
]
8787

8888

89-
def load_openml_task_and_data(task_id: int) -> tuple[OpenMLTask, OpenMLDataset]:
89+
def load_openml_task_and_data(task_id: int, with_data: bool = False) -> tuple[OpenMLTask, OpenMLDataset]:
9090
task = openml.tasks.get_task(task_id, download_data=False, download_qualities=False)
9191
data = openml.datasets.get_dataset(
92-
task.dataset_id, download_data=False, download_qualities=False
92+
task.dataset_id, download_data=with_data, download_qualities=False
9393
)
9494
return task, data

amlb/datasets/openml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def load(self, task_id=None, dataset_id=None, fold=0):
7272
dataset_id, task_id
7373
)
7474
)
75-
task, dataset = load_openml_task_and_data(task_id)
75+
task, dataset = load_openml_task_and_data(task_id, with_data=True)
7676
_, nfolds, _ = task.get_split_dimensions()
7777
if fold >= nfolds:
7878
raise ValueError(

0 commit comments

Comments
 (0)