Skip to content

Commit f7ff1ba

Browse files
committed
o When handling different element types - Fix the Exodus reader. Add num_elem to Exodus writer, addresses #1242. o Modify tests - they are far from nice good, needs a complete overhaul of all io module
1 parent dea22c9 commit f7ff1ba

File tree

7 files changed

+210
-46
lines changed

7 files changed

+210
-46
lines changed

test/test_esmf.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from pathlib import Path
44
import pytest
55
import xarray as xr
6-
6+
import numpy as np
7+
from uxarray.constants import ERROR_TOLERANCE
78
current_path = Path(os.path.dirname(os.path.realpath(__file__)))
89

910
esmf_ne30_grid_path = current_path / 'meshfiles' / "esmf" / "ne30" / "ne30pg3.grid.nc"
@@ -38,10 +39,71 @@ def test_read_esmf_dataset():
3839
for dim in dims:
3940
assert dim in uxds.dims
4041

41-
def test_write_esmf():
42-
"""Tests the writing of a UxDataset to an ESMF Grid file."""
43-
uxds = ux.open_grid(gridfile_ne30)
44-
out_ds = uxds.to_xarray("ESMF")
45-
assert isinstance(out_ds, xr.Dataset)
46-
assert 'nodeCoords' in out_ds
47-
assert 'elementConn' in out_ds
42+
def test_esmf_round_trip_consistency():
43+
"""Test round-trip serialization of grid objects through ESMF xarray format.
44+
45+
Validates that grid objects can be successfully converted to ESMF xarray.Dataset
46+
format, serialized to disk, and reloaded while maintaining numerical accuracy
47+
and topological integrity.
48+
49+
The test verifies:
50+
- Successful conversion to ESMF xarray format
51+
- File I/O round-trip consistency
52+
- Preservation of face-node connectivity (exact)
53+
- Preservation of node coordinates (within numerical tolerance)
54+
55+
Raises:
56+
AssertionError: If any round-trip validation fails
57+
"""
58+
# Load original grid
59+
original_grid = ux.open_grid(gridfile_ne30)
60+
61+
# Convert to ESMF xarray format
62+
esmf_dataset = original_grid.to_xarray("ESMF")
63+
64+
# Verify dataset structure
65+
assert isinstance(esmf_dataset, xr.Dataset)
66+
assert 'nodeCoords' in esmf_dataset
67+
assert 'elementConn' in esmf_dataset
68+
69+
# Define output file path
70+
esmf_filepath = "test_esmf_ne30.nc"
71+
72+
# Remove existing test file to ensure clean state
73+
if os.path.exists(esmf_filepath):
74+
os.remove(esmf_filepath)
75+
76+
try:
77+
# Serialize dataset to disk
78+
esmf_dataset.to_netcdf(esmf_filepath)
79+
80+
# Reload grid from serialized file
81+
reloaded_grid = ux.open_grid(esmf_filepath)
82+
83+
# Validate topological consistency (face-node connectivity)
84+
# Integer connectivity arrays must be exactly preserved
85+
np.testing.assert_array_equal(
86+
original_grid.face_node_connectivity.values,
87+
reloaded_grid.face_node_connectivity.values,
88+
err_msg="ESMF face connectivity mismatch"
89+
)
90+
91+
# Validate coordinate consistency with numerical tolerance
92+
# Coordinate transformations and I/O precision may introduce minor differences
93+
np.testing.assert_allclose(
94+
original_grid.node_lon.values,
95+
reloaded_grid.node_lon.values,
96+
err_msg="ESMF longitude mismatch",
97+
rtol=ERROR_TOLERANCE
98+
)
99+
np.testing.assert_allclose(
100+
original_grid.node_lat.values,
101+
reloaded_grid.node_lat.values,
102+
err_msg="ESMF latitude mismatch",
103+
rtol=ERROR_TOLERANCE
104+
)
105+
106+
finally:
107+
# Clean up temporary test file
108+
if os.path.exists(esmf_filepath):
109+
os.remove(esmf_filepath)

