Skip to content

Commit 317240e

Browse files
add: minor updates to TabArena SLURM runner
1 parent 153cecc commit 317240e

File tree

2 files changed

+36
-30
lines changed

2 files changed

+36
-30
lines changed

tabflow_slurm/run_tabarena_experiment.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
def setup_slurm_job(
1212
*,
1313
openml_cache_dir: str,
14-
tabrepo_cache_dir: str,
1514
setup_ray_for_slurm_shared_resources_environment: bool,
1615
num_cpus: int,
1716
num_gpus: int,
@@ -23,8 +22,6 @@ def setup_slurm_job(
2322
----------
2423
openml_cache_dir : str
2524
The path to the OpenML cache directory.
26-
tabrepo_cache_dir : str
27-
The path to the TabRepo cache directory.
2825
num_cpus : int
2926
The number of CPUs to use for the experiment (needed for proper Ray setup).
3027
num_gpus : int
@@ -36,8 +33,11 @@ def setup_slurm_job(
3633
Otherwise, given the shared filesystem, Ray will try to use the same temp dir for all workers and
3734
crash (semi-randomly).
3835
"""
39-
openml.config.set_root_cache_directory(root_cache_directory=openml_cache_dir)
40-
os.environ["TABREPO_CACHE"] = tabrepo_cache_dir
36+
if openml_cache_dir == "auto":
37+
print("Using the default OpenML cache directory.")
38+
else:
39+
print(f"Setting OpenML cache directory to: {openml_cache_dir}")
40+
openml.config.set_root_cache_directory(root_cache_directory=openml_cache_dir)
4141

4242
# SLURM save Ray setup in a shared resource system
4343
ray_dir = None
@@ -241,9 +241,6 @@ def parse_int_list(s):
241241
parser.add_argument(
242242
"--openml_cache_dir", type=str, help="Path to the OpenML cache directory."
243243
)
244-
parser.add_argument(
245-
"--tabrepo_cache_dir", type=str, help="Path to the TabRepo cache directory."
246-
)
247244
parser.add_argument(
248245
"--output_dir",
249246
type=str,
@@ -281,7 +278,6 @@ def parse_int_list(s):
281278
args = parser.parse_args()
282279
ray_temp_dir = setup_slurm_job(
283280
openml_cache_dir=args.openml_cache_dir,
284-
tabrepo_cache_dir=args.tabrepo_cache_dir,
285281
setup_ray_for_slurm_shared_resources_environment=args.setup_ray_for_slurm_shared_resources_environment,
286282
num_cpus=args.num_cpus,
287283
num_gpus=args.num_gpus,

tabflow_slurm/setup_slurm_base.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataclasses import dataclass, field
1414
from itertools import product
1515
from pathlib import Path
16+
from typing import Literal
1617

1718
import pandas as pd
1819
import ray
@@ -53,13 +54,15 @@ class BenchmarkSetup:
5354
)
5455
"""Python script to run the benchmark. This should point to the script that runs the benchmark
5556
for TabArena."""
56-
openml_cache_from_base_path: str = ".openml-cache"
57-
"""OpenML cache directory. This is used to store dataset and tasks data from OpenML."""
58-
tabrepo_cache_dir_from_base_path: str = "input_data/tabrepo"
59-
"""TabRepo cache directory."""
60-
slurm_log_output_from_base_path: str = "slurm_out/runs_1711"
61-
"""Directory for the SLURM output logs. This is used to store the output logs from the
62-
SLURM jobs."""
57+
openml_cache_from_base_path: str | Literal["auto"] = ".openml-cache"
58+
"""OpenML cache directory. This is used to store dataset and tasks data from OpenML.
59+
60+
If "auto", we use the default cache from OpenML.
61+
If any other string, this is interpreted as the path to the folder for a custom OpenML cache.
62+
"""
63+
slurm_log_output_from_base_path: str = "slurm_out/"
64+
"""Directory for the SLURM output logs. In this folder a `benchmark_name` folder will be created
65+
and used to store the output logs from the SLURM jobs."""
6366
output_dir_base_from_base_path: str = "output/"
6467
"""Output directory for the benchmark. In this folder a `benchmark_name` folder will be created."""
6568
configs_path_from_base_path: str = (
@@ -79,6 +82,7 @@ class BenchmarkSetup:
7982
"""Extra SLURM gres to use for the jobs. Adjust as needed for your cluster setup."""
8083
# Task/Data Settings
8184
# ------------------
85+
# TODO: update metadata and usage for non-IID tasks that are not fold-based in the future.
8286
custom_metadata: pd.DataFrame | str | None = None
8387
"""Custom metadata to use for defining the tasks and datasets to run.
8488
@@ -337,17 +341,16 @@ def run_script(self) -> str:
337341
@property
338342
def openml_cache(self) -> str:
339343
"""OpenML cache directory."""
344+
if self.openml_cache_from_base_path == "auto":
345+
return self.openml_cache_from_base_path
340346
return self.base_path + self.openml_cache_from_base_path
341347

342-
@property
343-
def tabrepo_cache_dir(self) -> str:
344-
"""TabRepo cache directory."""
345-
return self.base_path + self.tabrepo_cache_dir_from_base_path
346-
347348
@property
348349
def slurm_log_output(self) -> str:
349350
"""Directory for the SLURM output logs."""
350-
return self.base_path + self.slurm_log_output_from_base_path
351+
return (
352+
self.base_path + self.slurm_log_output_from_base_path + self.benchmark_name
353+
)
351354

352355
@property
353356
def slurm_base_command(self):
@@ -373,7 +376,9 @@ def slurm_base_command(self):
373376
gres += self.slurm_extra_gres
374377
gres = f"--gres={gres}" if len(gres) > 0 else None
375378

376-
time_in_h = self.time_limit // 3600 * self.configs_per_job + self.time_limit_overhead
379+
time_in_h = (
380+
self.time_limit // 3600 * self.configs_per_job + self.time_limit_overhead
381+
)
377382
time_in_h = f"--time={time_in_h}:00:00"
378383
cpus = f"--cpus-per-task={self.num_cpus}"
379384
if is_gpu_job:
@@ -393,8 +398,8 @@ def get_jobs_to_run(self): # noqa: C901
393398
"""Determine all jobs to run by checking the cache and filtering
394399
invalid jobs.
395400
"""
396-
Path(self.openml_cache).mkdir(parents=True, exist_ok=True)
397-
Path(self.tabrepo_cache_dir).mkdir(parents=True, exist_ok=True)
401+
if self.openml_cache != "auto":
402+
Path(self.openml_cache).mkdir(parents=True, exist_ok=True)
398403
Path(self.output_dir).mkdir(parents=True, exist_ok=True)
399404
Path(self.slurm_log_output).mkdir(parents=True, exist_ok=True)
400405

@@ -421,9 +426,13 @@ def get_jobs_to_run(self): # noqa: C901
421426
def yield_all_jobs():
422427
for row in metadata.itertuples():
423428
task_id = row.task_id
424-
n_samples_train_per_fold = int(
425-
row.num_instances - int(row.num_instances / row.num_folds)
426-
)
429+
if hasattr(row, "n_samples_train_per_fold"):
430+
n_samples_train_per_fold = row.n_samples_train_per_fold
431+
else:
432+
# Fallback to estimating the number of training samples per fold
433+
n_samples_train_per_fold = int(
434+
row.num_instances - int(row.num_instances / row.num_folds)
435+
)
427436
n_features = int(row.num_features)
428437
n_classes = (
429438
int(row.num_classes)
@@ -525,7 +534,9 @@ def generate_configs_yaml(self):
525534
"init_kwargs": {"verbosity": self.verbosity},
526535
}
527536
if self.model_artifacts_base_path is not None:
528-
method_kwargs["init_kwargs"]["default_base_path"] = self.model_artifacts_base_path
537+
method_kwargs["init_kwargs"]["default_base_path"] = (
538+
self.model_artifacts_base_path
539+
)
529540
if not self.model_agnostic_preprocessing:
530541
method_kwargs["fit_kwargs"] = {"feature_generator": None}
531542

@@ -620,7 +631,6 @@ def get_jobs_dict(self):
620631
"run_script": self.run_script,
621632
"openml_cache_dir": self.openml_cache,
622633
"configs_yaml_file": self.configs,
623-
"tabrepo_cache_dir": self.tabrepo_cache_dir,
624634
"output_dir": self.output_dir,
625635
"num_cpus": self.num_cpus,
626636
"num_gpus": self.num_gpus,

0 commit comments

Comments
 (0)