Skip to content

Commit dbd9e84

Browse files
add: update slurm scripts runner
1 parent 317240e commit dbd9e84

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tabflow_slurm/setup_slurm_base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,25 +641,27 @@ def get_jobs_dict(self):
641641
}
642642
return {"defaults": default_args, "jobs": jobs}
643643

644-
def setup_jobs(self):
644+
def setup_jobs(self) -> str:
645645
"""Setup the jobs to run by generating the SLURM job JSON file."""
646646
jobs_dict = self.get_jobs_dict()
647647
n_jobs = len(jobs_dict["jobs"])
648648
if n_jobs == 0:
649649
print("No jobs to run.")
650650
Path(self.slurm_job_json).unlink(missing_ok=True)
651651
Path(self.configs).unlink(missing_ok=True)
652-
return
652+
return "N/A"
653653

654654
with open(self.slurm_job_json, "w") as f:
655655
json.dump(jobs_dict, f)
656656

657+
run_command = f"sbatch --array=0-{n_jobs - 1}%100 {self.slurm_base_command} {self.slurm_job_json}"
657658
print(
658659
f"##### Setup Jobs for {self._safe_benchmark_name}"
659660
"\nRun the following command to start the jobs:"
660-
f"\nsbatch --array=0-{n_jobs - 1}%100 {self.slurm_base_command} {self.slurm_job_json}"
661+
f"\n{run_command}"
661662
"\n"
662663
)
664+
return run_command
663665

664666
@property
665667
def models_to_constraints(self) -> dict[str, dict[str, int]]:

tabflow_slurm/submit_template.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ PYTHON_PATH=$(jq -r '.defaults.python' "$JSON_FILE")
4040
RUNSCRIPT=$(jq -r '.defaults.run_script' "$JSON_FILE")
4141
OPENML_CACHE_DIR=$(jq -r '.defaults.openml_cache_dir' "$JSON_FILE")
4242
CONFIGS_YAML_FILE=$(jq -r '.defaults.configs_yaml_file' "$JSON_FILE")
43-
TABREPO_CACHE_DIR=$(jq -r '.defaults.tabrepo_cache_dir' "$JSON_FILE")
4443
OUTPUT_DIR=$(jq -r '.defaults.output_dir' "$JSON_FILE")
4544
NUM_CPUS=$(jq -r '.defaults.num_cpus' "$JSON_FILE")
4645
NUM_GPUS=$(jq -r '.defaults.num_gpus' "$JSON_FILE")
@@ -53,7 +52,6 @@ echo "Python Path: $PYTHON_PATH"
5352
echo "Run Script: $RUNSCRIPT"
5453
echo "OpenML Cache Directory: $OPENML_CACHE_DIR"
5554
echo "Configs YAML File: $CONFIGS_YAML_FILE"
56-
echo "Tabrepo Cache Directory: $TABREPO_CACHE_DIR"
5755
echo "Output Directory: $OUTPUT_DIR"
5856
echo "Number of CPUs: $NUM_CPUS"
5957
echo "Number of GPUs: $NUM_GPUS"
@@ -89,7 +87,6 @@ for CI in "${CONFIG_ARRAY[@]}"; do
8987
--config_index $CI \
9088
--configs_yaml_file $CONFIGS_YAML_FILE \
9189
--openml_cache_dir $OPENML_CACHE_DIR \
92-
--tabrepo_cache_dir $TABREPO_CACHE_DIR \
9390
--output_dir $OUTPUT_DIR \
9491
--num_cpus $NUM_CPUS \
9592
--num_gpus $NUM_GPUS \

0 commit comments

Comments
 (0)