diff --git a/README.md b/README.md index b3e5bb49..6e9846b4 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ conda activate crai `environment-cuda.yml` should be used when working with GPUs using CUDA. +The installation time of the required dependencies should not exceed 15 minutes using a stable and standard internet connection + ## Installation `climatereconstructionAI` can be installed using `pip` in the current directory: @@ -37,6 +39,8 @@ conda activate crai pip install . ``` +The installation time of the Python package should not exceed 1 minute on a regular computer + ## Usage The software can be used to: diff --git a/climatereconstructionai/config.py b/climatereconstructionai/config.py index 6ee9a771..370aadf8 100644 --- a/climatereconstructionai/config.py +++ b/climatereconstructionai/config.py @@ -54,29 +54,88 @@ def set_lambdas(): lambda_dict = {} - if loss_criterion == 0: + if loss_criterion in ("0", "inpainting"): lambda_dict['valid'] = 1. lambda_dict['hole'] = 6. lambda_dict['tv'] = .1 lambda_dict['prc'] = .05 lambda_dict['style'] = 120. - elif loss_criterion == 1: + elif loss_criterion in ("1", "l1-hole"): lambda_dict['hole'] = 1. - elif loss_criterion == 2: + elif loss_criterion in ("2", "downscaling"): lambda_dict['valid'] = 7. lambda_dict['hole'] = 0. lambda_dict['tv'] = .1 lambda_dict['prc'] = .05 lambda_dict['style'] = 120. - elif loss_criterion == 3: + + elif loss_criterion in ("3", "l1-valid"): lambda_dict['valid'] = 1. + elif loss_criterion in ("4", "extreme"): + lambda_dict['-extreme'] = 1. + lambda_dict['+extreme'] = 1. + + if vae_zdim != 0: + lambda_dict['kldiv'] = 1. + if lambda_loss is not None: lambda_dict.update(lambda_loss) +def set_steps(evaluate=False): + + assert sum(bool(x) for x in [lstm_steps, gru_steps, channel_steps]) < 2, \ + "lstm, gru and channel options are mutually exclusive" + + global recurrent_steps, n_recurrent_steps + global time_steps + time_steps = [0, 0] + if lstm_steps: + time_steps = lstm_steps + recurrent_steps = lstm_steps[0] + elif gru_steps: + time_steps = gru_steps + recurrent_steps = gru_steps[0] + else: + recurrent_steps = 0 + + n_recurrent_steps = sum(time_steps) + 1 + + global n_channel_steps, gt_channels + if channel_steps: + time_steps = channel_steps + n_channel_steps = sum(channel_steps) + 1 + gt_channels = [i * n_channel_steps + channel_steps[0] for i in range(n_output_data)] + else: + n_channel_steps = 1 + gt_channels = [0 for i in range(n_output_data)] + + global n_time_steps, in_steps, out_steps, n_pred_steps, pred_timestep, out_channels + + n_time_steps = sum(time_steps) + 1 + pred_timestep = list(range(-pred_steps[0], pred_steps[1] + 1)) + n_pred_steps = len(pred_timestep) + + if evaluate: + in_steps = range(0, n_time_steps) + out_steps = [time_steps[0]] + out_channels = n_output_data * n_pred_steps + else: + in_step = max(pred_steps[0] - time_steps[0], 0) + in_steps = range(in_step, in_step + n_time_steps) + n_time_steps = len(in_steps) + interval = [max(time_steps[i], pred_steps[i]) for i in range(2)] + out_steps = range(interval[0] - pred_steps[0], interval[0] + pred_steps[1] + 1) + time_steps = interval + + out_channels = n_output_data * len(out_steps) + + assert len(time_steps) == 2 + + def global_args(parser, arg_file=None, prog_func=None): import torch @@ -109,33 +168,18 @@ def global_args(parser, arg_file=None, prog_func=None): if not os.path.exists(log_dir): os.makedirs(log_dir) - global recurrent_steps - global n_recurrent_steps - global time_steps - time_steps = [0, 0] - if lstm_steps: - recurrent_steps = lstm_steps[0] - time_steps = lstm_steps - elif gru_steps: - recurrent_steps = gru_steps[0] - time_steps = gru_steps - else: - recurrent_steps = 0 - - n_recurrent_steps = sum(time_steps) + 1 - - global n_channel_steps - global gt_channels + global n_output_data + if n_target_data > 0: + n_output_data = n_target_data - n_channel_steps = 1 - gt_channels = [0 for i in range(out_channels)] - if channel_steps: - time_steps = channel_steps - n_channel_steps = sum(channel_steps) + 1 - for i in range(out_channels): - gt_channels[i] = (i + 1) * channel_steps[0] + i * (channel_steps[1] + 1) + global min_bounds, max_bounds + if len(min_bounds) == 1: + min_bounds = [min_bounds[0] for i in range(n_output_data)] + if len(max_bounds) == 1: + max_bounds = [max_bounds[0] for i in range(n_output_data)] - assert len(time_steps) == 2 + assert len(min_bounds) == n_output_data + assert len(max_bounds) == n_output_data if all('.json' in data_name for data_name in data_names) and (lstm_steps or channel_steps): print('Warning: Each input file defined in your ".json" files will be considered individually.' @@ -163,6 +207,7 @@ def set_common_args(): help="Number of data-names (from last) to be used as target data") arg_parser.add_argument('--device', type=str, default='cuda', help="Device used by PyTorch (cuda or cpu)") arg_parser.add_argument('--shuffle-masks', action='store_true', help="Select mask indices randomly") + arg_parser.add_argument('--vae-zdim', type=int, default=0, help="Use VAE with latent space dimension") arg_parser.add_argument('--channel-steps', type=int_list, default=None, help="Comma separated number of considered sequences for channeled memory:" "past_steps,future_steps") @@ -170,6 +215,8 @@ def set_common_args(): help="Comma separated number of considered sequences for lstm: past_steps,future_steps") arg_parser.add_argument('--gru-steps', type=int_list, default=None, help="Comma separated number of considered sequences for gru: past_steps,future_steps") + arg_parser.add_argument('--pred-steps', type=int_list, default=[0, 0], + help="Comma separated number of considered sequences for pred: past_steps,future_steps") arg_parser.add_argument('--encoding-layers', type=int_list, default='3', help="Number of encoding layers in the CNN") arg_parser.add_argument('--pooling-layers', type=int_list, default='0', help="Number of pooling layers in the CNN") @@ -177,7 +224,7 @@ def set_common_args(): arg_parser.add_argument('--weights', type=str, default=None, help="Initialization weight") arg_parser.add_argument('--steady-masks', type=str_list, default=None, help="Comma separated list of netCDF files containing a single mask to be applied " - "to all timesteps. The number of steady-masks must be the same as out-channels") + "to all timesteps. The number of steady-masks must be the same as n-output-data") arg_parser.add_argument('--loop-random-seed', type=int, default=None, help="Random seed for iteration loop") arg_parser.add_argument('--cuda-random-seed', type=int, default=None, @@ -192,17 +239,18 @@ def set_common_args(): arg_parser.add_argument('--masked-bn', action='store_true', help="Use masked batch normalization instead of standard BN") arg_parser.add_argument('--lazy-load', action='store_true', help="Use lazy loading for large datasets") + arg_parser.add_argument('--standard-conv', action='store_true', help="Disable partial convolution") arg_parser.add_argument('--global-padding', action='store_true', help="Use a custom padding for global dataset") arg_parser.add_argument('--normalize-data', action='store_true', help="Normalize the input climate data to 0 mean and 1 std") arg_parser.add_argument('--n-filters', type=int, default=None, help="Number of filters for the first/last layer") - arg_parser.add_argument('--out-channels', type=int, default=1, help="Number of channels for the output data") + arg_parser.add_argument('--n-output-data', type=int, default=1, help="Number of output data") arg_parser.add_argument('--dataset-name', type=str, default=None, help="Name of the dataset for format checking") arg_parser.add_argument('--min-bounds', type=float_list, default="-inf", help="Comma separated list of values defining the permitted lower-bound of output values") arg_parser.add_argument('--max-bounds', type=float_list, default="inf", help="Comma separated list of values defining the permitted upper-bound of output values") - arg_parser.add_argument('--profile', action='store_true', help="Profile code using tensorboard profiler") + arg_parser.add_argument('--profiler', type=str, default=None, help="Use specified profiler") return arg_parser @@ -231,9 +279,8 @@ def set_train_args(arg_file=None): help="Number of final models to be saved") arg_parser.add_argument('--final-models-interval', type=int, default=1000, help="Iteration step interval at which the final models should be saved") - arg_parser.add_argument('--loss-criterion', type=int, default=0, - help="Index defining the loss function " - "(0=original from Liu et al., 1=MAE of the hole region)") + arg_parser.add_argument('--loss-criterion', type=str, default="l1-hole", + help="Index/string defining the loss function (inpainting/l1-hole/l1-valid/etc.)") arg_parser.add_argument('--eval-timesteps', type=int_list, default=None, help="Sample indices for which a snapshot is created at each iter defined by log-interval") arg_parser.add_argument('-f', '--load-from-file', type=str, action=LoadFromFile, @@ -257,6 +304,7 @@ def set_train_args(arg_file=None): help="Number of batch iterations used to average the validation loss") args = global_args(arg_parser, arg_file) + set_steps() global passed_args passed_args = get_passed_arguments(args, arg_parser) @@ -276,13 +324,14 @@ def set_train_args(arg_file=None): def set_evaluate_args(arg_file=None, prog_func=None): arg_parser = set_common_args() arg_parser.add_argument('--model-dir', type=str, default='snapshots/ckpt/', help="Directory of the trained models") - arg_parser.add_argument('--model-names', type=str_list, default='1000000.pth', help="Model names") + arg_parser.add_argument('--model-names', type=str_list, default='final.pth', help="Model names") arg_parser.add_argument('--evaluation-dirs', type=str_list, default='evaluation/', help="Directory where the output files will be stored") arg_parser.add_argument('--eval-names', type=str_list, default='output', help="Prefix used for the output filenames") arg_parser.add_argument('--use-train-stats', action='store_true', help="Use mean and std from training data for normalization") + arg_parser.add_argument('--n-evaluations', type=int, default=1, help="Number of evaluations") arg_parser.add_argument('--create-graph', action='store_true', help="Create a Tensorboard graph of the NN") arg_parser.add_argument('--plot-results', type=int_list, default=[], help="Create plot images of the results for the comma separated list of time indices") @@ -290,8 +339,13 @@ def set_evaluate_args(arg_file=None, prog_func=None): help="Split the climate dataset into several partitions along the time coordinate") arg_parser.add_argument('--maxmem', type=int, default=None, help="Maximum available memory in MB (overwrite partitions parameter)") - arg_parser.add_argument('--split-outputs', action='store_true', - help="Do not merge the outputs when using multiple models and/or partitions") + arg_parser.add_argument('--time-freq', type=str, default=None, + help="Time frequency for pred-steps option (only for D,H,M,S,etc.)") + arg_parser.add_argument('--split-outputs', type=str, default="all", const=None, nargs='?', + help="Split the outputs according to members and/or partitions") arg_parser.add_argument('-f', '--load-from-file', type=str, action=LoadFromFile, help="Load all the arguments from a text file") global_args(arg_parser, arg_file, prog_func) + set_steps(evaluate=True) + assert len(eval_names) == n_output_data + globals()["model_names"] *= globals()["n_evaluations"] diff --git a/climatereconstructionai/evaluate.py b/climatereconstructionai/evaluate.py index 233ef212..cc7509a0 100644 --- a/climatereconstructionai/evaluate.py +++ b/climatereconstructionai/evaluate.py @@ -15,6 +15,12 @@ def store_encoding(ds): return ds +def format_time(ds): + ds['time'].encoding = encoding + ds['time'].encoding['original_shape'] = len(ds["time"]) + return ds.transpose("time", ...).reset_coords(drop=True) + + def evaluate(arg_file=None, prog_func=None): cfg.set_evaluate_args(arg_file, prog_func) @@ -36,7 +42,7 @@ def evaluate(arg_file=None, prog_func=None): data_stats = None dataset_val = NetCDFLoader(cfg.data_root_dir, cfg.data_names, cfg.mask_dir, cfg.mask_names, "infill", - cfg.data_types, cfg.time_steps, data_stats) + cfg.data_types, cfg.time_steps, cfg.steady_masks, data_stats) n_samples = len(dataset_val) @@ -79,28 +85,38 @@ def evaluate(arg_file=None, prog_func=None): batch_size = get_batch_size(model.parameters(), n_samples, image_sizes) iterator_val = iter(DataLoader(dataset_val, batch_size=batch_size, sampler=FiniteSampler(len(dataset_val)), num_workers=0)) - infill(model, iterator_val, eval_path, output_names, data_stats, dataset_val.xr_dss, count) + infill(model, iterator_val, eval_path, output_names, dataset_val.steady_mask, data_stats, + dataset_val.xr_dss, count) for name in output_names: if len(output_names[name]) == 1 and len(output_names[name][1]) == 1: os.rename(output_names[name][1][0], name + ".nc") else: - if not cfg.split_outputs: - dss = [] - for i_model in output_names[name]: - dss.append(xr.open_mfdataset(output_names[name][i_model], preprocess=store_encoding, autoclose=True, - combine='nested', data_vars='minimal', concat_dim="time", chunks={})) - dss[-1] = dss[-1].assign_coords({"member": i_model}) - - if len(dss) == 1: - ds = dss[-1].drop("member") + if cfg.split_outputs is not None: + + if cfg.split_outputs == "time": + k = 0 + for names in zip(*(output_names[name].values())): + k += 1 + dss = [xr.open_dataset(names[i]).assign_coords({"member": i}) for i in range(len(names))] + xr.concat(dss, dim="member").to_netcdf("{}-{}.nc".format(name, k)) else: - ds = xr.concat(dss, dim="member") - - ds['time'].encoding = encoding - ds['time'].encoding['original_shape'] = len(ds["time"]) - ds = ds.transpose("time", ...).reset_coords(drop=True) - ds.to_netcdf(name + ".nc") + dss = [] + for i_model in output_names[name]: + ds = xr.open_mfdataset(output_names[name][i_model], preprocess=store_encoding, autoclose=True, + combine='nested', data_vars='minimal', concat_dim="time", chunks={}) + ds = ds.assign_coords({"member": i_model}) + if cfg.split_outputs == "member": + format_time(ds).to_netcdf("{}.{}.nc".format(name, i_model)) + else: + dss.append(ds) + + if cfg.split_outputs != "member": + if len(dss) == 1: + ds = dss[-1].drop("member") + else: + ds = xr.concat(dss, dim="member") + format_time(ds).to_netcdf(name + ".nc") for i_model in output_names[name]: for output_name in output_names[name][i_model]: diff --git a/climatereconstructionai/loss/extreme_loss.py b/climatereconstructionai/loss/extreme_loss.py new file mode 100644 index 00000000..bb5052b1 --- /dev/null +++ b/climatereconstructionai/loss/extreme_loss.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + + +class ExtremeLoss(nn.Module): + def __init__(self): + super().__init__() + self.l2 = nn.MSELoss() + self.sm = nn.Softmax(dim=0) + + def forward(self, data_dict): + loss_dict = { + '-extreme': 0.0, + '+extreme': 0.0, + } + + output = data_dict['output'] + gt = data_dict['gt'] + + # calculate loss for all channels + for channel in range(output.shape[1]): + + gt_ch = torch.unsqueeze(gt[:, channel, :, :], dim=1) + output_ch = torch.unsqueeze(output[:, channel, :, :], dim=1) + loss_dict['-extreme'] += self.l2(self.sm(-output_ch), self.sm(-gt_ch)) + loss_dict['+extreme'] += self.l2(self.sm(output_ch), self.sm(gt_ch)) + + return loss_dict diff --git a/climatereconstructionai/loss/feature_loss.py b/climatereconstructionai/loss/feature_loss.py index abc836ad..5e5a8f7b 100644 --- a/climatereconstructionai/loss/feature_loss.py +++ b/climatereconstructionai/loss/feature_loss.py @@ -3,12 +3,11 @@ from .utils import gram_matrix - class FeatureLoss(nn.Module): - def __init__(self, extractor): + def __init__(self, extractor, devices): super().__init__() self.l1 = nn.L1Loss() - self.extractor = extractor + self.extractor = {f'{device}': extractor(device) for device in devices} def forward(self, data_dict): loss_dict = { @@ -20,6 +19,8 @@ def forward(self, data_dict): output = data_dict['output'] gt = data_dict['gt'] + extractor = self.extractor[str(output.device)] + # calculate loss for all channels for channel in range(output.shape[1]): @@ -28,9 +29,9 @@ def forward(self, data_dict): output_comp_ch = torch.unsqueeze(output_comp[:, channel, :, :], dim=1) # define different loss function from features from output and output_comp - feat_output = self.extractor(output_ch) - feat_output_comp = self.extractor(output_comp_ch) - feat_gt = self.extractor(gt_ch) + feat_output = extractor(output_ch) + feat_output_comp = extractor(output_comp_ch) + feat_gt = extractor(gt_ch) for i in range(len(feat_gt)): loss_dict['prc'] += self.l1(feat_output[i], feat_gt[i]) loss_dict['prc'] += self.l1(feat_output_comp[i], feat_gt[i]) @@ -39,4 +40,4 @@ def forward(self, data_dict): loss_dict['style'] += self.l1(gram_matrix(feat_output_comp[i]), gram_matrix(feat_gt[i])) - return loss_dict \ No newline at end of file + return loss_dict diff --git a/climatereconstructionai/loss/get_loss.py b/climatereconstructionai/loss/get_loss.py index e6a01bcf..be6231fb 100644 --- a/climatereconstructionai/loss/get_loss.py +++ b/climatereconstructionai/loss/get_loss.py @@ -4,47 +4,38 @@ from .hole_loss import HoleLoss from .total_variation_loss import TotalVariationLoss from .valid_loss import ValidLoss +from .kldiv_loss import KLDivLoss +from .extreme_loss import ExtremeLoss from .. import config as cfg from ..utils.featurizer import VGG16FeatureExtractor -def prepare_data_dict(img_mask, loss_mask, output, gt, tensor_keys): - data_dict = dict(zip(list(tensor_keys),[None]*len(tensor_keys))) +def prepare_data_dict(mask, output, latent_dist, gt, tensor_keys): + data_dict = dict(zip(list(tensor_keys), [None] * len(tensor_keys))) - mask = img_mask - loss_mask = img_mask - if loss_mask is not None: - mask += loss_mask - mask[mask < 0] = 0 - mask[mask > 1] = 1 - assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!" - - output = output[:, cfg.recurrent_steps, :, :, :] - mask = mask[:, cfg.recurrent_steps, :, :, :] - gt = gt[:, cfg.recurrent_steps, cfg.gt_channels, :, :] - - data_dict['mask'] = mask - data_dict['output'] = output - data_dict['gt'] = gt + data_dict['mask'] = mask[:, 0] + data_dict['output'] = output[:, cfg.recurrent_steps] + data_dict['latent_dist'] = latent_dist + data_dict['gt'] = gt[:, 0] if 'comp' in tensor_keys: - data_dict['comp'] = mask * gt + (1 - mask) * output + data_dict['comp'] = data_dict['mask'] * data_dict['gt'] + (1 - data_dict['mask']) * data_dict['output'] return data_dict class loss_criterion(torch.nn.Module): - def __init__(self, lambda_dict): + def __init__(self, lambda_dict, devices): super().__init__() self.criterions = [] - self.tensors = ['output', 'gt', 'mask'] + self.tensors = ['output', 'latent_dist', 'gt', 'mask'] style_added = False for loss, lambda_ in lambda_dict.items(): if lambda_ > 0: if (loss == 'style' or loss == 'prc') and not style_added: - self.criterions.append(FeatureLoss(VGG16FeatureExtractor()).to(cfg.device)) + self.criterions.append(FeatureLoss(VGG16FeatureExtractor, devices).to(cfg.device)) self.tensors.append('comp') style_added = True @@ -61,10 +52,15 @@ def __init__(self, lambda_dict): if 'comp' not in self.tensors: self.tensors.append('comp') + elif loss == 'kldiv': + self.criterions.append(KLDivLoss().to(cfg.device)) + + elif loss == '-extreme' or loss == '+extreme': + self.criterions.append(ExtremeLoss().to(cfg.device)) - def forward(self, img_mask, loss_mask, output, gt): + def forward(self, mask, output, latent_dist, gt): - data_dict = prepare_data_dict(img_mask, loss_mask, output, gt, self.tensors) + data_dict = prepare_data_dict(mask, output, latent_dist, gt, self.tensors) loss_dict = {} for criterion in self.criterions: @@ -75,19 +71,39 @@ def forward(self, img_mask, loss_mask, output, gt): if lambda_value > 0 and loss in loss_dict.keys(): loss_w_lambda = loss_dict[loss] * lambda_value loss_dict["total"] += loss_w_lambda - loss_dict[loss] = loss_w_lambda.item() + loss_dict[loss] = loss_w_lambda#.item() return loss_dict +class ModularizedFunction(torch.nn.Module): + def __init__(self, forward_op): + super().__init__() + self.forward_op = forward_op + + def forward(self, *args, **kwargs): + return self.forward_op(*args, **kwargs) + +class CriterionParallel(torch.nn.Module): + def __init__(self, criterion): + super().__init__() + if not isinstance(criterion, torch.nn.Module): + criterion = ModularizedFunction(criterion) + self.criterion = torch.nn.DataParallel(criterion) + + def forward(self, *args, **kwargs): + multi_dict = self.criterion(*args, **kwargs) + for key in multi_dict.keys(): + multi_dict[key] = multi_dict[key].mean() + return multi_dict class LossComputation(torch.nn.Module): - def __init__(self, lambda_dict): + def __init__(self, lambda_dict, devices): super().__init__() if cfg.multi_gpus: - self.criterion = torch.nn.DataParallel(loss_criterion(lambda_dict)) + self.criterion = CriterionParallel(loss_criterion(lambda_dict, devices)) else: - self.criterion = loss_criterion(lambda_dict) + self.criterion = loss_criterion(lambda_dict, devices) - def forward(self, img_mask, loss_mask, output, gt): - loss_dict = self.criterion(img_mask, loss_mask, output ,gt) - return loss_dict \ No newline at end of file + def forward(self, mask, output, latent_dist, gt): + loss_dict = self.criterion(mask, output, latent_dist, gt) + return loss_dict diff --git a/climatereconstructionai/loss/hole_loss.py b/climatereconstructionai/loss/hole_loss.py index 0b6082ef..497d6a45 100644 --- a/climatereconstructionai/loss/hole_loss.py +++ b/climatereconstructionai/loss/hole_loss.py @@ -25,4 +25,4 @@ def forward(self, data_dict): # define different loss functions from output and output_comp loss_dict['hole'] += self.l1((1 - mask_ch) * output_ch, (1 - mask_ch) * gt_ch) - return loss_dict \ No newline at end of file + return loss_dict diff --git a/climatereconstructionai/loss/kldiv_loss.py b/climatereconstructionai/loss/kldiv_loss.py new file mode 100644 index 00000000..09c6a5bd --- /dev/null +++ b/climatereconstructionai/loss/kldiv_loss.py @@ -0,0 +1,14 @@ +import torch +from torch import nn + + +class KLDivLoss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, data_dict): + + mu, logvar = data_dict['latent_dist'] + loss_dict = {'kldiv': -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())} + + return loss_dict diff --git a/climatereconstructionai/loss/total_variation_loss.py b/climatereconstructionai/loss/total_variation_loss.py index 82421b29..3ff86af5 100644 --- a/climatereconstructionai/loss/total_variation_loss.py +++ b/climatereconstructionai/loss/total_variation_loss.py @@ -20,4 +20,4 @@ def forward(self, data_dict): for channel in range(output_comp.shape[1]): output_comp_ch = torch.unsqueeze(output_comp[:, channel, :, :], dim=1) loss_dict['tv'] += total_variation_loss(output_comp_ch) - return loss_dict \ No newline at end of file + return loss_dict diff --git a/climatereconstructionai/loss/utils.py b/climatereconstructionai/loss/utils.py index 1b148c5e..2eaa318b 100644 --- a/climatereconstructionai/loss/utils.py +++ b/climatereconstructionai/loss/utils.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F - +from .. import config as cfg def gram_matrix(feat): (b, ch, h, w) = feat.size() @@ -24,3 +24,12 @@ def conv_variance(data, k=11): exp2 = torch.pow(F.conv2d(data, weights, padding='valid'), 2) return (exp - exp2) + +def get_devices(model): + try: + device_ids = model.device_ids + except: + device_ids = [torch.cuda.current_device()] + + return [f'{cfg.device}:{id}' for id in device_ids] + diff --git a/climatereconstructionai/loss/valid_loss.py b/climatereconstructionai/loss/valid_loss.py index 71ac1ab8..d1c0b2e1 100644 --- a/climatereconstructionai/loss/valid_loss.py +++ b/climatereconstructionai/loss/valid_loss.py @@ -24,4 +24,4 @@ def forward(self, data_dict): # define different loss functions from output and output_comp loss_dict['valid'] += self.l1(mask_ch * output_ch, mask_ch * gt_ch) - return loss_dict \ No newline at end of file + return loss_dict diff --git a/climatereconstructionai/metrics/get_metrics.py b/climatereconstructionai/metrics/get_metrics.py index 239f08d7..7891bb52 100644 --- a/climatereconstructionai/metrics/get_metrics.py +++ b/climatereconstructionai/metrics/get_metrics.py @@ -6,8 +6,9 @@ from .. import config as cfg from ..loss import get_loss + @torch.no_grad() -def get_metrics(img_mask, loss_mask, output, gt, setname): +def get_metrics(mask, output, gt, setname): metric_settings = { 'valid': {}, 'hole': {}, @@ -33,17 +34,16 @@ def get_metrics(img_mask, loss_mask, output, gt, setname): } } - metric_dict = {} metrics = cfg.val_metrics - loss_metric_dict = dict(zip(metrics,[1]*len(metrics))) + loss_metric_dict = dict(zip(metrics, [1] * len(metrics))) if 'feature' in metrics: - loss_metric_dict.update(dict(zip(['style', 'prc'],[1,1]))) + loss_metric_dict.update(dict(zip(['style', 'prc'], [1, 1]))) loss_comp = get_loss.LossComputation(loss_metric_dict) - loss_metrics = loss_comp(img_mask, loss_mask, output, gt) + loss_metrics = loss_comp(mask, output, gt) loss_metrics['total'] = loss_metrics['total'].item() for metric in metrics: @@ -63,7 +63,7 @@ def get_metrics(img_mask, loss_mask, output, gt, setname): metric_dict[f'metric/{setname}/prc'] = loss_metrics['prc'] else: - data = get_loss.prepare_data_dict(img_mask, loss_mask, output, gt, ['mask','output','gt']) + data = get_loss.prepare_data_dict(mask, output, gt, ['mask', 'output', 'gt']) metric_outputs = calculate_metric(metric, data['mask'], data['output'], data['gt'], torchmetrics_settings=settings['torchmetric_settings']) @@ -130,4 +130,4 @@ def calculate_metric(name_expr, mask, output, gt, domain='valid', torchmetrics_s result_out += result_ch result_out = [result_out] - return result_out \ No newline at end of file + return result_out diff --git a/climatereconstructionai/model/conv_configs.py b/climatereconstructionai/model/conv_configs.py index 372b9d7f..00617951 100644 --- a/climatereconstructionai/model/conv_configs.py +++ b/climatereconstructionai/model/conv_configs.py @@ -1,8 +1,11 @@ from .. import config as cfg -# define configurations for convolutions +def ceildiv(size, i): + return -(size // -2**i) + +# define configurations for convolutions def init_enc_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, start_channels): conv_configs = [] for i in range(enc_dec_layers): @@ -21,9 +24,8 @@ def init_enc_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st conv_config['kernel'] = (3, 3) conv_config['out_channels'] = conv_factor // (2 ** (enc_dec_layers - i - 1)) conv_config['skip_channels'] = 0 - conv_config['img_size'] = [size // (2 ** i) if size % (2 ** i) == 0 else - size // (2 ** i) + 1 for size in img_size] - conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']] + conv_config['img_size'] = [ceildiv(size, i) for size in img_size] + conv_config['rec_size'] = [ceildiv(size, 1) for size in conv_config['img_size']] conv_configs.append(conv_config) for i in range(pool_layers): @@ -33,9 +35,8 @@ def init_enc_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st conv_config['kernel'] = (3, 3) conv_config['out_channels'] = conv_factor conv_config['skip_channels'] = 0 - conv_config['img_size'] = [size // (2 ** (enc_dec_layers + i)) if size % (2 ** (enc_dec_layers + i)) == 0 - else size // (2 ** (enc_dec_layers + i)) + 1 for size in img_size] - conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']] + conv_config['img_size'] = [ceildiv(size, enc_dec_layers + i) for size in img_size] + conv_config['rec_size'] = [ceildiv(size, 1) for size in conv_config['img_size']] conv_configs.append(conv_config) return conv_configs @@ -50,10 +51,8 @@ def init_dec_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st conv_config['kernel'] = (3, 3) conv_config['out_channels'] = conv_factor conv_config['skip_channels'] = cfg.skip_layers * conv_factor - conv_config['img_size'] = [size // (2 ** (enc_dec_layers + pool_layers - i - 1)) - if size % (2 ** (enc_dec_layers + pool_layers - i - 1)) == 0 - else size // (2 ** (enc_dec_layers + pool_layers - i - 1)) + 1 for size in img_size] - conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']] + conv_config['img_size'] = [ceildiv(size, enc_dec_layers + pool_layers - i - 1) for size in img_size] + conv_config['rec_size'] = [ceildiv(size, 1) for size in conv_config['img_size']] conv_configs.append(conv_config) for i in range(1, enc_dec_layers + 1): conv_config = {} @@ -67,9 +66,8 @@ def init_dec_conv_configs(conv_factor, img_size, enc_dec_layers, pool_layers, st else: conv_config['out_channels'] = conv_factor // (2 ** i) conv_config['skip_channels'] = cfg.skip_layers * conv_factor // (2 ** i) - conv_config['img_size'] = [size // (2 ** (enc_dec_layers - i)) if size % (2 ** (enc_dec_layers - i)) == 0 - else size // (2 ** (enc_dec_layers - i)) + 1 for size in img_size] - conv_config['rec_size'] = [size // 2 if size % 2 == 0 else size // 2 + 1 for size in conv_config['img_size']] + conv_config['img_size'] = [ceildiv(size, enc_dec_layers - i) for size in img_size] + conv_config['rec_size'] = [ceildiv(size, 1) for size in conv_config['img_size']] conv_configs.append(conv_config) return conv_configs diff --git a/climatereconstructionai/model/net.py b/climatereconstructionai/model/net.py index d687161f..328f4d9b 100644 --- a/climatereconstructionai/model/net.py +++ b/climatereconstructionai/model/net.py @@ -2,6 +2,7 @@ import torch.nn as nn from .attention_module import AttentionEncoderBlock +from .vae_module import VAEBlock from .conv_configs import init_enc_conv_configs, init_dec_conv_configs, \ init_enc_conv_configs_orig, init_dec_conv_configs_orig from .encoder_decoder import EncoderBlock, DecoderBlock @@ -71,6 +72,9 @@ def __init__(self, img_size=(512, 512), enc_dec_layers=4, pool_layers=4, in_chan kernel=enc_conv_configs[i]['kernel'], stride=(2, 2), activation=nn.ReLU())) self.encoder = nn.ModuleList(encoding_layers) + if cfg.vae_zdim != 0: + self.vae_module = VAEBlock(conv_config=enc_conv_configs[-1], n_steps=cfg.n_time_steps, z_dim=cfg.vae_zdim) + # define decoding layers decoding_layers = [] for i in range(self.net_depth): @@ -164,6 +168,11 @@ def forward(self, input, input_mask): h, h_mask = hs[self.net_depth], hs_mask[self.net_depth] + if cfg.vae_zdim == 0: + latent_dist = None + else: + h, latent_dist = self.vae_module(h) + # forward pass decoding layers for i in range(self.net_depth): if cfg.recurrent_steps: @@ -178,7 +187,7 @@ def forward(self, input, input_mask): h = self.binder.scale(h) # return last element of output from last decoding layer - return h + return h, latent_dist def train(self, mode=True): super().train(mode) diff --git a/climatereconstructionai/model/vae_module.py b/climatereconstructionai/model/vae_module.py new file mode 100644 index 00000000..4fe4c07d --- /dev/null +++ b/climatereconstructionai/model/vae_module.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VAEBlock(nn.Module): + def __init__(self, conv_config, n_steps, z_dim): + super().__init__() + + self.h_shape = [-1, n_steps, conv_config['out_channels']] + conv_config['rec_size'] + self.h_dim = self.h_shape[1] * self.h_shape[2] * self.h_shape[3] * self.h_shape[4] + self.efc1 = nn.Linear(self.h_dim, z_dim) + self.efc2 = nn.Linear(self.h_dim, z_dim) + self.dfc1 = nn.Linear(z_dim, self.h_dim) + + def forward(self, input): + + input = input.view(-1, self.h_dim) + mu = self.efc1(input) + logvar = self.efc2(input) + std = torch.exp(logvar * 0.5) + eps = torch.randn_like(std) + z = mu + std * eps + output = F.relu(self.dfc1(z)) + output = output.view(self.h_shape) + + return output, (mu, logvar) diff --git a/climatereconstructionai/train.py b/climatereconstructionai/train.py index 08be9da9..9944eee4 100644 --- a/climatereconstructionai/train.py +++ b/climatereconstructionai/train.py @@ -10,12 +10,13 @@ from . import config as cfg from .loss import get_loss +from .loss.utils import get_devices from .metrics.get_metrics import get_metrics from .model.net import CRAINet from .utils import twriter, early_stopping from .utils.evaluation import create_snapshot_image from .utils.io import load_ckpt, load_model, save_ckpt -from .utils.netcdfloader import NetCDFLoader, InfiniteSampler, load_steadymask +from .utils.netcdfloader import NetCDFLoader, InfiniteSampler from .utils.profiler import load_profiler @@ -46,18 +47,16 @@ def train(arg_file=None): # create data sets dataset_train = NetCDFLoader(cfg.data_root_dir, cfg.data_names, cfg.mask_dir, cfg.mask_names, 'train', - cfg.data_types, cfg.time_steps) + cfg.data_types, cfg.time_steps, cfg.steady_masks) dataset_val = NetCDFLoader(cfg.data_root_dir, cfg.val_names, cfg.mask_dir, cfg.mask_names, 'val', cfg.data_types, - cfg.time_steps) + cfg.time_steps, cfg.steady_masks) iterator_train = iter(DataLoader(dataset_train, batch_size=cfg.batch_size, sampler=InfiniteSampler(len(dataset_train)), - num_workers=cfg.n_threads)) + num_workers=cfg.n_threads, persistent_workers=True)) iterator_val = iter(DataLoader(dataset_val, batch_size=cfg.batch_size, sampler=InfiniteSampler(len(dataset_val)), - num_workers=cfg.n_threads)) - - steady_mask = load_steadymask(cfg.mask_dir, cfg.steady_masks, cfg.data_types, cfg.device) + num_workers=cfg.n_threads, persistent_workers=True)) image_sizes = dataset_train.img_sizes if cfg.conv_factor is None: @@ -93,7 +92,6 @@ def train(arg_file=None): lr = cfg.lr early_stop = early_stopping.early_stopping() - loss_comp = get_loss.LossComputation(cfg.lambda_dict) # define optimizer and loss functions optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) @@ -111,10 +109,15 @@ def train(arg_file=None): param_group['lr'] = lr print('Starting from iter ', start_iter) + if cfg.profiler == "dlprof": + import nvidia_dlprof_pytorch_nvtx + nvidia_dlprof_pytorch_nvtx.init() prof = load_profiler(start_iter) if cfg.multi_gpus: model = torch.nn.DataParallel(model) + + loss_comp = get_loss.LossComputation(cfg.lambda_dict, get_devices(model)) i = cfg.max_iter - (cfg.n_final_models - 1) * cfg.final_models_interval final_models = range(i, cfg.max_iter + 1, cfg.final_models_interval) @@ -130,10 +133,10 @@ def train(arg_file=None): # train model model.train() - image, mask, gt = [x.to(cfg.device) for x in next(iterator_train)[:3]] - output = model(image, mask) + image, in_mask, out_mask, gt = [x.to(cfg.device) for x in next(iterator_train)[:4]] + output, latent_dist = model(image, in_mask) - train_loss = loss_comp(mask, steady_mask, output, gt) + train_loss = loss_comp(out_mask, output, latent_dist, gt) optimizer.zero_grad() train_loss['total'].backward() @@ -145,10 +148,10 @@ def train(arg_file=None): model.eval() val_losses = [] for _ in range(cfg.n_iters_val): - image, mask, gt = [x.to(cfg.device) for x in next(iterator_val)[:3]] + image, in_mask, out_mask, gt = [x.to(cfg.device) for x in next(iterator_val)[:4]] with torch.no_grad(): - output = model(image, mask) - val_losses.append(list(loss_comp(mask, steady_mask, output, gt).values())) + output, latent_dist = model(image, in_mask) + val_losses.append(list(loss_comp(out_mask, output, latent_dist, gt).values())) val_loss = torch.tensor(val_losses).mean(dim=0) val_loss = dict(zip(train_loss.keys(), val_loss)) @@ -192,10 +195,10 @@ def train(arg_file=None): if cfg.val_metrics is not None: val_metrics = [] for _ in range(cfg.n_iters_val): - image, mask, gt = [x.to(cfg.device) for x in next(iterator_val)[:3]] + image, in_mask, out_mask, gt = [x.to(cfg.device) for x in next(iterator_val)[:4]] with torch.no_grad(): - output = model(image, mask) - metric_dict = get_metrics(mask, steady_mask, output, gt, 'val') + output, latent_dist = model(image, in_mask) + metric_dict = get_metrics(out_mask, output, latent_dist, gt, 'val') val_metrics.append(list(metric_dict.values())) val_metrics = torch.tensor(val_metrics).mean(dim=0) @@ -204,7 +207,7 @@ def train(arg_file=None): metric_dict.update({'iterations': n_iter, 'iterations_best_model': early_stop.global_iter_best}) writer.update_hparams(metric_dict, n_iter) - writer.add_visualizations(mask, steady_mask, output, gt, n_iter, 'val') + writer.add_visualizations(out_mask, output, gt, n_iter, 'val') writer.close() diff --git a/climatereconstructionai/utils/evaluation.py b/climatereconstructionai/utils/evaluation.py index 8d4ed611..a768b581 100644 --- a/climatereconstructionai/utils/evaluation.py +++ b/climatereconstructionai/utils/evaluation.py @@ -6,9 +6,9 @@ import torch from tensorboardX import SummaryWriter import xarray as xr +import pandas as pd from .netcdfchecker import reformat_dataset -from .netcdfloader import load_steadymask from .normalizer import renormalize from .plotdata import plot_data from .. import config as cfg @@ -19,26 +19,33 @@ def create_snapshot_image(model, dataset, filename): data_dict = {} - data_dict["image"], data_dict["mask"], data_dict["gt"], index = zip(*[dataset[int(i)] for i in cfg.eval_timesteps]) + data_dict["image"], data_dict["in_mask"], data_dict["out_mask"], data_dict["gt"], index \ + = zip(*[dataset[int(i)] for i in cfg.eval_timesteps]) for key in data_dict.keys(): data_dict[key] = torch.stack(data_dict[key]).to(cfg.device) with torch.no_grad(): - data_dict["output"] = model(data_dict["image"], data_dict["mask"]) + data_dict["output"], _ = model(data_dict["image"], data_dict["in_mask"]) - data_dict["infilled"] = data_dict["mask"] * data_dict["image"] + (1 - data_dict["mask"]) * data_dict["output"] + # data_dict["infilled"] = data_dict["mask"] * data_dict["image"] + (1 - data_dict["mask"]) * data_dict["output"] keys = list(data_dict.keys()) for key in keys: data_dict[key] = data_dict[key].to(torch.device('cpu')) + for key in ('image', 'in_mask', 'output'): + data_dict[key] = data_dict[key][:, cfg.recurrent_steps, :, :, :] + + for key in ('gt', 'out_mask'): + data_dict[key] = data_dict[key][:, 0, :, :, :] + # set mask - data_dict["mask"] = 1 - data_dict["mask"] - data_dict["image"] = np.ma.masked_array(data_dict["image"], data_dict["mask"]) - data_dict["mask"] = np.ma.masked_array(data_dict["mask"], data_dict["mask"]) + data_dict["in_mask"] = 1 - data_dict["in_mask"] + data_dict["image"] = np.ma.masked_array(data_dict["image"], data_dict["in_mask"]) + data_dict["in_mask"] = np.ma.masked_array(data_dict["in_mask"], data_dict["in_mask"]) - n_rows = sum([data_dict[key].shape[2] for key in keys]) + n_rows = sum([data_dict[key].shape[1] for key in keys]) n_cols = data_dict["image"].shape[0] # plot and save data @@ -50,11 +57,11 @@ def create_snapshot_image(model, dataset, filename): k = 0 for key in keys: - for c in range(data_dict[key].shape[2]): + for c in range(data_dict[key].shape[1]): if cfg.vlim is None: - vmin = data_dict[key][:, :, c, :, :].min().item() - vmax = data_dict[key][:, :, c, :, :].max().item() + vmin = data_dict[key][:, c, :, :].min().item() + vmax = data_dict[key][:, c, :, :].max().item() else: vmin = cfg.vlim[0] vmax = cfg.vlim[1] @@ -64,7 +71,7 @@ def create_snapshot_image(model, dataset, filename): for j in range(n_cols): axes[k, j].axis("off") - axes[k, j].imshow(np.squeeze(data_dict[key][j][cfg.recurrent_steps, c, :, :]), vmin=vmin, vmax=vmax) + axes[k, j].imshow(np.squeeze(data_dict[key][j][c, :, :]), vmin=vmin, vmax=vmax) k += 1 @@ -94,17 +101,15 @@ def get_batch_size(parameters, n_samples, image_sizes): return int(np.ceil(n_samples / partitions)) -def infill(model, dataset, eval_path, output_names, data_stats, xr_dss, i_model): +def infill(model, dataset, eval_path, output_names, steady_mask, data_stats, xr_dss, i_model): if not os.path.exists(cfg.evaluation_dirs[0]): os.makedirs('{:s}'.format(cfg.evaluation_dirs[0])) - steady_mask = load_steadymask(cfg.mask_dir, cfg.steady_masks, cfg.data_types, cfg.device) - - data_dict = {'image': [], 'mask': [], 'gt': [], 'output': [], 'infilled': []} + data_dict = {'image': [], 'mask': [], 'gt': [], 'output': []} for split in tqdm(range(dataset.__len__())): - # TODO: implement evaluation for multiple data paths - data_dict["image"], data_dict["mask"], data_dict["gt"], index = next(dataset) + + data_dict["image"], data_dict["mask"], _, data_dict["gt"], index = next(dataset) if split == 0 and cfg.create_graph: writer = SummaryWriter(log_dir=cfg.log_dir) @@ -113,24 +118,27 @@ def infill(model, dataset, eval_path, output_names, data_stats, xr_dss, i_model) # get results from trained network with torch.no_grad(): - data_dict["output"] = model(data_dict["image"].to(cfg.device), data_dict["mask"].to(cfg.device)) + data_dict["output"], _ = model(data_dict["image"].to(cfg.device), data_dict["mask"].to(cfg.device)) - for key in ('image', 'mask', 'gt', 'output'): + for key in ('image', 'mask', 'output'): data_dict[key] = data_dict[key][:, cfg.recurrent_steps, :, :, :].to(torch.device('cpu')) + data_dict['gt'] = data_dict['gt'][:, 0, :, :, :].to(torch.device('cpu')) - for key in ('image', 'mask', 'gt'): + for key in ('image', 'mask'): data_dict[key] = data_dict[key][:, cfg.gt_channels, :, :] if steady_mask is not None: - for key in ('image', 'gt', 'output'): - data_dict[key][:, :, steady_mask.type(torch.bool)] = np.nan - - data_dict["infilled"] = (1 - data_dict["mask"]) - data_dict["infilled"] *= data_dict["output"] - data_dict["infilled"] += data_dict["mask"] * data_dict["image"] + for key in ('gt', 'image'): + data_dict[key][:, ~steady_mask.type(torch.bool)] = np.nan + data_dict['output'][:, ~np.repeat(steady_mask, cfg.n_pred_steps, axis=0).type(torch.bool)] = np.nan data_dict["image"] /= data_dict["mask"] + if cfg.n_target_data == 0 and cfg.n_pred_steps == 1: + data_dict["infilled"] = (1 - data_dict["mask"]) + data_dict["infilled"] *= data_dict["output"] + data_dict["infilled"] += data_dict["mask"] * data_dict["gt"].nan_to_num() + create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_model, split, index) if cfg.progress_fwd is not None: @@ -146,8 +154,12 @@ def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_mod suffix = m_label + "-" + str(split + 1) if cfg.n_target_data == 0: - cnames = ["gt", "mask", "image", "output", "infilled"] - pnames = ["image", "infilled"] + if cfg.n_pred_steps == 1: + cnames = ["gt", "mask", "image", "output", "infilled"] + pnames = ["image", "infilled"] + else: + cnames = ["gt", "mask", "image", "output"] + pnames = ["image", "output"] else: cnames = ["gt", "output"] pnames = ["gt", "output"] @@ -156,6 +168,7 @@ def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_mod i_data = -cfg.n_target_data + j data_type = cfg.data_types[i_data] + i_plot = {} for cname in cnames: @@ -168,33 +181,47 @@ def create_outputs(data_dict, eval_path, output_names, data_stats, xr_dss, i_mod output_names[rootname][i_model] += [rootname + suffix + ".nc"] - ds = xr_dss[i_data][ds_index][1].copy() + ds = xr_dss[i_data][ds_index]["ds1"].copy() + dims = xr_dss[i_data][ds_index]["dims"].copy() + coords = xr_dss[i_data][ds_index]["coords"].copy() + coords["time"] = xr_dss[i_data][ds_index]["ds"]["time"].values[index] + if cfg.n_pred_steps > 1 and cname == "output": + dims = [dims[0]] + ["pred_time"] + dims[1:] + coords["pred_time"] = np.array(cfg.pred_timestep) + if cfg.time_freq: + coords["pred_time"] = pd.to_timedelta(coords["pred_time"], unit=cfg.time_freq) + coords['times'] = (["pred_time", "time"], np.add.outer(coords["pred_time"], coords["time"])) + i_pred = range(j * cfg.n_pred_steps, (j + 1) * cfg.n_pred_steps) + i_plot[cname] = i_pred[0] + else: + i_pred, i_plot[cname] = j, j if cfg.normalize_data and cname != "mask": - data_dict[cname][:, j, :, :] = renormalize(data_dict[cname][:, j, :, :], - data_stats["mean"][i_data], data_stats["std"][i_data]) + data_dict[cname][:, i_pred] = renormalize(data_dict[cname][:, i_pred], + data_stats["mean"][i_data], data_stats["std"][i_data]) - ds[data_type] = xr.DataArray(data_dict[cname].to(torch.device('cpu')).detach().numpy()[:, j, :, :], - dims=xr_dss[i_data][ds_index][2], coords=xr_dss[i_data][ds_index][3]) - ds["time"] = xr_dss[i_data][ds_index][0]["time"].values[index] + ds[data_type] = xr.DataArray(data_dict[cname].detach().numpy()[:, i_pred], dims=dims, coords=coords) + ds = reformat_dataset(xr_dss[i_data][ds_index]["ds"], ds, data_type) - ds = reformat_dataset(xr_dss[i_data][ds_index][0], ds, data_type) - - for var in xr_dss[i_data][ds_index][0].keys(): - if "time" in xr_dss[i_data][ds_index][0][var].dims: - ds[var] = xr_dss[i_data][ds_index][0][var].isel(time=index) + for var in xr_dss[i_data][ds_index]["ds"].keys(): + if "time" in xr_dss[i_data][ds_index]["ds"][var].dims: + ds[var] = xr_dss[i_data][ds_index]["ds"][var].isel(time=index) else: - ds[var] = xr_dss[i_data][ds_index][0][var] + ds[var] = xr_dss[i_data][ds_index]["ds"][var] + if "history" in ds.attrs: + history = "\n" + ds.attrs["history"] + else: + history = "" ds.attrs["history"] = "Infilled using CRAI (Climate Reconstruction AI: " \ - "https://github.com/FREVA-CLINT/climatereconstructionAI)\n" + ds.attrs["history"] + "https://github.com/FREVA-CLINT/climatereconstructionAI)" + history ds.to_netcdf(output_names[rootname][i_model][-1]) for time_step in cfg.plot_results: if time_step in index: output_name = '{}_{}{}_{}.png'.format(eval_path[j], "combined", m_label, time_step) - plot_data(xr_dss[i_data][ds_index][1].coords, - [data_dict[p][time_step - index[0], j, :, :].squeeze() for p in pnames], + plot_data(xr_dss[i_data][ds_index]["ds1"].coords, + [data_dict[p][time_step - index[0], i_plot[p], :, :].squeeze() for p in pnames], ["Original", "Reconstructed"], output_name, data_type, - str(xr_dss[i_data][ds_index][0]["time"][time_step].values), + str(xr_dss[i_data][ds_index]["ds"]["time"][time_step].values), *cfg.dataset_format["scale"]) diff --git a/climatereconstructionai/utils/featurizer.py b/climatereconstructionai/utils/featurizer.py index 0b4c1775..008a95a9 100644 --- a/climatereconstructionai/utils/featurizer.py +++ b/climatereconstructionai/utils/featurizer.py @@ -4,13 +4,13 @@ class VGG16FeatureExtractor(nn.Module): - def __init__(self): + def __init__(self, device): super().__init__() vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT) - self.enc_1 = nn.Sequential(*vgg16.features[:5]) - self.enc_2 = nn.Sequential(*vgg16.features[5:10]) - self.enc_3 = nn.Sequential(*vgg16.features[10:17]) + self.enc_1 = nn.Sequential(*vgg16.features[:5]).to(device) + self.enc_2 = nn.Sequential(*vgg16.features[5:10]).to(device) + self.enc_3 = nn.Sequential(*vgg16.features[10:17]).to(device) # fix the encoder for i in range(3): diff --git a/climatereconstructionai/utils/io.py b/climatereconstructionai/utils/io.py index e6319b38..eba1f0fc 100644 --- a/climatereconstructionai/utils/io.py +++ b/climatereconstructionai/utils/io.py @@ -37,8 +37,7 @@ def load_model(ckpt_dict, model, optimizer=None, label=None): if label is None: label = ckpt_dict["labels"][-1] - ckpt_dict[label]["model"] = \ - {key.replace("module.", ""): value for key, value in ckpt_dict[label]["model"].items()} + ckpt_dict[label]["model"] = {key: value for key, value in ckpt_dict[label]["model"].items()} model.load_state_dict(ckpt_dict[label]["model"]) if optimizer is not None: optimizer.load_state_dict(ckpt_dict[label]["optimizer"]) diff --git a/climatereconstructionai/utils/netcdfloader.py b/climatereconstructionai/utils/netcdfloader.py index faaa7233..5290cb84 100644 --- a/climatereconstructionai/utils/netcdfloader.py +++ b/climatereconstructionai/utils/netcdfloader.py @@ -12,17 +12,23 @@ from .. import config as cfg -def load_steadymask(path, mask_names, data_types, device): +def load_steadymask(path, mask_names, data_types): if mask_names is None: return None else: - assert len(mask_names) == cfg.out_channels if cfg.n_target_data == 0: - steady_mask = load_netcdf(path, mask_names, data_types[:cfg.out_channels])[0] + assert len(mask_names) == cfg.n_output_data + steady_mask = load_netcdf(path, mask_names, data_types[:cfg.n_output_data])[0] else: + assert len(mask_names) == cfg.n_target_data steady_mask = load_netcdf(path, mask_names, data_types[-cfg.n_target_data:])[0] - # stack + squeeze ensures that it works with steady masks with one timestep or no timestep - return torch.stack([torch.from_numpy(np.array(mask)).to(device) for mask in steady_mask]).squeeze() + + steady_mask = torch.stack([torch.from_numpy(np.array(mask[0])) for mask in steady_mask]) + # squeeze time dimension if any + + if steady_mask.ndim == 4: + steady_mask = steady_mask.squeeze(axis=1) + return steady_mask class InfiniteSampler(Sampler): @@ -83,12 +89,13 @@ def nc_loadchecker(filename, data_type): else: data = ds1[data_type].values - dims = ds1[data_type].dims + dims = list(ds1[data_type].dims) coords = {key: ds1[data_type].coords[key] for key in ds1[data_type].coords if key != "time"} ds1 = ds1.drop_vars(ds1.keys()) - ds1 = ds1.drop_dims("time") + if "time" in dims: + ds1 = ds1.drop_dims("time") - return [ds, ds1, dims, coords], data, data.shape[0], data.shape[1:] + return {"ds": ds, "ds1": ds1, "dims": dims, "coords": coords}, data, data.shape[0], data.shape[1:] def load_netcdf(path, data_names, data_types, keep_dss=False): @@ -124,11 +131,13 @@ def load_netcdf(path, data_names, data_types, keep_dss=False): class NetCDFLoader(Dataset): - def __init__(self, data_root, img_names, mask_root, mask_names, split, data_types, time_steps, train_stats=None): + def __init__(self, data_root, img_names, mask_root, mask_names, split, data_types, time_steps, steady_masks, + train_stats=None): super(NetCDFLoader, self).__init__() self.random = random.Random(cfg.loop_random_seed) + self.standard_conv = cfg.standard_conv self.data_types = data_types self.time_steps = time_steps @@ -161,6 +170,8 @@ def __init__(self, data_root, img_names, mask_root, mask_names, split, data_type self.bounds = bnd_normalization(self.img_mean, self.img_std) + self.steady_mask = load_steadymask(mask_root, steady_masks, data_types) + def load_data(self, ind_data, img_indices, ds_index, mask_indices, mask_ds_index): if self.mask_data is None: @@ -181,9 +192,9 @@ def get_single_item(self, ind_data, index, shuffle_masks): # get index of dataset ds_index = 0 current_index = 0 - for l in range(len(self.img_length)): - if index > current_index + self.img_length[l]: - current_index += self.img_length[l] + for ilength in range(len(self.img_length)): + if index > current_index + self.img_length[ilength]: + current_index += self.img_length[ilength] ds_index += 1 index -= current_index @@ -195,9 +206,9 @@ def get_single_item(self, ind_data, index, shuffle_masks): mask_index = self.random.randint(0, sum(self.mask_length) - 1) mask_ds_index = 0 current_index = 0 - for l in range(len(self.mask_length)): - if mask_index > current_index + self.mask_length[l]: - current_index += self.mask_length[l] + for ilength in range(len(self.mask_length)): + if mask_index > current_index + self.mask_length[ilength]: + current_index += self.mask_length[ilength] mask_ds_index += 1 mask_index -= current_index @@ -217,10 +228,25 @@ def get_single_item(self, ind_data, index, shuffle_masks): return images, masks + def create_out_mask(self, mask, i): + + out_mask = mask[cfg.out_steps] + if cfg.n_target_data > 0 or cfg.n_pred_steps > 1: + out_mask[:] = 1. + + if self.steady_mask is not None: + out_mask += self.steady_mask[i] + out_mask[out_mask < 0] = 0 + out_mask[out_mask > 1] = 1 + assert ((out_mask == 0) | (out_mask == 1)).all(), "Not all values in mask are zeros or ones!" + + return out_mask + def __getitem__(self, index): images = [] - masks = [] + in_masks = [] + out_masks = [] masked = [] ndata = len(self.data_types) @@ -229,18 +255,29 @@ def __getitem__(self, index): image, mask = self.get_single_item(i, index, cfg.shuffle_masks) if i >= ndata - cfg.n_target_data: - images.append(image) + images.append(image[cfg.out_steps]) + out_masks.append(self.create_out_mask(mask, i - ndata + cfg.n_target_data)) else: - if cfg.n_target_data == 0: - images.append(image) - masks.append(mask) - masked.append(image * mask) + if cfg.n_target_data == 0 and i < cfg.n_output_data: + images.append(image[cfg.out_steps]) + out_masks.append(self.create_out_mask(mask, i)) + if self.standard_conv: + in_masks.append(torch.ones_like(mask[cfg.in_steps])) + else: + in_masks.append(mask[cfg.in_steps]) + masked.append(image[cfg.in_steps] * mask[cfg.in_steps]) if cfg.channel_steps: - return torch.cat(masked, dim=0).transpose(0, 1), torch.cat(masks, dim=0) \ - .transpose(0, 1), torch.cat(images, dim=0).transpose(0, 1), index + return (torch.cat(masked, dim=0).transpose(0, 1), + torch.cat(in_masks, dim=0).transpose(0, 1), + torch.cat(out_masks, dim=0).transpose(0, 1), + torch.cat(images, dim=0).transpose(0, 1), + index) else: - return torch.cat(masked, dim=1), torch.cat(masks, dim=1), torch.cat(images, dim=1), index + return (torch.cat(masked, dim=1), torch.cat(in_masks, dim=1), + torch.cat(out_masks, dim=0).transpose(0, 1), + torch.cat(images, dim=0).transpose(0, 1), + index) def __len__(self): return sum(self.img_length) diff --git a/climatereconstructionai/utils/normalizer.py b/climatereconstructionai/utils/normalizer.py index 14ce84ad..4883c9c9 100644 --- a/climatereconstructionai/utils/normalizer.py +++ b/climatereconstructionai/utils/normalizer.py @@ -8,8 +8,18 @@ def img_normalization(img_data, train_stats): if cfg.normalize_data: for i in range(len(img_data)): if train_stats is None: - img_mean.append(np.nanmean(np.array(img_data[i]))) - img_std.append(np.nanstd(np.array(img_data[i]))) + sums, ssum, size = 0., 0., 0. + for data in img_data[i]: + if cfg.lazy_load: + sums += data.chunk('auto').sum(skipna=True).compute().item() + ssum += (data.chunk('auto') * data.chunk('auto')).sum(skipna=True).compute().item() + size += data.chunk('auto').count().compute().item() + else: + sums += np.nansum(data) + ssum += np.nansum(data * data) + size += np.sum(~np.isnan(data)) + img_mean.append(sums / size) + img_std.append(np.sqrt(ssum / size - img_mean[-1] * img_mean[-1])) else: img_mean.append(train_stats["mean"][i]) img_std.append(train_stats["std"][i]) @@ -27,16 +37,18 @@ def bnd_normalization(img_mean, img_std): bounds = np.ones((cfg.out_channels, 2)) * np.inf if cfg.n_target_data == 0: - mean_val, std_val = img_mean[:cfg.out_channels], img_std[:cfg.out_channels] + mean_val, std_val = img_mean[:cfg.n_output_data], img_std[:cfg.n_output_data] else: mean_val, std_val = img_mean[-cfg.n_target_data:], img_std[-cfg.n_target_data:] k = 0 for bound in (cfg.min_bounds, cfg.max_bounds): - bounds[:, k] = bound - if cfg.normalize_data: - bounds[:, k] = (bounds[:, k] - mean_val) / std_val + for i in range(cfg.n_output_data): + idx = range(i * cfg.n_pred_steps, (i + 1) * cfg.n_pred_steps) + bounds[idx, k] = bound[i] + if cfg.normalize_data: + bounds[idx, k] = (bounds[idx, k] - mean_val[i]) / std_val[i] k += 1 diff --git a/climatereconstructionai/utils/profiler.py b/climatereconstructionai/utils/profiler.py index 8c5a0336..034f63bc 100644 --- a/climatereconstructionai/utils/profiler.py +++ b/climatereconstructionai/utils/profiler.py @@ -3,7 +3,7 @@ def load_profiler(start_iter): - if cfg.profile: + if cfg.profiler == "tensorboard": return torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], schedule=torch.profiler.schedule(wait=1, warmup=1, active=cfg.max_iter - start_iter, repeat=1), diff --git a/climatereconstructionai/utils/twriter.py b/climatereconstructionai/utils/twriter.py index 503c421a..1a9a8ea2 100644 --- a/climatereconstructionai/utils/twriter.py +++ b/climatereconstructionai/utils/twriter.py @@ -65,22 +65,22 @@ def add_figure(self, fig, iter_index, name_tag=None): name_tag = self.suffix self.writer.add_figure(name_tag, fig, global_step=iter_index) - def add_error_maps(self, mask, steady_mask, output, gt, iter_index, setname): - error_maps = visualization.get_all_error_maps(mask, steady_mask, output, gt, num_samples=3) + def add_error_maps(self, mask, output, gt, iter_index, setname): + error_maps = visualization.get_all_error_maps(mask, output, gt, num_samples=3) for error_map, name in zip(error_maps, ['error', 'relative error', 'abs error', 'relative abs error']): self.add_figure(error_map, iter_index, f'map/{setname}/{name}') - def add_correlation_plots(self, mask, steady_mask, output, gt, iter_index, setname): - fig = visualization.create_correlation_plot(mask, steady_mask, output, gt) + def add_correlation_plots(self, mask, output, gt, iter_index, setname): + fig = visualization.create_correlation_plot(mask, output, gt) self.add_figure(fig, iter_index, name_tag=f'plot/{setname}/correlation') - def add_error_dist_plot(self, mask, steady_mask, output, gt, iter_index, setname): - fig = visualization.create_error_dist_plot(mask, steady_mask, output, gt) + def add_error_dist_plot(self, mask, output, gt, iter_index, setname): + fig = visualization.create_error_dist_plot(mask, output, gt) self.add_figure(fig, iter_index, name_tag=f'plot/{setname}/error_dist') - def add_maps(self, mask, steady_mask, output, gt, iter_index, setname): - fig = visualization.create_map(mask, steady_mask, output, gt, num_samples=3) + def add_maps(self, mask, output, gt, iter_index, setname): + fig = visualization.create_map(mask, output, gt, num_samples=3) self.add_figure(fig, iter_index, name_tag=f'map/{setname}/values') def add_distribution(self, values, iter_index, name_tag=None): @@ -88,26 +88,26 @@ def add_distribution(self, values, iter_index, name_tag=None): name_tag = self.suffix self.writer.add_histogram(name_tag, values, global_step=iter_index) - def add_distributions(self, mask, steady_mask, output, gt, iter_index, setname): + def add_distributions(self, mask, output, gt, iter_index, setname): - errors_dists = visualization.get_all_error_distributions(mask, steady_mask, output, gt, num_samples=1000) + errors_dists = visualization.get_all_error_distributions(mask, output, gt, num_samples=1000) for error_dist, suffix in zip(errors_dists, ['error', 'abs error', 'relative error', 'relative abs error']): name = f'dist/{setname}/{suffix}' # entries in value_list correspond to channels for ch, values in enumerate(error_dist): self.add_distribution(values, iter_index, name_tag=f'{name}_channel{ch}') - def add_visualizations(self, mask, steady_mask, output, gt, iter_index, setname): + def add_visualizations(self, mask, output, gt, iter_index, setname): if "correlation" in cfg.tensor_plots: - self.add_correlation_plots(mask, steady_mask, output, gt, iter_index, setname) - self.add_error_dist_plot(mask, steady_mask, output, gt, iter_index, setname) + self.add_correlation_plots(mask, output, gt, iter_index, setname) + self.add_error_dist_plot(mask, output, gt, iter_index, setname) if "distribution" in cfg.tensor_plots: - self.add_distributions(mask, steady_mask, output, gt, iter_index, setname) + self.add_distributions(mask, output, gt, iter_index, setname) if "error" in cfg.tensor_plots: - self.add_error_maps(mask, steady_mask, output, gt, iter_index, setname) - self.add_maps(mask, steady_mask, output, gt, iter_index, setname) + self.add_error_maps(mask, output, gt, iter_index, setname) + self.add_maps(mask, output, gt, iter_index, setname) def close(self): self.writer.close() diff --git a/climatereconstructionai/utils/visualization.py b/climatereconstructionai/utils/visualization.py index 8b4faa3f..f3afb1ac 100644 --- a/climatereconstructionai/utils/visualization.py +++ b/climatereconstructionai/utils/visualization.py @@ -3,12 +3,7 @@ import torch -def calculate_distributions(mask, steady_mask, output, gt, domain="valid", num_samples=1000): - if steady_mask is not None: - mask += steady_mask - mask[mask < 0] = 0 - mask[mask > 1] = 1 - assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!" +def calculate_distributions(mask, output, gt, domain="valid", num_samples=1000): value_list_pred = [] value_list_target = [] @@ -36,14 +31,8 @@ def calculate_distributions(mask, steady_mask, output, gt, domain="valid", num_s return value_list_pred, value_list_target -def calculate_error_distributions(mask, steady_mask, output, gt, operation="AE", domain="valid", num_samples=1000): - preds, targets = calculate_distributions(mask, steady_mask, output, gt, domain=domain, num_samples=num_samples) - - if steady_mask is not None: - mask += steady_mask - mask[mask < 0] = 0 - mask[mask > 1] = 1 - assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!" +def calculate_error_distributions(mask, output, gt, operation="AE", domain="valid", num_samples=1000): + preds, targets = calculate_distributions(mask, output, gt, domain=domain, num_samples=num_samples) value_list = [] for ch in range(len(preds)): @@ -65,8 +54,8 @@ def calculate_error_distributions(mask, steady_mask, output, gt, operation="AE", return value_list -def create_error_dist_plot(mask, steady_mask, output, gt, operation='E', domain="valid", num_samples=1000): - preds, targets = calculate_distributions(mask, steady_mask, output, gt, domain=domain, num_samples=num_samples) +def create_error_dist_plot(mask, output, gt, operation='E', domain="valid", num_samples=1000): + preds, targets = calculate_distributions(mask, output, gt, domain=domain, num_samples=num_samples) fig, axs = plt.subplots(1, len(preds), squeeze=False) @@ -97,8 +86,8 @@ def create_error_dist_plot(mask, steady_mask, output, gt, operation='E', domain= return fig -def create_correlation_plot(mask, steady_mask, output, gt, domain="valid", num_samples=1000): - preds, targets = calculate_distributions(mask, steady_mask, output, gt, domain=domain, num_samples=num_samples) +def create_correlation_plot(mask, output, gt, domain="valid", num_samples=1000): + preds, targets = calculate_distributions(mask, output, gt, domain=domain, num_samples=num_samples) fig, axs = plt.subplots(1, len(preds), squeeze=False) @@ -117,12 +106,7 @@ def create_correlation_plot(mask, steady_mask, output, gt, domain="valid", num_s return fig -def create_error_map(mask, steady_mask, output, gt, num_samples=3, operation="AE", domain="valid"): - if steady_mask is not None: - mask += steady_mask - mask[mask < 0] = 0 - mask[mask > 1] = 1 - assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!" +def create_error_map(mask, output, gt, num_samples=3, operation="AE", domain="valid"): num_channels = output.shape[2] samples = torch.randint(output.shape[0], (num_samples,)) @@ -161,12 +145,7 @@ def create_error_map(mask, steady_mask, output, gt, num_samples=3, operation="AE return fig -def create_map(mask, steady_mask, output, gt, num_samples=3): - if steady_mask is not None: - mask += steady_mask - mask[mask < 0] = 0 - mask[mask > 1] = 1 - assert ((mask == 0) | (mask == 1)).all(), "Not all values in mask are zeros or ones!" +def create_map(mask, output, gt, num_samples=3): samples = torch.randint(output.shape[0], (num_samples,)) @@ -196,13 +175,13 @@ def create_map(mask, steady_mask, output, gt, num_samples=3): return fig -def get_all_error_distributions(mask, steady_mask, output, gt, domain="valid", num_samples=1000): - error_dists = [calculate_error_distributions(mask, steady_mask, output, gt, operation=op, domain=domain, +def get_all_error_distributions(mask, output, gt, domain="valid", num_samples=1000): + error_dists = [calculate_error_distributions(mask, output, gt, operation=op, domain=domain, num_samples=num_samples) for op in ['E', 'AE', 'RE', 'RAE']] return error_dists -def get_all_error_maps(mask, steady_mask, output, gt, num_samples=3): - error_maps = [create_error_map(mask, steady_mask, output, gt, num_samples=num_samples, operation=op, domain="valid") +def get_all_error_maps(mask, output, gt, num_samples=3): + error_maps = [create_error_map(mask, output, gt, num_samples=num_samples, operation=op, domain="valid") for op in ['E', 'AE', 'RE', 'RAE']] return error_maps diff --git a/data/hadcrut_missmask_1.nc b/data/masks/hadcrut_missmask_1.nc similarity index 100% rename from data/hadcrut_missmask_1.nc rename to data/masks/hadcrut_missmask_1.nc diff --git a/data/masks/masks_tas_hadcrut_187709_189308.nc b/data/masks/masks_tas_hadcrut_187709_189308.nc new file mode 100644 index 00000000..87d4f8e6 Binary files /dev/null and b/data/masks/masks_tas_hadcrut_187709_189308.nc differ diff --git a/data/masks/steady_tas_hadcrut.nc b/data/masks/steady_tas_hadcrut.nc new file mode 100644 index 00000000..852df889 Binary files /dev/null and b/data/masks/steady_tas_hadcrut.nc differ diff --git a/demo/demo_args.txt b/demo/demo_args.txt index cbbc7f11..75d6d32a 100644 --- a/demo/demo_args.txt +++ b/demo/demo_args.txt @@ -6,7 +6,6 @@ --data-types tas --device cpu --n-filters 18 ---out-channels 1 --eval-names demo --dataset-name hadcrut-mod --plot-result 0 diff --git a/environment-cuda.yml b/environment-cuda.yml index 95e5e9fd..f2d2733a 100644 --- a/environment-cuda.yml +++ b/environment-cuda.yml @@ -4,7 +4,7 @@ channels: - pytorch - defaults dependencies: - - pytorch>=1.11.0 + - pytorch-gpu>=1.11.0 - cudatoolkit>=11.7 - tqdm>=4.64.0 - torchvision>=0.12.0 diff --git a/tests/in/evaluation/minimum-1.inp b/tests/in/evaluation/minimum-1.inp index 80aee974..c65022d7 100644 --- a/tests/in/evaluation/minimum-1.inp +++ b/tests/in/evaluation/minimum-1.inp @@ -1,6 +1,6 @@ --device cpu --data-root-dir data/ ---model-dir tests/ref/ +--model-dir tests/ref/training/ --model-names minimum-1.inp.pth --evaluation-dirs tests/out/evaluation/ --data-names tas_hadcrut_187709_189308.nc diff --git a/tests/in/evaluation/normalize-bounds.inp b/tests/in/evaluation/normalize-bounds.inp new file mode 100644 index 00000000..c544b9f0 --- /dev/null +++ b/tests/in/evaluation/normalize-bounds.inp @@ -0,0 +1,17 @@ +--device cpu +--data-root-dir data/ +--mask-dir data/masks/ +--model-dir tests/ref/training/ +--model-names normalize-bounds.inp.pth +--evaluation-dirs tests/out/evaluation/ +--data-names tas_hadcrut_187709_189308.nc,tas_hadcrut_187709_189308.nc +--mask-names masks_tas_hadcrut_187709_189308.nc,masks_tas_hadcrut_187709_189308.nc +--eval-names normalize-bounds_1,normalize-bounds_2 +--data-types tas,tas +--encoding-layers 3,3 +--pooling-layers 0,0 +--normalize-data +--min-bounds=0,-inf +--max-bounds 10,inf +--steady-mask steady_tas_hadcrut.nc,steady_tas_hadcrut.nc +--n-output-data 2 diff --git a/tests/in/training/attention-channel-memory.inp b/tests/in/training/attention-channel-memory.inp index fe170bba..d4b4ac37 100644 --- a/tests/in/training/attention-channel-memory.inp +++ b/tests/in/training/attention-channel-memory.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc,20cr-1ens.nc,20cr-1ens.nc diff --git a/tests/in/training/attention.inp b/tests/in/training/attention.inp index b72a5bc1..0a16dbe1 100644 --- a/tests/in/training/attention.inp +++ b/tests/in/training/attention.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc,20cr-1ens.nc,20cr-1ens.nc diff --git a/tests/in/training/channel-infusion.inp b/tests/in/training/channel-infusion.inp index e52b1bdb..f1db689b 100644 --- a/tests/in/training/channel-infusion.inp +++ b/tests/in/training/channel-infusion.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc,20cr-1ens.nc,20cr-1ens.nc diff --git a/tests/in/training/channel-memory.inp b/tests/in/training/channel-memory.inp index 7a71c0f2..fbd9fc4c 100644 --- a/tests/in/training/channel-memory.inp +++ b/tests/in/training/channel-memory.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc diff --git a/tests/in/training/conv-lstm.inp b/tests/in/training/conv-lstm.inp index e350f5d0..da0a0fa4 100644 --- a/tests/in/training/conv-lstm.inp +++ b/tests/in/training/conv-lstm.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc diff --git a/tests/in/training/json-input.inp b/tests/in/training/json-input.inp index d1f8ea94..9bb0e163 100644 --- a/tests/in/training/json-input.inp +++ b/tests/in/training/json-input.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr.json,20cr.json diff --git a/tests/in/training/lazy-forecast.inp b/tests/in/training/lazy-forecast.inp new file mode 100644 index 00000000..940e7679 --- /dev/null +++ b/tests/in/training/lazy-forecast.inp @@ -0,0 +1,23 @@ +--device cpu --batch-size 2 +--n-threads 4 +--data-root-dir data/ +--mask-dir data/masks/ +--log-dir tests/out/training/logs/ +--snapshot-dir tests/out/training/ +--data-names 20cr-1ens.nc,20cr-1ens.nc +--mask-names hadcrut_missmask_1.nc,hadcrut_missmask_1.nc +--max-iter 10 +--data-types tas,tas +--encoding-layers 3,3 +--pooling-layers 0,0 +--save-model-interval 5 +--loss-criterion l1-hole +--weights kaiming +--loop-random-seed 3 +--cuda-random-seed 3 +--deterministic +--shuffle-masks +--normalize-data +--lstm-steps 4,3 +--pred-steps 5,1 +--lazy-load \ No newline at end of file diff --git a/tests/in/training/minimum-1.inp b/tests/in/training/minimum-1.inp index 38cfa268..749220e8 100644 --- a/tests/in/training/minimum-1.inp +++ b/tests/in/training/minimum-1.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc diff --git a/tests/in/training/minimum-2.inp b/tests/in/training/minimum-2.inp index ee985d86..4c3a658a 100644 --- a/tests/in/training/minimum-2.inp +++ b/tests/in/training/minimum-2.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 3 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc diff --git a/tests/in/training/normalize-bounds.inp b/tests/in/training/normalize-bounds.inp new file mode 100644 index 00000000..aa586ee2 --- /dev/null +++ b/tests/in/training/normalize-bounds.inp @@ -0,0 +1,24 @@ +--device cpu --batch-size 2 +--n-threads 4 +--data-root-dir data/ +--mask-dir data/masks/ +--log-dir tests/out/training/logs/ +--snapshot-dir tests/out/training/ +--data-names 20cr-1ens.nc,20cr-1ens.nc +--mask-names hadcrut_missmask_1.nc,hadcrut_missmask_1.nc +--max-iter 10 +--data-types tas,tas +--encoding-layers 3,3 +--pooling-layers 0,0 +--save-model-interval 5 +--loss-criterion l1-hole +--weights kaiming +--loop-random-seed 3 +--cuda-random-seed 3 +--deterministic +--shuffle-masks +--normalize-data +--min-bounds=0,-inf +--max-bounds 10,inf +--steady-mask steady_tas_hadcrut.nc,steady_tas_hadcrut.nc +--n-output-data 2 \ No newline at end of file diff --git a/tests/in/training/target-channel-memory.inp b/tests/in/training/target-channel-memory.inp index 88546419..d7b26d73 100644 --- a/tests/in/training/target-channel-memory.inp +++ b/tests/in/training/target-channel-memory.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc,20cr-1ens.nc,20cr-1ens.nc diff --git a/tests/in/training/traj-gru.inp b/tests/in/training/traj-gru.inp index c0b30861..70f080be 100644 --- a/tests/in/training/traj-gru.inp +++ b/tests/in/training/traj-gru.inp @@ -1,7 +1,7 @@ --device cpu --batch-size 2 --n-threads 4 --data-root-dir data/ ---mask-dir data/ +--mask-dir data/masks/ --log-dir tests/out/training/logs/ --snapshot-dir tests/out/training/ --data-names 20cr-1ens.nc diff --git a/tests/in/training/vae.inp b/tests/in/training/vae.inp new file mode 100644 index 00000000..e6eec029 --- /dev/null +++ b/tests/in/training/vae.inp @@ -0,0 +1,20 @@ +--device cpu --batch-size 2 +--n-threads 4 +--data-root-dir data/ +--mask-dir data/masks/ +--log-dir tests/out/training/logs/ +--snapshot-dir tests/out/training/ +--data-names 20cr-1ens.nc +--mask-names hadcrut_missmask_1.nc +--max-iter 10 +--data-types tas +--encoding-layers 3 +--save-model-interval 5 +--loss-criterion inpainting +--weights kaiming +--loop-random-seed 3 +--cuda-random-seed 3 +--deterministic +--vae-zdim 10 +--lr-scheduler-patience 2 +--shuffle-masks \ No newline at end of file diff --git a/tests/ref/minimum-1_infilled.nc b/tests/ref/evaluation/minimum-1_infilled.nc similarity index 100% rename from tests/ref/minimum-1_infilled.nc rename to tests/ref/evaluation/minimum-1_infilled.nc diff --git a/tests/ref/evaluation/normalize-bounds_1_infilled.nc b/tests/ref/evaluation/normalize-bounds_1_infilled.nc new file mode 100644 index 00000000..7df2b9e7 Binary files /dev/null and b/tests/ref/evaluation/normalize-bounds_1_infilled.nc differ diff --git a/tests/ref/evaluation/normalize-bounds_2_infilled.nc b/tests/ref/evaluation/normalize-bounds_2_infilled.nc new file mode 100644 index 00000000..0969b5da Binary files /dev/null and b/tests/ref/evaluation/normalize-bounds_2_infilled.nc differ diff --git a/tests/ref/attention-channel-memory.inp.pth b/tests/ref/training/attention-channel-memory.inp.pth similarity index 100% rename from tests/ref/attention-channel-memory.inp.pth rename to tests/ref/training/attention-channel-memory.inp.pth diff --git a/tests/ref/attention.inp.pth b/tests/ref/training/attention.inp.pth similarity index 100% rename from tests/ref/attention.inp.pth rename to tests/ref/training/attention.inp.pth diff --git a/tests/ref/channel-infusion.inp.pth b/tests/ref/training/channel-infusion.inp.pth similarity index 100% rename from tests/ref/channel-infusion.inp.pth rename to tests/ref/training/channel-infusion.inp.pth diff --git a/tests/ref/channel-memory.inp.pth b/tests/ref/training/channel-memory.inp.pth similarity index 100% rename from tests/ref/channel-memory.inp.pth rename to tests/ref/training/channel-memory.inp.pth diff --git a/tests/ref/conv-lstm.inp.pth b/tests/ref/training/conv-lstm.inp.pth similarity index 100% rename from tests/ref/conv-lstm.inp.pth rename to tests/ref/training/conv-lstm.inp.pth diff --git a/tests/ref/json-input.inp.pth b/tests/ref/training/json-input.inp.pth similarity index 100% rename from tests/ref/json-input.inp.pth rename to tests/ref/training/json-input.inp.pth diff --git a/tests/ref/training/lazy-forecast.inp.pth b/tests/ref/training/lazy-forecast.inp.pth new file mode 100644 index 00000000..fb8cc2dc Binary files /dev/null and b/tests/ref/training/lazy-forecast.inp.pth differ diff --git a/tests/ref/minimum-1.inp.pth b/tests/ref/training/minimum-1.inp.pth similarity index 100% rename from tests/ref/minimum-1.inp.pth rename to tests/ref/training/minimum-1.inp.pth diff --git a/tests/ref/minimum-2.inp.pth b/tests/ref/training/minimum-2.inp.pth similarity index 100% rename from tests/ref/minimum-2.inp.pth rename to tests/ref/training/minimum-2.inp.pth diff --git a/tests/ref/training/normalize-bounds.inp.pth b/tests/ref/training/normalize-bounds.inp.pth new file mode 100644 index 00000000..0077b602 Binary files /dev/null and b/tests/ref/training/normalize-bounds.inp.pth differ diff --git a/tests/ref/target-channel-memory.inp.pth b/tests/ref/training/target-channel-memory.inp.pth similarity index 100% rename from tests/ref/target-channel-memory.inp.pth rename to tests/ref/training/target-channel-memory.inp.pth diff --git a/tests/ref/traj-gru.inp.pth b/tests/ref/training/traj-gru.inp.pth similarity index 100% rename from tests/ref/traj-gru.inp.pth rename to tests/ref/training/traj-gru.inp.pth diff --git a/tests/ref/training/vae.inp.pth b/tests/ref/training/vae.inp.pth new file mode 100644 index 00000000..edbc5d9b Binary files /dev/null and b/tests/ref/training/vae.inp.pth differ diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 01e2e82b..09c4991b 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -1,18 +1,22 @@ import pytest +import os testdir = "tests/" @pytest.mark.evaluation -def test_evaluation_run(): +@pytest.mark.parametrize("file", sorted(os.listdir(testdir + "in/evaluation/"))) +def test_evaluation_run(file): from climatereconstructionai import evaluate - evaluate(testdir + "in/evaluation/minimum-1.inp") + evaluate('{}in/evaluation/{}'.format(testdir, file)) + # evaluate(testdir + "in/evaluation/minimum-1.inp") @pytest.mark.evaluation -def test_comp_netcdf(): +@pytest.mark.parametrize("file", sorted(os.listdir(testdir + "ref/evaluation/"))) +def test_comp_netcdf(file): import xarray as xr - ds_ref = xr.open_dataset(testdir + "ref/minimum-1_infilled.nc") - ds_run = xr.open_dataset(testdir + "out/evaluation/minimum-1_infilled.nc") + ds_ref = xr.open_dataset('{}ref/evaluation/{}'.format(testdir, file)) + ds_run = xr.open_dataset('{}out/evaluation/{}'.format(testdir, file)) # assert ds_ref.equals(ds_run) xr.testing.assert_allclose(ds_ref, ds_run, rtol=1e-15, atol=1e-8) diff --git a/tests/test_training.py b/tests/test_training.py index bc650f85..1b00c522 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,13 +14,13 @@ def test_training_run(file): @pytest.mark.training -@pytest.mark.parametrize("model", os.listdir(testdir + "in/training/")) +@pytest.mark.parametrize("model", os.listdir(testdir + "ref/training/")) def test_comp_models(model): import torch - ckpt_dict = torch.load(testdir + "ref/" + model + ".pth") + ckpt_dict = torch.load(testdir + "ref/training/" + model) for label in ckpt_dict["labels"]: model_ref = ckpt_dict[label]["model"] - model_run = torch.load(testdir + "out/training/ckpt/" + model + ".pth")[label]["model"] + model_run = torch.load(testdir + "out/training/ckpt/" + model)[label]["model"] for k_ref, k_run in zip(model_ref.keys(), model_run.keys()): assert k_ref == k_run print("* Checking {}...".format(k_ref))