Skip to content

Commit e83bf5c

Browse files
Merge pull request #69 from PolymathicAI/68-bug-welldatamodule-normalization-not-applied-to-val-and-test-data
68 bug welldatamodule normalization not applied to val and test data
2 parents 6cd3c44 + 4775658 commit e83bf5c

File tree

3 files changed

+61
-33
lines changed

3 files changed

+61
-33
lines changed

the_well/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.2.0"
1+
__version__ = "1.2.1"
22

33

44
__all__ = ["__version__"]

the_well/benchmark/trainer/training.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
self.best_val_loss = None
141141
self.starting_val_loss = float("inf")
142142
self.dset_metadata = self.datamodule.train_dataset.metadata
143+
self.dset_norm = None
143144
if self.datamodule.train_dataset.use_normalization:
144145
self.dset_norm = self.datamodule.train_dataset.norm
145146
if formatter == "channels_first_default":
@@ -176,38 +177,51 @@ def load_checkpoint(self, checkpoint_path: str):
176177
checkpoint["epoch"] + 1
177178
) # Saves after training loop, so start at next epoch
178179

179-
def normalize(self, batch):
180+
def normalize(self, batch_dict=None, direct_tensor=None):
180181
if hasattr(self, "dset_norm") and self.dset_norm:
181-
batch["input_fields"] = self.dset_norm.normalize_flattened(
182-
batch["input_fields"], "variable"
183-
)
184-
if "constant_fields" in batch:
185-
batch["constant_fields"] = self.dset_norm.normalize_flattened(
186-
batch["constant_fields"], "constant"
182+
if batch_dict is not None:
183+
batch_dict["input_fields"] = self.dset_norm.normalize_flattened(
184+
batch_dict["input_fields"], "variable"
187185
)
188-
return batch
186+
if "constant_fields" in batch_dict:
187+
batch_dict["constant_fields"] = self.dset_norm.normalize_flattened(
188+
batch_dict["constant_fields"], "constant"
189+
)
190+
if direct_tensor is not None:
191+
if self.is_delta:
192+
direct_tensor = self.dset_norm.normalize_delta_flattened(
193+
direct_tensor, "variable"
194+
)
195+
else:
196+
direct_tensor = self.dset_norm.normalize_flattened(
197+
direct_tensor, "variable"
198+
)
199+
return batch_dict, direct_tensor
189200

190-
def denormalize(self, batch, prediction):
201+
def denormalize(self, batch_dict=None, direct_tensor=None):
191202
if hasattr(self, "dset_norm") and self.dset_norm:
192-
batch["input_fields"] = self.dset_norm.denormalize_flattened(
193-
batch["input_fields"], "variable"
194-
)
195-
if "constant_fields" in batch:
196-
batch["constant_fields"] = self.dset_norm.denormalize_flattened(
197-
batch["constant_fields"], "constant"
198-
)
199-
200-
# Delta denormalization is different than full denormalization
201-
if self.is_delta:
202-
prediction = self.dset_norm.delta_denormalize_flattened(
203-
prediction, "variable"
204-
)
205-
else:
206-
prediction = self.dset_norm.denormalize_flattened(
207-
prediction, "variable"
203+
if batch_dict is not None:
204+
batch_dict["input_fields"] = self.dset_norm.denormalize_flattened(
205+
batch_dict["input_fields"], "variable"
208206
)
207+
if "constant_fields" in batch_dict:
208+
batch_dict["constant_fields"] = (
209+
self.dset_norm.denormalize_flattened(
210+
batch_dict["constant_fields"], "constant"
211+
)
212+
)
213+
if direct_tensor is not None:
214+
# Delta denormalization is different than full denormalization
215+
if self.is_delta:
216+
direct_tensor = self.dset_norm.delta_denormalize_flattened(
217+
direct_tensor, "variable"
218+
)
219+
else:
220+
direct_tensor = self.dset_norm.denormalize_flattened(
221+
direct_tensor, "variable"
222+
)
209223

210-
return batch, prediction
224+
return batch_dict, direct_tensor
211225

212226
def rollout_model(self, model, batch, formatter, train=True):
213227
"""Rollout the model for as many steps as we have data for."""
@@ -216,31 +230,37 @@ def rollout_model(self, model, batch, formatter, train=True):
216230
y_ref.shape[1], self.max_rollout_steps
217231
) # Number of timesteps in target
218232
y_ref = y_ref[:, :rollout_steps]
233+
# NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM).
234+
if not train:
235+
_, y_ref = self.denormalize(None, y_ref)
236+
219237
# Create a moving batch of one step at a time
220-
moving_batch = batch
238+
moving_batch = dict(batch)
221239
moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device)
222240
if "constant_fields" in moving_batch:
223241
moving_batch["constant_fields"] = moving_batch["constant_fields"].to(
224242
self.device
225243
)
226244
y_preds = []
227245
for i in range(rollout_steps):
228-
if not train:
229-
moving_batch = self.normalize(moving_batch)
246+
# NOTE: This is a quick fix so we can make datamodule behavior consistent.
247+
# Including local normalization schemes means there needs to be the option of normalizing each step
248+
# and there's currently not a registry of local vs global normalization schemes.
249+
if not train and self.datamodule.val_dataset.use_normalization and i > 0:
250+
moving_batch, _ = self.normalize(moving_batch)
230251

231252
inputs, _ = formatter.process_input(moving_batch)
232253
inputs = [x.to(self.device) for x in inputs]
233254
y_pred = model(*inputs)
234255

235256
y_pred = formatter.process_output_channel_last(y_pred)
236-
237257
if not train:
238258
moving_batch, y_pred = self.denormalize(moving_batch, y_pred)
239259

240260
if (not train) and self.is_delta:
241-
assert {
261+
assert (
242262
moving_batch["input_fields"][:, -1, ...].shape == y_pred.shape
243-
}, f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\
263+
), f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\
244264
and prediction {y_pred.shape}"
245265
y_pred = moving_batch["input_fields"][:, -1, ...] + y_pred
246266
y_pred = formatter.process_output_expand_time(y_pred)

the_well/data/datamodule.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def __init__(
163163
well_split_name="valid",
164164
include_filters=include_filters,
165165
exclude_filters=exclude_filters,
166+
use_normalization=use_normalization,
167+
normalization_type=normalization_type,
166168
n_steps_input=n_steps_input,
167169
n_steps_output=n_steps_output,
168170
storage_options=storage_kwargs,
@@ -181,6 +183,8 @@ def __init__(
181183
well_split_name="valid",
182184
include_filters=include_filters,
183185
exclude_filters=exclude_filters,
186+
use_normalization=use_normalization,
187+
normalization_type=normalization_type,
184188
max_rollout_steps=max_rollout_steps,
185189
n_steps_input=n_steps_input,
186190
n_steps_output=n_steps_output,
@@ -201,6 +205,8 @@ def __init__(
201205
well_split_name="test",
202206
include_filters=include_filters,
203207
exclude_filters=exclude_filters,
208+
use_normalization=use_normalization,
209+
normalization_type=normalization_type,
204210
n_steps_input=n_steps_input,
205211
n_steps_output=n_steps_output,
206212
storage_options=storage_kwargs,
@@ -219,6 +225,8 @@ def __init__(
219225
well_split_name="test",
220226
include_filters=include_filters,
221227
exclude_filters=exclude_filters,
228+
use_normalization=use_normalization,
229+
normalization_type=normalization_type,
222230
max_rollout_steps=max_rollout_steps,
223231
n_steps_input=n_steps_input,
224232
n_steps_output=n_steps_output,

0 commit comments

Comments
 (0)