Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion the_well/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.0"
__version__ = "1.2.1"


__all__ = ["__version__"]
75 changes: 46 additions & 29 deletions the_well/benchmark/trainer/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,38 +176,51 @@ def load_checkpoint(self, checkpoint_path: str):
checkpoint["epoch"] + 1
) # Saves after training loop, so start at next epoch

def normalize(self, batch):
def normalize(self, batch_dict=None, direct_tensor=None):
if hasattr(self, "dset_norm") and self.dset_norm:
batch["input_fields"] = self.dset_norm.normalize_flattened(
batch["input_fields"], "variable"
)
if "constant_fields" in batch:
batch["constant_fields"] = self.dset_norm.normalize_flattened(
batch["constant_fields"], "constant"
if batch_dict is not None:
batch_dict["input_fields"] = self.dset_norm.normalize_flattened(
batch_dict["input_fields"], "variable"
)
return batch
if "constant_fields" in batch_dict:
batch_dict["constant_fields"] = self.dset_norm.normalize_flattened(
batch_dict["constant_fields"], "constant"
)
if direct_tensor is not None:
if self.is_delta:
direct_tensor = self.dset_norm.normalize_delta_flattened(
direct_tensor, "variable"
)
else:
direct_tensor = self.dset_norm.normalize_flattened(
direct_tensor, "variable"
)
return batch_dict, direct_tensor

def denormalize(self, batch, prediction):
def denormalize(self, batch_dict=None, direct_tensor=None):
if hasattr(self, "dset_norm") and self.dset_norm:
batch["input_fields"] = self.dset_norm.denormalize_flattened(
batch["input_fields"], "variable"
)
if "constant_fields" in batch:
batch["constant_fields"] = self.dset_norm.denormalize_flattened(
batch["constant_fields"], "constant"
)

# Delta denormalization is different than full denormalization
if self.is_delta:
prediction = self.dset_norm.delta_denormalize_flattened(
prediction, "variable"
)
else:
prediction = self.dset_norm.denormalize_flattened(
prediction, "variable"
if batch_dict is not None:
batch_dict["input_fields"] = self.dset_norm.denormalize_flattened(
batch_dict["input_fields"], "variable"
)
if "constant_fields" in batch_dict:
batch_dict["constant_fields"] = (
self.dset_norm.denormalize_flattened(
batch_dict["constant_fields"], "constant"
)
)
if direct_tensor is not None:
# Delta denormalization is different than full denormalization
if self.is_delta:
direct_tensor = self.dset_norm.delta_denormalize_flattened(
direct_tensor, "variable"
)
else:
direct_tensor = self.dset_norm.denormalize_flattened(
direct_tensor, "variable"
)

return batch, prediction
return batch_dict, direct_tensor

def rollout_model(self, model, batch, formatter, train=True):
"""Rollout the model for as many steps as we have data for."""
Expand All @@ -216,6 +229,10 @@ def rollout_model(self, model, batch, formatter, train=True):
y_ref.shape[1], self.max_rollout_steps
) # Number of timesteps in target
y_ref = y_ref[:, :rollout_steps]
# NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM).
if not train:
_, y_ref = self.denormalize(None, y_ref)

# Create a moving batch of one step at a time
moving_batch = batch
moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device)
Expand All @@ -225,15 +242,15 @@ def rollout_model(self, model, batch, formatter, train=True):
)
y_preds = []
for i in range(rollout_steps):
if not train:
moving_batch = self.normalize(moving_batch)
# NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM).
if i > 0 and not train:
moving_batch, _ = self.normalize(moving_batch)

inputs, _ = formatter.process_input(moving_batch)
inputs = [x.to(self.device) for x in inputs]
y_pred = model(*inputs)

y_pred = formatter.process_output_channel_last(y_pred)

if not train:
moving_batch, y_pred = self.denormalize(moving_batch, y_pred)

Expand Down
8 changes: 8 additions & 0 deletions the_well/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def __init__(
well_split_name="valid",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
storage_options=storage_kwargs,
Expand All @@ -181,6 +183,8 @@ def __init__(
well_split_name="valid",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
max_rollout_steps=max_rollout_steps,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
Expand All @@ -201,6 +205,8 @@ def __init__(
well_split_name="test",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
storage_options=storage_kwargs,
Expand All @@ -219,6 +225,8 @@ def __init__(
well_split_name="test",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
max_rollout_steps=max_rollout_steps,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
Expand Down
Loading