diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..0583a4f55 --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,6 @@ +train : + plot : + lnjzhore : + slurm_id: 0 + description: "Christian's naoj54ch with new code" + eval: vgbndhco diff --git a/config/runs_plot_train_seq.yml b/config/runs_plot_train_seq.yml new file mode 100644 index 000000000..47e0bb79b --- /dev/null +++ b/config/runs_plot_train_seq.yml @@ -0,0 +1,35 @@ +train : + plot : + mzp2xvd1 : + slurm_id: 0 + description: "1000 samples [3,4,5,6,7,8]" + epoch: 5 + eval: yv0scz3w + + js45xaol : + slurm_id: 0 + description: "4096 samples [3,4,5,6,7,8]" + eval: jq63t9vj + + tt4561ai : + slurm_id: 0 + description: "1000 samples [3,4,5,6,7,8,9,10,11,12]" + epoch: 9 + eval: soxhzpf2 + + uf743jr2 : + slurm_id: 0 + description: "1000 samples [4,6,8]" + epoch: 2 + eval: elnk0mvo + + gygh6uzm : + slurm_id: 0 + description: "4096 samples [4,6,8]" + eval: y7lruh4d + + uj94vuey : + slurm_id: 0 + description: "4096 samples [4...4,6...6,8...8]" + eval: q5krziaj + diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 573c084c4..8e1bf6c61 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -127,10 +127,12 @@ def _read_yaml_config(yaml_file_path): for k, v in config_dict_temp.items(): assert isinstance(v["slurm_id"], int), "slurm_id has to be int." assert isinstance(v["description"], str), "description has to be str." - config_dict[k] = [v["slurm_id"], v["description"]] + assert isinstance(v.get("epoch",-1),int) + config_dict[k] = [v["slurm_id"], v["description"],v.get("epoch",-1)] # Validate the structure: {run_id: [job_id, experiment_name]} - _check_run_id_dict(config_dict) + #_check_run_id_dict(config_dict) + return config_dict @@ -150,7 +152,7 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### -def get_stream_names(run_id: str, model_path: Path | None = "./model"): +def get_stream_names(run_id: str, model_path: Path | None = "./model", mini_epoch = -1): """ Get the stream names from the model configuration file. @@ -167,7 +169,7 @@ def get_stream_names(run_id: str, model_path: Path | None = "./model"): List of stream names """ # return col names from training (should be identical to validation) - cf = config.load_model_config(run_id, -1, model_path=model_path) + cf = config.load_model_config(run_id, mini_epoch, model_path=model_path) return [si["name"].replace(",", "").replace("/", "_").replace(" ", "_") for si in cf.streams] @@ -673,8 +675,7 @@ def plot_train(args=None): clean_plot_folder(out_dir) # read logged data - - runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids] + runs_data = [TrainLogger.read(run_id, model_path=model_base_dir, mini_epoch=runs_ids[run_id][2]) for run_id in runs_ids] # determine which runs are still alive (as a process, though they might hang internally) ret = subprocess.run(["squeue"], capture_output=True) @@ -730,7 +731,7 @@ def plot_train(args=None): run_id, runs_ids[run_id], run_data, - get_stream_names(run_id, model_path=model_base_dir), # limit to available streams + get_stream_names(run_id, model_path=model_base_dir, mini_epoch=runs_ids[run_id][2]), # limit to available streams plot_dir=out_dir, ) plot_loss_per_run( @@ -738,7 +739,7 @@ def plot_train(args=None): run_id, runs_ids[run_id], run_data, - get_stream_names(run_id, model_path=model_base_dir), # limit to available streams + get_stream_names(run_id, model_path=model_base_dir, mini_epoch=runs_ids[run_id][2]), # limit to available streams plot_dir=out_dir, )