@@ -140,6 +140,7 @@ def __init__(
140140 self .best_val_loss = None
141141 self .starting_val_loss = float ("inf" )
142142 self .dset_metadata = self .datamodule .train_dataset .metadata
143+ self .dset_norm = None
143144 if self .datamodule .train_dataset .use_normalization :
144145 self .dset_norm = self .datamodule .train_dataset .norm
145146 if formatter == "channels_first_default" :
@@ -176,38 +177,51 @@ def load_checkpoint(self, checkpoint_path: str):
176177 checkpoint ["epoch" ] + 1
177178 ) # Saves after training loop, so start at next epoch
178179
179- def normalize (self , batch ):
180+ def normalize (self , batch_dict = None , direct_tensor = None ):
180181 if hasattr (self , "dset_norm" ) and self .dset_norm :
181- batch ["input_fields" ] = self .dset_norm .normalize_flattened (
182- batch ["input_fields" ], "variable"
183- )
184- if "constant_fields" in batch :
185- batch ["constant_fields" ] = self .dset_norm .normalize_flattened (
186- batch ["constant_fields" ], "constant"
182+ if batch_dict is not None :
183+ batch_dict ["input_fields" ] = self .dset_norm .normalize_flattened (
184+ batch_dict ["input_fields" ], "variable"
187185 )
188- return batch
186+ if "constant_fields" in batch_dict :
187+ batch_dict ["constant_fields" ] = self .dset_norm .normalize_flattened (
188+ batch_dict ["constant_fields" ], "constant"
189+ )
190+ if direct_tensor is not None :
191+ if self .is_delta :
192+ direct_tensor = self .dset_norm .normalize_delta_flattened (
193+ direct_tensor , "variable"
194+ )
195+ else :
196+ direct_tensor = self .dset_norm .normalize_flattened (
197+ direct_tensor , "variable"
198+ )
199+ return batch_dict , direct_tensor
189200
190- def denormalize (self , batch , prediction ):
201+ def denormalize (self , batch_dict = None , direct_tensor = None ):
191202 if hasattr (self , "dset_norm" ) and self .dset_norm :
192- batch ["input_fields" ] = self .dset_norm .denormalize_flattened (
193- batch ["input_fields" ], "variable"
194- )
195- if "constant_fields" in batch :
196- batch ["constant_fields" ] = self .dset_norm .denormalize_flattened (
197- batch ["constant_fields" ], "constant"
198- )
199-
200- # Delta denormalization is different than full denormalization
201- if self .is_delta :
202- prediction = self .dset_norm .delta_denormalize_flattened (
203- prediction , "variable"
204- )
205- else :
206- prediction = self .dset_norm .denormalize_flattened (
207- prediction , "variable"
203+ if batch_dict is not None :
204+ batch_dict ["input_fields" ] = self .dset_norm .denormalize_flattened (
205+ batch_dict ["input_fields" ], "variable"
208206 )
207+ if "constant_fields" in batch_dict :
208+ batch_dict ["constant_fields" ] = (
209+ self .dset_norm .denormalize_flattened (
210+ batch_dict ["constant_fields" ], "constant"
211+ )
212+ )
213+ if direct_tensor is not None :
214+ # Delta denormalization is different than full denormalization
215+ if self .is_delta :
216+ direct_tensor = self .dset_norm .delta_denormalize_flattened (
217+ direct_tensor , "variable"
218+ )
219+ else :
220+ direct_tensor = self .dset_norm .denormalize_flattened (
221+ direct_tensor , "variable"
222+ )
209223
210- return batch , prediction
224+ return batch_dict , direct_tensor
211225
212226 def rollout_model (self , model , batch , formatter , train = True ):
213227 """Rollout the model for as many steps as we have data for."""
@@ -216,31 +230,37 @@ def rollout_model(self, model, batch, formatter, train=True):
216230 y_ref .shape [1 ], self .max_rollout_steps
217231 ) # Number of timesteps in target
218232 y_ref = y_ref [:, :rollout_steps ]
233+ # NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM).
234+ if not train :
235+ _ , y_ref = self .denormalize (None , y_ref )
236+
219237 # Create a moving batch of one step at a time
220- moving_batch = batch
238+ moving_batch = dict ( batch )
221239 moving_batch ["input_fields" ] = moving_batch ["input_fields" ].to (self .device )
222240 if "constant_fields" in moving_batch :
223241 moving_batch ["constant_fields" ] = moving_batch ["constant_fields" ].to (
224242 self .device
225243 )
226244 y_preds = []
227245 for i in range (rollout_steps ):
228- if not train :
229- moving_batch = self .normalize (moving_batch )
246+ # NOTE: This is a quick fix so we can make datamodule behavior consistent.
247+ # Including local normalization schemes means there needs to be the option of normalizing each step
248+ # and there's currently not a registry of local vs global normalization schemes.
249+ if not train and self .datamodule .val_dataset .use_normalization and i > 0 :
250+ moving_batch , _ = self .normalize (moving_batch )
230251
231252 inputs , _ = formatter .process_input (moving_batch )
232253 inputs = [x .to (self .device ) for x in inputs ]
233254 y_pred = model (* inputs )
234255
235256 y_pred = formatter .process_output_channel_last (y_pred )
236-
237257 if not train :
238258 moving_batch , y_pred = self .denormalize (moving_batch , y_pred )
239259
240260 if (not train ) and self .is_delta :
241- assert {
261+ assert (
242262 moving_batch ["input_fields" ][:, - 1 , ...].shape == y_pred .shape
243- } , f"Mismatching shapes between last input timestep { moving_batch [:, - 1 , ...].shape } \
263+ ) , f"Mismatching shapes between last input timestep { moving_batch [:, - 1 , ...].shape } \
244264 and prediction { y_pred .shape } "
245265 y_pred = moving_batch ["input_fields" ][:, - 1 , ...] + y_pred
246266 y_pred = formatter .process_output_expand_time (y_pred )
0 commit comments