Skip to content

Commit 622cd68

Browse files
committed
Refactor wandb parameters structure and update script to include params_wrapper for improved organization
1 parent b09be80 commit 622cd68

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

batch_run/dispatcher.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535

3636

3737
params_template = {
38-
"kwargs_wandb_init": {
39-
"project": "face_rhythm",
40-
"name": name_wandb,
41-
"dir_save": dir_save,
42-
"reinit": True,
38+
"params_wrapper": {
39+
"kwargs_wandb_init": {
40+
"project": "face_rhythm",
41+
"name": name_wandb,
42+
"dir": dir_save,
43+
},
44+
"path_script": path_script,
45+
"period_logger": 2,
4346
},
44-
"path_script": path_script,
4547

4648
"params_script": {
4749
"steps": [
@@ -269,9 +271,6 @@
269271

270272
## Prepare call str and set environment variables
271273
os.environ['WANDB_API_KEY'] = wandb_api_key
272-
# Initialize a wandb run with minimal settings.
273-
import wandb
274-
wandb.init(project='face_rhythm', name=name_wandb, config=params[0], reinit=True)
275274

276275
## define slurm SBATCH parameters
277276
# sbatch_config_list = \
@@ -320,12 +319,12 @@
320319
#SBATCH --account=kempner_bsabatini_lab # The account name for the job.
321320
#SBATCH --job-name={name_slurm} # Job name
322321
#SBATCH --output={path} # File to write: STDOUT (and STDERR if --error is not used)
323-
#SBATCH --partition=kempner_requeue # Partition (job queue)
322+
#SBATCH --partition=kempner_requeue # Partition (job queue)
324323
#SBATCH --gres=gpu:1 # Number of GPUs
325324
#SBATCH -c 16 # Number of cores (-c) on one node
326325
#SBATCH -n 1 # Number of nodes (-n)
327-
#SBATCH --mem=64GB # Memory pool for all cores (see also --mem-per-cpu)
328-
#SBATCH --time=0-1:00:00 # Runtime in D-HH:MM:SS
326+
#SBATCH --mem=128GB # Memory pool for all cores (see also --mem-per-cpu)
327+
#SBATCH --time=0-0:10:00 # Runtime in D-HH:MM:SS
329328
#SBATCH --requeue # Requeue the job if it is preempted
330329
#SBATCH --export=WANDB_API_KEY # Export the WANDB_API_KEY environment variable to the job
331330
@@ -335,7 +334,7 @@
335334
echo "activating environment"
336335
source activate {name_env}
337336
338-
echo "starting job with call: python $@"
337+
echo "starting job"
339338
python "$@"
340339
""" for path in paths_log]
341340

scripts/wandb_script_wrapper.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
1. Parses command-line arguments, expecting two positional arguments:
1313
- path_params: Path to a JSON file containing parameters for the target script. This
1414
JSON must include at minimum:
15-
* "path_script": String path to the target script to be executed.
16-
* "kwargs_wandb_init": (Optional) A dictionary of keyword arguments for
17-
wandb.init().
18-
* "params_script" can be provided, which will be saved as a JSON to the output
19-
directory and --path_params will be passed to the target script.
15+
* "params_wrapper": Containing the following:
16+
* "path_script": String path to the target script to be executed.
17+
* "kwargs_wandb_init": (Optional) A dictionary of keyword arguments for
18+
wandb.init().
19+
* "params_script" can be provided, which will be saved as a JSON to the output
20+
directory and --path_params will be passed to the target script.
2021
- directory_save: Directory path where output files (such as logs and saved parameters)
2122
will be stored. Will be passed to the target script as --directory_save.
2223
@@ -37,11 +38,14 @@
3738
"/path/to/save_dir". The parameter JSON file should include entries like:
3839
3940
{
40-
"path_script": "/path/to/your_target_script.py",
41-
"kwargs_wandb_init": {
42-
"project": "face_rhythm",
43-
"entity": "your_wandb_username",
44-
"name": "example_run"
41+
"params_wrapper": {
42+
"path_script": "/path/to/your_target_script.py",
43+
"kwargs_wandb_init": {
44+
"project": "face_rhythm",
45+
"entity": "your_wandb_username",
46+
"name": "example_run"
47+
},
48+
"period_logger": 2,
4549
},
4650
"params_script": {
4751
"example_param": "value"
@@ -62,6 +66,8 @@
6266
import time
6367
import wandb
6468
import psutil
69+
import functools
70+
6571

6672
def stream_reader(pipe, log_label):
6773
"""
@@ -113,11 +119,15 @@ def monitor_system_metrics(interval: int = 30):
113119
import json
114120
with open(path_params, 'r') as f:
115121
params = json.load(f)
122+
123+
# Get sub parameters for wrapper
124+
assert 'params_wrapper' in params, "Error: 'params_wrapper' is missing in the parameters file."
125+
params_wrapper = params['params_wrapper']
116126

117127
# Gather kwargs_wandb_init from the JSON file.
118-
kwargs_wandb_init = params.get('kwargs_wandb_init', None)
128+
kwargs_wandb_init = params_wrapper.get('kwargs_wandb_init', None)
119129
# Gather path_script from the JSON file. Error if missing
120-
path_script = params.get('path_script', None)
130+
path_script = params_wrapper.get('path_script', None)
121131
if path_script is None:
122132
print("Error: 'path_script' is missing in the parameters file.")
123133
sys.exit(1)
@@ -134,6 +144,9 @@ def monitor_system_metrics(interval: int = 30):
134144
json.dump(params_script, f)
135145
else:
136146
print("Warning: 'params_script' is not provided in the parameters file. Skipping saving parameters.")
147+
148+
# Prepare call to monitor_system_metrics.
149+
monitor_system_metrics = functools.partial(monitor_system_metrics, interval=params_wrapper.get('period_logger', 30))
137150

138151
# Ensure WandB is installed.
139152
try:

0 commit comments

Comments
 (0)