diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 04ac47746..6a9137f60 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -33,6 +33,7 @@ from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( extract_batch_metadata, + filter_config_by_enabled, get_batch_size_from_config, get_target_idxs_from_cfg, ) @@ -98,10 +99,19 @@ def init(self, cf: Config, devices): self.freeze_modules = cf.get("freeze_modules", "") + # keys to filter for enabled/disabled + keys_to_filter = ["losses", "model_input", "target_input"] + + # get training config and remove disabled options (e.g. because of overrides) self.training_cfg = cf.get("training_config") + self.training_cfg = filter_config_by_enabled(self.training_cfg, keys_to_filter) + # validation and test configs are training configs, updated by specified keys self.validation_cfg = merge_configs(self.training_cfg, cf.get("validation_config", {})) + self.validation_cfg = filter_config_by_enabled(self.validation_cfg, keys_to_filter) + # test cfg is derived from validation cfg with specified keys overwritten self.test_cfg = merge_configs(self.validation_cfg, cf.get("test_config", {})) + self.test_cfg = filter_config_by_enabled(self.test_cfg, keys_to_filter) # batch sizes self.batch_size_per_gpu = get_batch_size_from_config(self.training_cfg) diff --git a/src/weathergen/train/utils.py b/src/weathergen/train/utils.py index ebfafab25..67dee7e54 100644 --- a/src/weathergen/train/utils.py +++ b/src/weathergen/train/utils.py @@ -143,6 +143,21 @@ def get_target_idxs_from_cfg(cfg, loss_name) -> list[int] | None: tc = [v.get("target_source_correspondence") for _, v in cfg.losses[loss_name].loss_fcts.items()] tc = [list(t.keys()) for t in tc if t is not None] - target_idxs = list(set([i for t in tc for i in t])) if len(tc) > 0 else None + target_idxs = list(set([int(i) for t in tc for i in t])) if len(tc) > 0 else None return target_idxs + + +def filter_config_by_enabled(cfg, keys): + """ + Filtered disabled entries from config + """ + + for key in keys: + filtered = {} + for k, v in cfg.get(key, {}).items(): + if v.get("enabled", True): + filtered[k] = v + cfg[key] = filtered + + return cfg diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 35bfafe3e..2f517a7f3 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -206,12 +206,13 @@ def plot_lr( if run_data.train.is_empty(): continue run_id = run_data.run_id - x_col = next(filter(lambda c: x_axis in c, run_data.train.columns)) - data_cols = list(filter(lambda c: "learning_rate" in c, run_data.train.columns)) + # x_col = next(filter(lambda c: x_axis in c, run_data.train.columns)) + # data_cols = list(filter(lambda c: "learning_rate" in c, run_data.train.columns)) plt.plot( - run_data.train[x_col], - run_data.train[data_cols], + # run_data.train[x_col], + # run_data.train[data_cols], + run_data.train, linestyle, color=colors[j % len(colors)], ) @@ -384,7 +385,7 @@ def plot_loss_per_stream( if run_data_mode.is_empty(): continue # find the col of the request x-axis (e.g. samples) - x_col = next(filter(lambda c: x_axis in c, run_data_mode.columns)) + # x_col = next(filter(lambda c: x_axis in c, run_data_mode.columns)) # find the cols of the requested metric (e.g. mse) for all streams # TODO: fix captialization data_cols = filter( @@ -393,11 +394,11 @@ def plot_loss_per_stream( ) for col in data_cols: - x_vals = np.array(run_data_mode[x_col]) + # x_vals = np.array(run_data_mode[x_col]) y_data = np.array(run_data_mode[col]) plt.plot( - x_vals, + # x_vals, y_data, linestyle, color=colors[j % len(colors)], @@ -512,7 +513,7 @@ def plot_loss_per_run( alpha = 0.35 if "train" in mode else alpha run_data_mode = run_data.by_mode(mode) - x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0] + # x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0] # find the cols of the requested metric (e.g. mse) for all streams data_cols = [c for _, c in enumerate(run_data_mode.columns) if err in c] @@ -525,11 +526,11 @@ def plot_loss_per_run( if run_data_mode[col].shape[0] == 0: continue - x_vals = np.array(run_data_mode[x_col]) + # x_vals = np.array(run_data_mode[x_col]) y_data = np.array(run_data_mode[col]) plt.plot( - x_vals, + # x_vals, y_data, linestyle, color=colors[j % len(colors)], diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 929a78ea9..a640b926b 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -156,20 +156,21 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: cols_train = ["dtime", "samples", "mse", "lr"] cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] for si in cf.streams: - for lf in cf.loss_fcts: + for lf, _ in cf.training_config.losses.items(): cols1 += [_key_loss(si["name"], lf[0])] cols_train += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] - ] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts] - if with_stddev: - for si in cf.streams: - cols1 += [_key_stddev(si["name"])] - cols_train += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") - + ", " - + "stddev" + si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf ] + # with_stddev = [("stats" in lf) for lf, _ in cf.training_config.losses.items()] + # if with_stddev: + # for si in cf.streams: + # cols1 += [_key_stddev(si["name"])] + # cols_train += [ + # si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + # + ", " + # + "stddev" + # ] + # read training log data try: with open(fname_log_train, "rb") as f: @@ -214,20 +215,25 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: cols_val = ["dtime", "samples"] cols2 = [_weathergen_timestamp, "num_samples"] for si in cf.streams: - for lf in cf.loss_fcts_val: + cfg = ( + cf.validation_config + if cf.validation_config.get("losses") is not None + else cf.training_config + ) + for lf, _ in cfg.losses.items(): cols_val += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] + si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf ] cols2 += [_key_loss(si["name"], lf[0])] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val] - if with_stddev: - for si in cf.streams: - cols2 += [_key_stddev(si["name"])] - cols_val += [ - si["name"].replace(",", "").replace("/", "_").replace(" ", "_") - + ", " - + "stddev" - ] + # with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val] + # if with_stddev: + # for si in cf.streams: + # cols2 += [_key_stddev(si["name"])] + # cols_val += [ + # si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + # + ", " + # + "stddev" + # ] # read validation log data try: with open(fname_log_val, "rb") as f: @@ -370,6 +376,8 @@ def clean_df(df, columns: list[str] | None): idcs = [i for i in range(len(columns)) if columns[i] == "loss_avg_mean"] if len(idcs) > 0: columns[idcs[0]] = "loss_avg_0_mean" + # TODO, TODO, TODO + columns = ["LossPhysical.loss_avg"] df = df.select(columns) # Remove all rows where all columns are null df = df.filter(~pl.all_horizontal(pl.col(c).is_null() for c in columns))