Skip to content

Commit 08cb68a

Browse files
authored
Merge pull request #11 from Kitware/slicing_averaging_fixes
fix: Slicing axes and averaging fixes
2 parents d96aea6 + 2edad77 commit 08cb68a

File tree

2 files changed

+15
-38
lines changed

2 files changed

+15
-38
lines changed

src/e3sm_quickview/plugins/eam_reader.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
try:
1010
import netCDF4
1111
import numpy as np
12+
import json
1213

1314
_has_deps = True
1415
except ImportError as ie:
@@ -230,9 +231,6 @@ def __init__(self):
230231
self._DataFileName = None
231232
self._ConnFileName = None
232233
self._dirty = False
233-
self._surface_update = True
234-
self._midpoint_update = True
235-
self._interface_update = True
236234

237235
# Variables for dimension sliders
238236
self._time = 0
@@ -396,23 +394,18 @@ def _get_cached_area(self, vardata):
396394
self._cached_area[mask] = np.nan
397395
return self._cached_area
398396

399-
def _load_variable(self, vardata, varmeta, timeInd):
397+
def _load_variable(self, vardata, varmeta):
400398
"""Load variable data with dimension-based slicing."""
401399
try:
402400
# Build slice tuple based on variable's dimensions and user-selected slices
403401
slice_tuple = []
404402
for dim in varmeta.dimensions:
405403
if dim == self._data_horizontal_dim:
406-
continue
407-
# elif dim == "time":
408-
# # Use timeInd for time dimension
409-
# slice_tuple.append(timeInd)
410-
elif hasattr(self, "_slices") and dim in self._slices:
411-
# Use user-specified slice for this dimension
412-
slice_tuple.append(self._slices[dim])
404+
slice_tuple.append(slice(None))
413405
else:
414406
# Use all data for unspecified dimensions
415-
slice_tuple.append(slice(None))
407+
slice_tuple.append(self._slices.get(dim, 0))
408+
416409
# Get data with proper slicing
417410
data = vardata[varmeta.name][tuple(slice_tuple)].data.flatten()
418411
data = np.where(data == varmeta.fillval, np.nan, data)
@@ -511,7 +504,7 @@ def _populate_variable_metadata(self):
511504
if self._data_horizontal_dim not in dims:
512505
continue
513506
varmeta = VarMeta(name, info, self._data_horizontal_dim)
514-
if len(dims) == 1 and "area" in name:
507+
if len(dims) == 1 and "area" in name.lower():
515508
self._areavar = varmeta
516509
if len(dims) > 1:
517510
all_dimensions.update(dims)
@@ -588,20 +581,10 @@ def SetConnFileName(self, fname):
588581

589582
def SetSlicing(self, slice_str):
590583
# Parse JSON string containing dimension slices and update self._slices
591-
# Initialize _slices if not already done
592-
if not hasattr(self, "_slices"):
593-
self._slices = {}
594-
595-
# Initialize dimensions if not already done
596-
if not hasattr(self, "_dimensions"):
597-
self._dimensions = {}
598584

599585
if slice_str and slice_str.strip(): # Check for non-empty string
600586
try:
601-
import json
602-
603587
slice_dict = json.loads(slice_str)
604-
605588
# Validate and update slices for provided dimensions
606589
invalid_slices = []
607590
for dim, slice_val in slice_dict.items():
@@ -713,16 +696,6 @@ def RequestData(self, request, inInfo, outInfo):
713696
print_error("Required Python module 'netCDF4' or 'numpy' missing!")
714697
return 0
715698

716-
# Getting the correct time index
717-
executive = self.GetExecutive()
718-
from_port = request.Get(executive.FROM_OUTPUT_PORT())
719-
timeInd = self.get_time_index(outInfo, executive, from_port)
720-
if self._time != timeInd:
721-
self._time = timeInd
722-
self._surface_update = True
723-
self._midpoint_update = True
724-
self._interface_update = True
725-
726699
meshdata = self._get_mesh_dataset()
727700
vardata = self._get_var_dataset()
728701

@@ -766,9 +739,8 @@ def RequestData(self, request, inInfo, outInfo):
766739
if self._variable_selection.ArrayIsEnabled(name):
767740
if output_mesh.CellData.HasArray(name):
768741
to_remove.remove(name)
769-
if not output_mesh.CellData.HasArray(name) or self._surface_update:
770-
data = self._load_variable(vardata, varmeta, timeInd)
771-
output_mesh.CellData.append(data, name)
742+
data = self._load_variable(vardata, varmeta)
743+
output_mesh.CellData.append(data, name)
772744

773745
area_var_name = "area"
774746
if self._areavar and not output_mesh.CellData.HasArray(area_var_name):

src/e3sm_quickview/utils/compute.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def extract_avgs(pv_data, array_names):
3939
area_array = vtk_data.GetCellData().GetArray("area")
4040
for name in array_names:
4141
vtk_array = vtk_data.GetCellData().GetArray(name)
42-
avg_value = calculate_weighted_average(vtk_array, area_array)
42+
if vtk_array is None:
43+
results[name] = np.nan
44+
continue
45+
if area_array:
46+
avg_value = calculate_weighted_average(vtk_array, area_array)
47+
else:
48+
avg_value = float(np.nanmean(np.array(vtk_array)))
4349
results[name] = avg_value
44-
4550
return results

0 commit comments

Comments
 (0)