1313from dataclasses import dataclass , field
1414from itertools import product
1515from pathlib import Path
16+ from typing import Literal
1617
1718import pandas as pd
1819import ray
@@ -53,13 +54,15 @@ class BenchmarkSetup:
5354 )
5455 """Python script to run the benchmark. This should point to the script that runs the benchmark
5556 for TabArena."""
56- openml_cache_from_base_path : str = ".openml-cache"
57- """OpenML cache directory. This is used to store dataset and tasks data from OpenML."""
58- tabrepo_cache_dir_from_base_path : str = "input_data/tabrepo"
59- """TabRepo cache directory."""
60- slurm_log_output_from_base_path : str = "slurm_out/runs_1711"
61- """Directory for the SLURM output logs. This is used to store the output logs from the
62- SLURM jobs."""
57+ openml_cache_from_base_path : str | Literal ["auto" ] = ".openml-cache"
58+ """OpenML cache directory. This is used to store dataset and tasks data from OpenML.
59+
60+ If "auto", we use the default cache from OpenML.
61+ If any other string, this is interpreted as the path to the folder for a custom OpenML cache.
62+ """
63+ slurm_log_output_from_base_path : str = "slurm_out/"
64+ """Directory for the SLURM output logs. In this folder a `benchmark_name` folder will be created
65+ and used to store the output logs from the SLURM jobs."""
6366 output_dir_base_from_base_path : str = "output/"
6467 """Output directory for the benchmark. In this folder a `benchmark_name` folder will be created."""
6568 configs_path_from_base_path : str = (
@@ -79,6 +82,7 @@ class BenchmarkSetup:
7982 """Extra SLURM gres to use for the jobs. Adjust as needed for your cluster setup."""
8083 # Task/Data Settings
8184 # ------------------
85+ # TODO: update metadata and usage for non-IID tasks that are not fold-based in the future.
8286 custom_metadata : pd .DataFrame | str | None = None
8387 """Custom metadata to use for defining the tasks and datasets to run.
8488
@@ -337,17 +341,16 @@ def run_script(self) -> str:
337341 @property
338342 def openml_cache (self ) -> str :
339343 """OpenML cache directory."""
344+ if self .openml_cache_from_base_path == "auto" :
345+ return self .openml_cache_from_base_path
340346 return self .base_path + self .openml_cache_from_base_path
341347
342- @property
343- def tabrepo_cache_dir (self ) -> str :
344- """TabRepo cache directory."""
345- return self .base_path + self .tabrepo_cache_dir_from_base_path
346-
347348 @property
348349 def slurm_log_output (self ) -> str :
349350 """Directory for the SLURM output logs."""
350- return self .base_path + self .slurm_log_output_from_base_path
351+ return (
352+ self .base_path + self .slurm_log_output_from_base_path + self .benchmark_name
353+ )
351354
352355 @property
353356 def slurm_base_command (self ):
@@ -373,7 +376,9 @@ def slurm_base_command(self):
373376 gres += self .slurm_extra_gres
374377 gres = f"--gres={ gres } " if len (gres ) > 0 else None
375378
376- time_in_h = self .time_limit // 3600 * self .configs_per_job + self .time_limit_overhead
379+ time_in_h = (
380+ self .time_limit // 3600 * self .configs_per_job + self .time_limit_overhead
381+ )
377382 time_in_h = f"--time={ time_in_h } :00:00"
378383 cpus = f"--cpus-per-task={ self .num_cpus } "
379384 if is_gpu_job :
@@ -393,8 +398,8 @@ def get_jobs_to_run(self): # noqa: C901
393398 """Determine all jobs to run by checking the cache and filtering
394399 invalid jobs.
395400 """
396- Path ( self .openml_cache ). mkdir ( parents = True , exist_ok = True )
397- Path (self .tabrepo_cache_dir ).mkdir (parents = True , exist_ok = True )
401+ if self .openml_cache != "auto" :
402+ Path (self .openml_cache ).mkdir (parents = True , exist_ok = True )
398403 Path (self .output_dir ).mkdir (parents = True , exist_ok = True )
399404 Path (self .slurm_log_output ).mkdir (parents = True , exist_ok = True )
400405
@@ -421,9 +426,13 @@ def get_jobs_to_run(self): # noqa: C901
421426 def yield_all_jobs ():
422427 for row in metadata .itertuples ():
423428 task_id = row .task_id
424- n_samples_train_per_fold = int (
425- row .num_instances - int (row .num_instances / row .num_folds )
426- )
429+ if hasattr (row , "n_samples_train_per_fold" ):
430+ n_samples_train_per_fold = row .n_samples_train_per_fold
431+ else :
432+ # Fallback to estimating the number of training samples per fold
433+ n_samples_train_per_fold = int (
434+ row .num_instances - int (row .num_instances / row .num_folds )
435+ )
427436 n_features = int (row .num_features )
428437 n_classes = (
429438 int (row .num_classes )
@@ -525,7 +534,9 @@ def generate_configs_yaml(self):
525534 "init_kwargs" : {"verbosity" : self .verbosity },
526535 }
527536 if self .model_artifacts_base_path is not None :
528- method_kwargs ["init_kwargs" ]["default_base_path" ] = self .model_artifacts_base_path
537+ method_kwargs ["init_kwargs" ]["default_base_path" ] = (
538+ self .model_artifacts_base_path
539+ )
529540 if not self .model_agnostic_preprocessing :
530541 method_kwargs ["fit_kwargs" ] = {"feature_generator" : None }
531542
@@ -620,7 +631,6 @@ def get_jobs_dict(self):
620631 "run_script" : self .run_script ,
621632 "openml_cache_dir" : self .openml_cache ,
622633 "configs_yaml_file" : self .configs ,
623- "tabrepo_cache_dir" : self .tabrepo_cache_dir ,
624634 "output_dir" : self .output_dir ,
625635 "num_cpus" : self .num_cpus ,
626636 "num_gpus" : self .num_gpus ,
0 commit comments