From 715226c9ecc4d996188cb28fe0397114d7df997f Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 13:53:13 -0500 Subject: [PATCH 1/4] Move normalization into dataloader - perform denormalization in file --- the_well/benchmark/trainer/training.py | 57 +++++++++++++++----------- the_well/data/datamodule.py | 8 ++++ 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index ffb057ba..ba5587bc 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -176,38 +176,45 @@ def load_checkpoint(self, checkpoint_path: str): checkpoint["epoch"] + 1 ) # Saves after training loop, so start at next epoch - def normalize(self, batch): + def normalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.normalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.normalize_flattened( - batch["constant_fields"], "constant" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.normalize_flattened( + batch_dict["input_fields"], "variable" + ) + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = self.dset_norm.normalize_flattened( + batch_dict["constant_fields"], "constant" + ) + if direct_tensor is not None: + direct_tensor = self.dset_norm.normalize_flattened( + direct_tensor, "variable" ) - return batch + return batch_dict, direct_tensor + return batch_dict, direct_tensor - def denormalize(self, batch, prediction): + def denormalize(self, batch_dict=None, direct_tensor=None): if hasattr(self, "dset_norm") and self.dset_norm: - batch["input_fields"] = self.dset_norm.denormalize_flattened( - batch["input_fields"], "variable" - ) - if "constant_fields" in batch: - batch["constant_fields"] = self.dset_norm.denormalize_flattened( - batch["constant_fields"], "constant" + if batch_dict is not None: + batch_dict["input_fields"] = self.dset_norm.denormalize_flattened( + batch_dict["input_fields"], "variable" ) + if "constant_fields" in batch_dict: + batch_dict["constant_fields"] = self.dset_norm.denormalize_flattened( + batch_dict["constant_fields"], "constant" + ) # Delta denormalization is different than full denormalization if self.is_delta: - prediction = self.dset_norm.delta_denormalize_flattened( - prediction, "variable" + direct_tensor = self.dset_norm.delta_denormalize_flattened( + direct_tensor, "variable" ) else: - prediction = self.dset_norm.denormalize_flattened( - prediction, "variable" + direct_tensor = self.dset_norm.denormalize_flattened( + direct_tensor, "variable" ) - return batch, prediction + return batch_dict, direct_tensor def rollout_model(self, model, batch, formatter, train=True): """Rollout the model for as many steps as we have data for.""" @@ -216,6 +223,10 @@ def rollout_model(self, model, batch, formatter, train=True): y_ref.shape[1], self.max_rollout_steps ) # Number of timesteps in target y_ref = y_ref[:, :rollout_steps] + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if not train: + _, y_ref = self.denormalize(None, y_ref) + # Create a moving batch of one step at a time moving_batch = batch moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device) @@ -225,15 +236,15 @@ def rollout_model(self, model, batch, formatter, train=True): ) y_preds = [] for i in range(rollout_steps): - if not train: - moving_batch = self.normalize(moving_batch) + # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). + if i > 0 and not train: + moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch) inputs = [x.to(self.device) for x in inputs] y_pred = model(*inputs) y_pred = formatter.process_output_channel_last(y_pred) - if not train: moving_batch, y_pred = self.denormalize(moving_batch, y_pred) diff --git a/the_well/data/datamodule.py b/the_well/data/datamodule.py index 3e037ec2..f3660a55 100755 --- a/the_well/data/datamodule.py +++ b/the_well/data/datamodule.py @@ -163,6 +163,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -181,6 +183,8 @@ def __init__( well_split_name="valid", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output, @@ -201,6 +205,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, n_steps_input=n_steps_input, n_steps_output=n_steps_output, storage_options=storage_kwargs, @@ -219,6 +225,8 @@ def __init__( well_split_name="test", include_filters=include_filters, exclude_filters=exclude_filters, + use_normalization=use_normalization, + normalization_type=normalization_type, max_rollout_steps=max_rollout_steps, n_steps_input=n_steps_input, n_steps_output=n_steps_output, From 1f626ac97622f93b593a40b8bfbedb62a400b693 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 14:02:09 -0500 Subject: [PATCH 2/4] Increment hotfix version --- the_well/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/the_well/__init__.py b/the_well/__init__.py index fe54b95b..c7dec85a 100755 --- a/the_well/__init__.py +++ b/the_well/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0" +__version__ = "1.2.1" __all__ = ["__version__"] From 5e22beec296ef99447bcc901f452d66cf591ec1b Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 14:07:15 -0500 Subject: [PATCH 3/4] Finish synchronizing the norm/denorm funcs --- the_well/benchmark/trainer/training.py | 32 +++++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index ba5587bc..11c86543 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -187,10 +187,14 @@ def normalize(self, batch_dict=None, direct_tensor=None): batch_dict["constant_fields"], "constant" ) if direct_tensor is not None: - direct_tensor = self.dset_norm.normalize_flattened( - direct_tensor, "variable" - ) - return batch_dict, direct_tensor + if self.is_delta: + direct_tensor = self.dset_norm.normalize_delta_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.normalize_flattened( + direct_tensor, "variable" + ) return batch_dict, direct_tensor def denormalize(self, batch_dict=None, direct_tensor=None): @@ -203,16 +207,16 @@ def denormalize(self, batch_dict=None, direct_tensor=None): batch_dict["constant_fields"] = self.dset_norm.denormalize_flattened( batch_dict["constant_fields"], "constant" ) - - # Delta denormalization is different than full denormalization - if self.is_delta: - direct_tensor = self.dset_norm.delta_denormalize_flattened( - direct_tensor, "variable" - ) - else: - direct_tensor = self.dset_norm.denormalize_flattened( - direct_tensor, "variable" - ) + if direct_tensor is not None: + # Delta denormalization is different than full denormalization + if self.is_delta: + direct_tensor = self.dset_norm.delta_denormalize_flattened( + direct_tensor, "variable" + ) + else: + direct_tensor = self.dset_norm.denormalize_flattened( + direct_tensor, "variable" + ) return batch_dict, direct_tensor From 1540ae73d75a8384b1dfbe9ae16edc82e28063b0 Mon Sep 17 00:00:00 2001 From: Mike McCabe Date: Wed, 10 Dec 2025 14:22:05 -0500 Subject: [PATCH 4/4] linter --- the_well/benchmark/trainer/training.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 11c86543..8509c1ab 100755 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -204,8 +204,10 @@ def denormalize(self, batch_dict=None, direct_tensor=None): batch_dict["input_fields"], "variable" ) if "constant_fields" in batch_dict: - batch_dict["constant_fields"] = self.dset_norm.denormalize_flattened( - batch_dict["constant_fields"], "constant" + batch_dict["constant_fields"] = ( + self.dset_norm.denormalize_flattened( + batch_dict["constant_fields"], "constant" + ) ) if direct_tensor is not None: # Delta denormalization is different than full denormalization @@ -241,7 +243,7 @@ def rollout_model(self, model, batch, formatter, train=True): y_preds = [] for i in range(rollout_steps): # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM). - if i > 0 and not train: + if i > 0 and not train: moving_batch, _ = self.normalize(moving_batch) inputs, _ = formatter.process_input(moving_batch)