Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config/runs_plot_train.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
train :
plot :
lnjzhore :
slurm_id: 0
description: "Christian's naoj54ch with new code"
eval: vgbndhco
35 changes: 35 additions & 0 deletions config/runs_plot_train_seq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
train :
plot :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you intend to commit these configs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not big of a change but doesn't affect anything in terms of functionality of code? I feel I (maybe others) may need it later also.

mzp2xvd1 :
slurm_id: 0
description: "1000 samples [3,4,5,6,7,8]"
epoch: 5
eval: yv0scz3w

js45xaol :
slurm_id: 0
description: "4096 samples [3,4,5,6,7,8]"
eval: jq63t9vj

tt4561ai :
slurm_id: 0
description: "1000 samples [3,4,5,6,7,8,9,10,11,12]"
epoch: 9
eval: soxhzpf2

uf743jr2 :
slurm_id: 0
description: "1000 samples [4,6,8]"
epoch: 2
eval: elnk0mvo

gygh6uzm :
slurm_id: 0
description: "4096 samples [4,6,8]"
eval: y7lruh4d

uj94vuey :
slurm_id: 0
description: "4096 samples [4...4,6...6,8...8]"
eval: q5krziaj

17 changes: 9 additions & 8 deletions src/weathergen/utils/plot_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ def _read_yaml_config(yaml_file_path):
for k, v in config_dict_temp.items():
assert isinstance(v["slurm_id"], int), "slurm_id has to be int."
assert isinstance(v["description"], str), "description has to be str."
config_dict[k] = [v["slurm_id"], v["description"]]
assert isinstance(v.get("epoch",-1),int)
config_dict[k] = [v["slurm_id"], v["description"],v.get("epoch",-1)]

# Validate the structure: {run_id: [job_id, experiment_name]}
_check_run_id_dict(config_dict)
#_check_run_id_dict(config_dict)


return config_dict

Expand All @@ -150,7 +152,7 @@ def clean_plot_folder(plot_dir: Path):


####################################################################################################
def get_stream_names(run_id: str, model_path: Path | None = "./model"):
def get_stream_names(run_id: str, model_path: Path | None = "./model", mini_epoch = -1):
"""
Get the stream names from the model configuration file.

Expand All @@ -167,7 +169,7 @@ def get_stream_names(run_id: str, model_path: Path | None = "./model"):
List of stream names
"""
# return col names from training (should be identical to validation)
cf = config.load_model_config(run_id, -1, model_path=model_path)
cf = config.load_model_config(run_id, mini_epoch, model_path=model_path)
return [si["name"].replace(",", "").replace("/", "_").replace(" ", "_") for si in cf.streams]


Expand Down Expand Up @@ -673,8 +675,7 @@ def plot_train(args=None):
clean_plot_folder(out_dir)

# read logged data

runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids]
runs_data = [TrainLogger.read(run_id, model_path=model_base_dir, mini_epoch=runs_ids[run_id][2]) for run_id in runs_ids]

# determine which runs are still alive (as a process, though they might hang internally)
ret = subprocess.run(["squeue"], capture_output=True)
Expand Down Expand Up @@ -730,15 +731,15 @@ def plot_train(args=None):
run_id,
runs_ids[run_id],
run_data,
get_stream_names(run_id, model_path=model_base_dir), # limit to available streams
get_stream_names(run_id, model_path=model_base_dir, mini_epoch=runs_ids[run_id][2]), # limit to available streams
plot_dir=out_dir,
)
plot_loss_per_run(
["val"],
run_id,
runs_ids[run_id],
run_data,
get_stream_names(run_id, model_path=model_base_dir), # limit to available streams
get_stream_names(run_id, model_path=model_base_dir, mini_epoch=runs_ids[run_id][2]), # limit to available streams
plot_dir=out_dir,
)

Expand Down
Loading