@@ -296,14 +296,26 @@ def _pad_axes(
296296 expand_dims = expand_dims + (1 ,) * tensor_order
297297 return np .tile (field_data , expand_dims )
298298
299+ def _postprocess_field_list (self , field_list , output_list , order ):
300+ """Postprocesses field list to apply tensor transforms"""
301+ if len (field_list ) > 0 :
302+ field_list = torch .stack (field_list , - (order + 1 ))
303+ for tensor_transform in self .tensor_transforms :
304+ field_list = tensor_transform (field_list , order = order )
305+ if self .flatten_tensors :
306+ field_list = field_list .flatten (- (order + 1 ))
307+ output_list .append (field_list )
308+ return output_list
309+
299310 def _reconstruct_fields (self , file , sample_idx , time_idx , n_steps , dt ):
300311 """Reconstruct space fields starting at index sample_idx, time_idx, with
301312 n_steps and dt stride. Apply transformations if provided."""
302313 variable_fields = []
303314 constant_fields = []
304315 # Iterate through field types and apply appropriate transforms to stack them
305316 for i , order_fields in enumerate (["t0_fields" , "t1_fields" , "t2_fields" ]):
306- sub_fields = []
317+ variable_subfields = []
318+ constant_subfields = []
307319 for field_name in file [order_fields ].attrs ["field_names" ]:
308320 field = file [order_fields ][field_name ]
309321 use_dims = field .attrs ["dim_varying" ]
@@ -329,18 +341,18 @@ def _reconstruct_fields(self, file, sample_idx, time_idx, n_steps, dt):
329341 self ._check_cache (
330342 field_name , field_data
331343 ) # If constant and processed, cache
332- sub_fields .append (field_data )
344+ if field .attrs ["time_varying" ]:
345+ variable_subfields .append (field_data )
346+ else :
347+ constant_subfields .append (field_data )
348+
333349 # Stack fields such that the last i dims are the tensor dims
334- sub_fields = torch .stack (sub_fields , - (i + 1 ))
335- for tensor_transform in self .tensor_transforms :
336- sub_fields = tensor_transform (sub_fields , order = i )
337- # If we're flattening tensors, we can then flatten last i dims
338- if self .flatten_tensors :
339- sub_fields = sub_fields .flatten (- (i + 1 ))
340- if field .attrs ["time_varying" ]:
341- variable_fields .append (sub_fields )
342- else :
343- constant_fields .append (sub_fields )
350+ variable_fields = self ._postprocess_field_list (
351+ variable_subfields , variable_fields , i
352+ )
353+ constant_fields = self ._postprocess_field_list (
354+ constant_subfields , constant_fields , i
355+ )
344356
345357 return tuple (
346358 [
0 commit comments