test/test_exodus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def test_mixed_exodus():
5050
assert np.array_equal(ugrid_load_saved.node_lon.values, uxgrid.node_lon.values)
5151
assert np.array_equal(uxgrid.node_lon.values, exodus_load_saved.node_lon.values)
5252
assert np.array_equal(ugrid_load_saved.node_lat.values, uxgrid.node_lat.values)
53+
5354
# Cleanup
5455
os.remove("test_ugrid.nc")
5556
os.remove("test_exo.exo")

test/test_grid.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def test_grid_ugrid_exodus_roundtrip():
144144
# Define output file paths
145145
ugrid_filepath = f"test_ugrid_{grid_name}.nc"
146146
exodus_filepath = f"test_exodus_{grid_name}.exo"
147+
test_files.append(ugrid_filepath)
148+
test_files.append(exodus_filepath)
147149

148150
# Serialize datasets to disk
149151
ugrid_dataset.to_netcdf(ugrid_filepath)

test/test_healpix.py

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import xarray as xr
77
import pandas as pd
88
from pathlib import Path
9+
from uxarray.constants import ERROR_TOLERANCE
910

1011

1112
current_path = Path(os.path.dirname(os.path.realpath(__file__)))
@@ -100,35 +101,88 @@ def test_invalid_cells():
100101
with pytest.raises(ValueError):
101102
uxda = ux.UxDataset.from_healpix(xrda)
102103

