Skip to content

Commit ec7c5d5

Browse files
Merge pull request #34 from PolymathicAI/33-issue-with-dataloader-and-computation-of-statistics
33 issue with dataloader and computation of statistics
2 parents d2924b3 + 23242db commit ec7c5d5

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

the_well/benchmark/data/datasets.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,26 @@ def _pad_axes(
296296
expand_dims = expand_dims + (1,) * tensor_order
297297
return np.tile(field_data, expand_dims)
298298

299+
def _postprocess_field_list(self, field_list, output_list, order):
300+
"""Postprocesses field list to apply tensor transforms"""
301+
if len(field_list) > 0:
302+
field_list = torch.stack(field_list, -(order + 1))
303+
for tensor_transform in self.tensor_transforms:
304+
field_list = tensor_transform(field_list, order=order)
305+
if self.flatten_tensors:
306+
field_list = field_list.flatten(-(order + 1))
307+
output_list.append(field_list)
308+
return output_list
309+
299310
def _reconstruct_fields(self, file, sample_idx, time_idx, n_steps, dt):
300311
"""Reconstruct space fields starting at index sample_idx, time_idx, with
301312
n_steps and dt stride. Apply transformations if provided."""
302313
variable_fields = []
303314
constant_fields = []
304315
# Iterate through field types and apply appropriate transforms to stack them
305316
for i, order_fields in enumerate(["t0_fields", "t1_fields", "t2_fields"]):
306-
sub_fields = []
317+
variable_subfields = []
318+
constant_subfields = []
307319
for field_name in file[order_fields].attrs["field_names"]:
308320
field = file[order_fields][field_name]
309321
use_dims = field.attrs["dim_varying"]
@@ -329,18 +341,18 @@ def _reconstruct_fields(self, file, sample_idx, time_idx, n_steps, dt):
329341
self._check_cache(
330342
field_name, field_data
331343
) # If constant and processed, cache
332-
sub_fields.append(field_data)
344+
if field.attrs["time_varying"]:
345+
variable_subfields.append(field_data)
346+
else:
347+
constant_subfields.append(field_data)
348+
333349
# Stack fields such that the last i dims are the tensor dims
334-
sub_fields = torch.stack(sub_fields, -(i + 1))
335-
for tensor_transform in self.tensor_transforms:
336-
sub_fields = tensor_transform(sub_fields, order=i)
337-
# If we're flattening tensors, we can then flatten last i dims
338-
if self.flatten_tensors:
339-
sub_fields = sub_fields.flatten(-(i + 1))
340-
if field.attrs["time_varying"]:
341-
variable_fields.append(sub_fields)
342-
else:
343-
constant_fields.append(sub_fields)
350+
variable_fields = self._postprocess_field_list(
351+
variable_subfields, variable_fields, i
352+
)
353+
constant_fields = self._postprocess_field_list(
354+
constant_subfields, constant_fields, i
355+
)
344356

345357
return tuple(
346358
[

0 commit comments

Comments
 (0)