Skip to content

Commit d75c5e4

Browse files
authored
Merge pull request #4 from Kitware/dimension_matching
feat: Adding dimension matching for horizontal axis
2 parents e2e2732 + 4512933 commit d75c5e4

File tree

1 file changed

+74
-22
lines changed

1 file changed

+74
-22
lines changed

src/e3sm_quickview/plugins/eam_reader.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
)
1919
_has_deps = False
2020

21-
dims1 = set(["ncol"])
22-
dims2 = set(["time", "ncol"])
23-
dims3i = set(["time", "ilev", "ncol"])
24-
dims3m = set(["time", "lev", "ncol"])
21+
# Dimensions will be dynamically determined from connectivity and data files
2522

2623

2724
class EAMConstants:
@@ -46,7 +43,7 @@ class VarType(Enum):
4643

4744

4845
class VarMeta:
49-
def __init__(self, name, info):
46+
def __init__(self, name, info, horizontal_dim=None):
5047
self.name = name
5148
self.type = None
5249
self.transpose = False
@@ -64,7 +61,8 @@ def __init__(self, name, info):
6461
elif "ilev" in dims:
6562
self.type = VarType._3Di
6663

67-
if "ncol" in dims[1]:
64+
# Use dynamic horizontal dimension
65+
if horizontal_dim and len(dims) > 1 and horizontal_dim in dims[1]:
6866
self.transpose = True
6967

7068

@@ -74,7 +72,7 @@ def compare(data, arrays, dim):
7472
raise Exception(
7573
"Length of hya_/hyb_ variable does not match the corresponding dimension"
7674
)
77-
for i, array in enumerate(arrays[1:], start=1):
75+
for array in arrays[1:]:
7876
comp = data[array][:].flatten()
7977
if not np.array_equal(ref, comp):
8078
return None
@@ -251,6 +249,10 @@ def __init__(self):
251249
self._cached_lev = None
252250
self._cached_ilev = None
253251
self._cached_area = None
252+
253+
# Dynamic dimension detection
254+
self._horizontal_dim = None # From connectivity file
255+
self._data_horizontal_dim = None # Matched in data file
254256

255257
def __del__(self):
256258
"""Clean up NetCDF file handles on deletion."""
@@ -308,6 +310,31 @@ def _clear(self):
308310
self._cached_lev = None
309311
self._cached_ilev = None
310312
self._cached_area = None
313+
# Clear dimension detection
314+
self._horizontal_dim = None
315+
self._data_horizontal_dim = None
316+
317+
def _identify_horizontal_dimension(self, meshdata, vardata):
318+
"""Identify horizontal dimension from connectivity and match with data file."""
319+
if self._horizontal_dim and self._data_horizontal_dim:
320+
return # Already identified
321+
322+
# Get first dimension from connectivity file
323+
conn_dims = list(meshdata.dimensions.keys())
324+
if not conn_dims:
325+
print_error("No dimensions found in connectivity file")
326+
return
327+
328+
self._horizontal_dim = conn_dims[0]
329+
conn_size = meshdata.dimensions[self._horizontal_dim].size
330+
331+
# Match dimension in data file by size
332+
for dim_name, dim_obj in vardata.dimensions.items():
333+
if dim_obj.size == conn_size:
334+
self._data_horizontal_dim = dim_name
335+
return
336+
337+
print_error(f"Could not match horizontal dimension size {conn_size} in data file")
311338

312339
def _clear_geometry_cache(self):
313340
"""Clear cached geometry data."""
@@ -381,18 +408,14 @@ def _build_geometry(self, meshdata):
381408
return
382409

383410
dims = meshdata.dimensions
384-
mdims = np.array(list(meshdata.dimensions.keys()))
385411
mvars = np.array(list(meshdata.variables.keys()))
386-
387-
# Find ncells2D
388-
ncells2D = dims[
389-
mdims[
390-
np.where(
391-
(np.char.find(mdims, "grid_size") > -1)
392-
| (np.char.find(mdims, "ncol") > -1)
393-
)[0][0]
394-
]
395-
].size
412+
413+
# Use the identified horizontal dimension
414+
if not self._horizontal_dim:
415+
print_error("Horizontal dimension not identified in connectivity file")
416+
return
417+
418+
ncells2D = dims[self._horizontal_dim].size
396419
self._cached_ncells2D = ncells2D
397420

398421
# Find lat/lon dimensions
@@ -436,20 +459,35 @@ def _build_geometry(self, meshdata):
436459
)
437460

438461
def _populate_variable_metadata(self):
439-
if self._DataFileName is None:
462+
if self._DataFileName is None or self._ConnFileName is None:
440463
return
464+
465+
meshdata = self._get_mesh_dataset()
441466
vardata = self._get_var_dataset()
442-
467+
468+
# Identify horizontal dimensions first
469+
self._identify_horizontal_dimension(meshdata, vardata)
470+
471+
if not self._data_horizontal_dim:
472+
print_error("Could not detect horizontal dimension in data file")
473+
return
474+
443475
# Clear existing selection arrays BEFORE adding new ones
444476
self._surface_selection.RemoveAllArrays()
445477
self._midpoint_selection.RemoveAllArrays()
446478
self._interface_selection.RemoveAllArrays()
479+
480+
# Define dimension sets dynamically based on detected dimension
481+
dims1 = set([self._data_horizontal_dim])
482+
dims2 = set(['time', self._data_horizontal_dim])
483+
dims3m = set(['time', 'lev', self._data_horizontal_dim])
484+
dims3i = set(['time', 'ilev', self._data_horizontal_dim])
447485

448486
for name, info in vardata.variables.items():
449487
dims = set(info.dimensions)
450488
if not (dims == dims1 or dims == dims2 or dims == dims3m or dims == dims3i):
451489
continue
452-
varmeta = VarMeta(name, info)
490+
varmeta = VarMeta(name, info, self._data_horizontal_dim)
453491
if varmeta.type == VarType._1D:
454492
self._info_vars.append(varmeta)
455493
if "area" in name:
@@ -507,15 +545,18 @@ def SetConnFileName(self, fname):
507545
self._surface_update = True
508546
self._midpoint_update = True
509547
self._interface_update = True
548+
self._clear() # Clear dimension cache
510549
# Close old dataset if filename changed
511550
if self._cached_mesh_filename != fname and self._mesh_dataset is not None:
512551
try:
513552
self._mesh_dataset.close()
514553
except Exception:
515554
pass
516555
self._mesh_dataset = None
517-
# Clear geometry cache when connectivity file changes
518556
self._clear_geometry_cache()
557+
# Re-populate metadata if data file is already set
558+
if self._DataFileName:
559+
self._populate_variable_metadata()
519560
self.Modified()
520561

521562
def SetMiddleLayer(self, lev):
@@ -616,8 +657,19 @@ def RequestData(self, request, inInfo, outInfo):
616657
meshdata = self._get_mesh_dataset()
617658
vardata = self._get_var_dataset()
618659

660+
# Ensure dimensions are identified
661+
self._identify_horizontal_dimension(meshdata, vardata)
662+
663+
if not self._horizontal_dim or not self._data_horizontal_dim:
664+
print_error("Could not identify required dimensions from files")
665+
return 0
666+
619667
# Build geometry if not cached
620668
self._build_geometry(meshdata)
669+
670+
if self._cached_points is None:
671+
print_error("Could not build geometry from connectivity file")
672+
return 0
621673

622674
output_mesh = dsa.WrapDataObject(self._output)
623675

0 commit comments

Comments
 (0)