103-
def test_healpix_to_netcdf(tmp_path):
104-
"""Test that HEALPix grid can be encoded as UGRID and saved to netCDF.
105-
Using pytest tmp_path fixture to create a temporary file.
104+
def test_healpix_round_trip_consistency(tmp_path):
105+
"""Test round-trip serialization of HEALPix grid through UGRID and Exodus formats.
106+
107+
Validates that HEALPix grid objects can be successfully converted to xarray.Dataset
108+
objects in both UGRID and Exodus formats, serialized to disk, and reloaded
109+
while maintaining numerical accuracy and topological integrity.
110+
111+
Args:
112+
tmp_path: pytest fixture providing temporary directory
113+
114+
Raises:
115+
AssertionError: If any round-trip validation fails
106116
"""
107117
# Create HEALPix grid
108-
h = ux.Grid.from_healpix(zoom=3)
118+
original_grid = ux.Grid.from_healpix(zoom=3)
109119

110120
# Access node coordinates to ensure they're generated before encoding
111-
_ = h.node_lon
112-
_ = h.node_lat
113-
114-
# Convert to different formats
115-
uxa_ugrid = h.to_xarray("UGRID")
116-
uxa_exodus = h.to_xarray("Exodus")
117-
118-
tmp_filename_ugrid = tmp_path / "healpix_test_ugrid.nc"
119-
tmp_filename_exodus = tmp_path / "healpix_test_exodus.exo"
120-
121-
# Save to netCDF
122-
uxa_ugrid.to_netcdf(tmp_filename_ugrid)
123-
uxa_exodus.to_netcdf(tmp_filename_exodus)
124-
125-
# Assertions
126-
assert tmp_filename_ugrid.exists()
127-
assert tmp_filename_ugrid.stat().st_size > 0
128-
assert tmp_filename_exodus.exists()
129-
assert tmp_filename_exodus.stat().st_size > 0
130-
131-
loaded_grid_ugrid = ux.open_grid(tmp_filename_ugrid)
132-
loaded_grid_exodus = ux.open_grid(tmp_filename_exodus)
133-
assert loaded_grid_ugrid.n_face == h.n_face
134-
assert loaded_grid_exodus.n_face == h.n_face
121+
_ = original_grid.node_lon
122+
_ = original_grid.node_lat
123+
124+
# Convert to xarray.Dataset objects in different formats
125+
ugrid_dataset = original_grid.to_xarray("UGRID")
126+
exodus_dataset = original_grid.to_xarray("Exodus")
127+
128+
# Define output file paths using tmp_path fixture
129+
ugrid_filepath = tmp_path / "healpix_test_ugrid.nc"
130+
exodus_filepath = tmp_path / "healpix_test_exodus.exo"
131+
132+
# Serialize datasets to disk
133+
ugrid_dataset.to_netcdf(ugrid_filepath)
134+
exodus_dataset.to_netcdf(exodus_filepath)
135+
136+
# Verify files were created successfully
137+
assert ugrid_filepath.exists()
138+
assert ugrid_filepath.stat().st_size > 0
139+
assert exodus_filepath.exists()
140+
assert exodus_filepath.stat().st_size > 0
141+
142+
# Reload grids from serialized files
143+
reloaded_ugrid = ux.open_grid(ugrid_filepath)
144+
reloaded_exodus = ux.open_grid(exodus_filepath)
145+
146+
# Validate topological consistency (face-node connectivity)
147+
# Integer connectivity arrays must be exactly preserved
148+
np.testing.assert_array_equal(
149+
original_grid.face_node_connectivity.values,
150+
reloaded_ugrid.face_node_connectivity.values,
151+
err_msg="UGRID face connectivity mismatch for HEALPix"
152+
)
153+
np.testing.assert_array_equal(
154+
original_grid.face_node_connectivity.values,
155+
reloaded_exodus.face_node_connectivity.values,
156+
err_msg="Exodus face connectivity mismatch for HEALPix"
157+
)
158+
159+
# Validate coordinate consistency with numerical tolerance
160+
# Coordinate transformations and I/O precision may introduce minor differences
161+
np.testing.assert_allclose(
162+
original_grid.node_lon.values,
163+
reloaded_ugrid.node_lon.values,
164+
err_msg="UGRID longitude mismatch for HEALPix",
165+
rtol=ERROR_TOLERANCE
166+
)
167+
np.testing.assert_allclose(
168+
original_grid.node_lon.values,
169+
reloaded_exodus.node_lon.values,
170+
err_msg="Exodus longitude mismatch for HEALPix",
171+
rtol=ERROR_TOLERANCE
172+
)
173+
np.testing.assert_allclose(
174+
original_grid.node_lat.values,
175+
reloaded_ugrid.node_lat.values,
176+
err_msg="UGRID latitude mismatch for HEALPix",
177+
rtol=ERROR_TOLERANCE
178+
)
179+
np.testing.assert_allclose(
180+
original_grid.node_lat.values,
181+
reloaded_exodus.node_lat.values,
182+
err_msg="Exodus latitude mismatch for HEALPix",
183+
rtol=ERROR_TOLERANCE
184+
)
185+
186+
# Validate grid dimensions are preserved
187+
assert reloaded_ugrid.n_face == original_grid.n_face
188+
assert reloaded_exodus.n_face == original_grid.n_face

test/test_scrip.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
gridfile_RLL1deg = current_path / "meshfiles" / "ugrid" / "outRLL1deg" / "outRLL1deg.ug"
2121
gridfile_RLL10deg_ne4 = current_path / "meshfiles" / "ugrid" / "ov_RLL10deg_CSne4" / "ov_RLL10deg_CSne4.ug"
2222
gridfile_exo_ne8 = current_path / "meshfiles" / "exodus" / "outCSne8" / "outCSne8.g"
23+
gridfile_scrip = current_path / "meshfiles" / "scrip" / "outCSne8" /"outCSne8.nc"
2324

2425
def test_read_ugrid():
2526
"""Reads a ugrid file."""
@@ -47,8 +48,20 @@ def test_read_ugrid():
4748

4849
def test_to_xarray_ugrid():
4950
"""Read an Exodus dataset and convert it to UGRID format using to_xarray."""
50-
ux_grid = ux.open_grid(gridfile_exo_ne8)
51-
ux_grid.to_xarray("UGRID")
51+
ux_grid = ux.open_grid(gridfile_scrip)
52+
xr_obj = ux_grid.to_xarray("UGRID")
53+
xr_obj.to_netcdf("scrip_ugrid_csne8.nc")
54+
reloaded_grid = ux.open_grid("scrip_ugrid_csne8.nc")
55+
# Check that the grid topology is perfectly preserved
56+
nt.assert_array_equal(ux_grid.face_node_connectivity.values,
57+
reloaded_grid.face_node_connectivity.values)
58+
59+
# Check that node coordinates are numerically close
60+
nt.assert_allclose(ux_grid.node_lon.values, reloaded_grid.node_lon.values)
61+
nt.assert_allclose(ux_grid.node_lat.values, reloaded_grid.node_lat.values)
62+
63+
# Cleanup
64+
os.remove("scrip_ugrid_csne8.nc")
5265

5366
def test_standardized_dtype_and_fill():
5467
"""Test to see if Mesh2_Face_Nodes uses the expected integer datatype

