diff --git a/the_well/__init__.py b/the_well/__init__.py index fe54b95b..c7dec85a 100755 --- a/the_well/__init__.py +++ b/the_well/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0" +__version__ = "1.2.1" __all__ = ["__version__"] diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index ffb057ba..8509c1ab 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -176,38 +176,51 @@ def load_checkpoint(self, checkpoint_path: str): checkpoint["epoch"] + 1 ) # Saves after training loop, so start at next epoch - def normalize(self, batch): + def normalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.normalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.normalize_flattened( - batch["constant_fields"], "constant" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.normalize_flattened( + batch_dict["input_fields"], "variable" ) - return batch + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = self.dset_norm.normalize_flattened( + batch_dict["constant_fields"], "constant" + ) + if direct_tensor is not None: + if self.is_delta: + direct_tensor = self.dset_norm.normalize_delta_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.normalize_flattened( + direct_tensor, "variable" + ) + return batch_dict, direct_tensor - def denormalize(self, batch, prediction): + def denormalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.denormalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.denormalize_flattened( - batch["constant_fields"], "constant" - ) - - # Delta denormalization is different than full denormalization - if self.is_delta: - prediction = self.dset_norm.delta_denormalize_flattened( - prediction, "variable" - ) - else: - prediction = self.dset_norm.denormalize_flattened( - prediction, "variable" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.denormalize_flattened( + batch_dict["input_fields"], "variable" ) + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = ( + self.dset_norm.denormalize_flattened( + batch_dict["constant_fields"], "constant" + ) + ) + if direct_tensor is not None: + # Delta denormalization is different than full denormalization + if self.is_delta: + direct_tensor = self.dset_norm.delta_denormalize_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.denormalize_flattened( + direct_tensor, "variable" + ) - return batch, prediction + return batch_dict, direct_tensor def rollout_model(self, model, batch, formatter, train=True): """Rollout the model for as many steps as we have data for.""" @@ -216,6 +229,10 @@ def rollout_model(self, model, batch, formatter, train=True): y_ref.shape[1], self.max_rollout_steps ) # Number of timesteps in target y_ref = y_ref[:, :rollout_steps] + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if not train: + _, y_ref = self.denormalize(None, y_ref) + # Create a moving batch of one step at a time moving_batch = batch moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device) @@ -225,15 +242,15 @@ def rollout_model(self, model, batch, formatter, train=True): ) y_preds = [] for i in range(rollout_steps): - if not train: - moving_batch = self.normalize(moving_batch) + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if i > 0 and not train: + moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch) inputs = [x.to(self.device) for x in inputs] y_pred = model(*inputs) y_pred = formatter.process_output_channel_last(y_pred) - if not train: moving_batch, y_pred = self.denormalize(moving_batch, y_pred) diff --git a/the_well/data/datamodule.py b/the_well/data/datamodule.py index 3e037ec2..f3660a55 100755 --- a/the_well/data/datamodule.py +++ b/the_well/data/datamodule.py @@ -163,6 +163,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -181,6 +183,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output, @@ -201,6 +205,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -219,6 +225,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output,