test/test_ugrid.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,19 @@ def test_read_ugrid():
4646
def test_to_xarray_ugrid():
4747
"""Read an Exodus dataset and convert it to UGRID format using to_xarray."""
4848
ux_grid = ux.open_grid(gridfile_exo_ne8)
49-
ux_grid.to_xarray("UGRID")
49+
xr_obj = ux_grid.to_xarray("UGRID")
50+
xr_obj.to_netcdf("ugrid_exo_csne8.nc")
51+
reloaded_grid = ux.open_grid("ugrid_exo_csne8.nc")
52+
# Check that the grid topology is perfectly preserved
53+
nt.assert_array_equal(ux_grid.face_node_connectivity.values,
54+
reloaded_grid.face_node_connectivity.values)
55+
56+
# Check that node coordinates are numerically close
57+
nt.assert_allclose(ux_grid.node_lon.values, reloaded_grid.node_lon.values)
58+
nt.assert_allclose(ux_grid.node_lat.values, reloaded_grid.node_lat.values)
59+
60+
# Cleanup
61+
os.remove("ugrid_exo_csne8.nc")
5062

5163
def test_standardized_dtype_and_fill():
5264
"""Test to see if Mesh2_Face_Nodes uses the expected integer datatype and expected fill value."""

uxarray/io/_exodus.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,36 @@ def _read_exodus(ext_ds):
6464
data=ext_ds.coordz, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Z_ATTRS
6565
)
6666
elif "connect" in key:
67-
# check if num face nodes is less than max.
68-
if value.data.shape[1] <= max_face_nodes:
69-
face_nodes = value
70-
else:
71-
raise RuntimeError("found face_nodes_dim greater than n_max_face_nodes")
67+
# This variable will be populated in the next step
68+
pass
7269

7370
# outside the k,v for loop
7471
# set the face nodes data compiled in "connect" section
72+
connect_list = []
73+
for key, value in ext_ds.variables.items():
74+
if "connect" in key:
75+
connect_list.append(value.data)
76+
77+
padded_blocks = []
78+
for block in connect_list:
79+
num_nodes = block.shape[1]
80+
pad_width = max_face_nodes - num_nodes
81+
82+
# Pad with 0, as Exodus uses 0 for non-existent nodes
83+
padded_block = np.pad(
84+
block, ((0, 0), (0, pad_width)), "constant", constant_values=0
85+
)
86+
padded_blocks.append(padded_block)
87+
88+
# Prevent error on empty grids
89+
if not padded_blocks:
90+
face_nodes = np.empty((0, max_face_nodes), dtype=INT_DTYPE)
91+
else:
92+
face_nodes = np.vstack(padded_blocks)
7593

7694
# standardize fill values and data type face nodes
7795
face_nodes = _replace_fill_values(
78-
grid_var=face_nodes[:] - 1,
96+
grid_var=xr.DataArray(face_nodes - 1), # Wrap numpy array in a DataArray
7997
original_fill=-1,
8098
new_fill=INT_FILL_VALUE,
8199
new_dtype=INT_DTYPE,
@@ -151,6 +169,9 @@ def _encode_exodus(ds, outfile=None):
151169

152170
exo_ds["time_whole"] = xr.DataArray(data=[], dims=["time_step"])
153171

172+
# --- Add num_elem dimension ---
173+
exo_ds.attrs["num_elem"] = ds.sizes["n_face"]
174+
154175
# --- QA Records ---
155176
ux_exodus_version = "1.0"
156177
qa_records = [["uxarray"], [ux_exodus_version], [date], [time]]
@@ -161,7 +182,6 @@ def _encode_exodus(ds, outfile=None):
161182

162183
# --- Node Coordinates ---
163184
if "node_x" not in ds:
164-
print("HERE", ds["node_lon"].values)
165185
node_lon_rad = np.deg2rad(ds["node_lon"].values)
166186
node_lat_rad = np.deg2rad(ds["node_lat"].values)
167187
x, y, z = _lonlat_rad_to_xyz(node_lon_rad, node_lat_rad)

0 commit comments

Comments
 (0)