Skip to content

Commit 1955a87

Browse files
committed
o Fix exodus encoder bug, add tests
1 parent 1594ea1 commit 1955a87

File tree

3 files changed

+127
-15
lines changed

3 files changed

+127
-15
lines changed

test/test_exodus.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,26 @@ def test_mixed_exodus():
3333
"""Read/write an exodus file with two types of faces (triangle and quadrilaterals) and writes a ugrid file."""
3434
uxgrid = ux.open_grid(exo2_filename)
3535

36-
uxgrid.to_xarray("UGRID")
37-
uxgrid.to_xarray("Exodus")
38-
# Add assertions or checks as needed
36+
ugrid_obj = uxgrid.to_xarray("UGRID")
37+
exo_obj = uxgrid.to_xarray("Exodus")
38+
39+
ugrid_obj.to_netcdf("test_ugrid.nc")
40+
exo_obj.to_netcdf("test_exo.exo")
41+
42+
ugrid_load_saved = ux.open_grid("test_ugrid.nc")
43+
exodus_load_saved = ux.open_grid("test_exo.exo")
44+
45+
# Face node connectivity comparison
46+
assert np.array_equal(ugrid_load_saved.face_node_connectivity.values, uxgrid.face_node_connectivity.values)
47+
assert np.array_equal(uxgrid.face_node_connectivity.values, exodus_load_saved.face_node_connectivity.values)
48+
49+
# Node coordinates comparison
50+
assert np.array_equal(ugrid_load_saved.node_lon.values, uxgrid.node_lon.values)
51+
assert np.array_equal(uxgrid.node_lon.values, exodus_load_saved.node_lon.values)
52+
assert np.array_equal(ugrid_load_saved.node_lat.values, uxgrid.node_lat.values)
53+
# Cleanup
54+
os.remove("test_ugrid.nc")
55+
os.remove("test_exo.exo")
3956

4057
def test_standardized_dtype_and_fill():
4158
"""Test to see if Mesh2_Face_Nodes uses the expected integer datatype and expected fill value as set in constants.py."""

test/test_grid.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,107 @@ def test_grid_with_holes():
9696
assert grid_without_holes.global_sphere_coverage
9797

9898

99-
def test_grid_to_xarray():
100-
"""Reads a ugrid file and encodes it as `xarray.Dataset` in various types."""
101-
grid_CSne30.to_xarray("UGRID")
102-
grid_RLL1deg.to_xarray("UGRID")
103-
grid_RLL10deg_CSne4.to_xarray("UGRID")
104-
105-
grid_CSne30.to_xarray("Exodus")
106-
grid_RLL1deg.to_xarray("Exodus")
107-
grid_RLL10deg_CSne4.to_xarray("Exodus")
99+
def test_grid_ugrid_exodus_roundtrip():
100+
"""Test round-trip serialization of grid objects through UGRID and Exodus xarray formats.
101+
102+
Validates that grid objects can be successfully converted to xarray.Dataset
103+
objects in both UGRID and Exodus formats, serialized to disk, and reloaded
104+
while maintaining numerical accuracy and topological integrity.
105+
106+
The test verifies:
107+
- Successful conversion to UGRID and Exodus xarray formats
108+
- File I/O round-trip consistency
109+
- Preservation of face-node connectivity (exact)
110+
- Preservation of node coordinates (within numerical tolerance)
111+
112+
Raises:
113+
AssertionError: If any round-trip validation fails
114+
"""
115+
116+
# Convert grids to xarray.Dataset objects in different formats
117+
ugrid_datasets = {
118+
'CSne30': grid_CSne30.to_xarray("UGRID"),
119+
'RLL1deg': grid_RLL1deg.to_xarray("UGRID"),
120+
'RLL10deg_CSne4': grid_RLL10deg_CSne4.to_xarray("UGRID")
121+
}
122+
123+
exodus_datasets = {
124+
'CSne30': grid_CSne30.to_xarray("Exodus"),
125+
'RLL1deg': grid_RLL1deg.to_xarray("Exodus"),
126+
'RLL10deg_CSne4': grid_RLL10deg_CSne4.to_xarray("Exodus")
127+
}
128+
129+
# Define test cases with corresponding grid objects
130+
test_grids = {
131+
'CSne30': grid_CSne30,
132+
'RLL1deg': grid_RLL1deg,
133+
'RLL10deg_CSne4': grid_RLL10deg_CSne4
134+
}
135+
136+
# Perform round-trip validation for each grid type
137+
test_files = []
138+
139+
for grid_name in test_grids.keys():
140+
ugrid_dataset = ugrid_datasets[grid_name]
141+
exodus_dataset = exodus_datasets[grid_name]
142+
original_grid = test_grids[grid_name]
143+
144+
# Define output file paths
145+
ugrid_filepath = f"test_ugrid_{grid_name}.nc"
146+
exodus_filepath = f"test_exodus_{grid_name}.exo"
147+
148+
# Serialize datasets to disk
149+
ugrid_dataset.to_netcdf(ugrid_filepath)
150+
exodus_dataset.to_netcdf(exodus_filepath)
151+
152+
# Reload grids from serialized files
153+
reloaded_ugrid = ux.open_grid(ugrid_filepath)
154+
reloaded_exodus = ux.open_grid(exodus_filepath)
155+
156+
# Validate topological consistency (face-node connectivity)
157+
# Integer connectivity arrays must be exactly preserved
158+
np.testing.assert_array_equal(
159+
original_grid.face_node_connectivity.values,
160+
reloaded_ugrid.face_node_connectivity.values,
161+
err_msg=f"UGRID face connectivity mismatch for {grid_name}"
162+
)
163+
np.testing.assert_array_equal(
164+
original_grid.face_node_connectivity.values,
165+
reloaded_exodus.face_node_connectivity.values,
166+
err_msg=f"Exodus face connectivity mismatch for {grid_name}"
167+
)
168+
169+
# Validate coordinate consistency with numerical tolerance
170+
# Coordinate transformations and I/O precision may introduce minor differences
171+
np.testing.assert_allclose(
172+
original_grid.node_lon.values,
173+
reloaded_ugrid.node_lon.values,
174+
err_msg=f"UGRID longitude mismatch for {grid_name}",
175+
rtol=ERROR_TOLERANCE
176+
)
177+
np.testing.assert_allclose(
178+
original_grid.node_lon.values,
179+
reloaded_exodus.node_lon.values,
180+
err_msg=f"Exodus longitude mismatch for {grid_name}",
181+
rtol=ERROR_TOLERANCE
182+
)
183+
np.testing.assert_allclose(
184+
original_grid.node_lat.values,
185+
reloaded_ugrid.node_lat.values,
186+
err_msg=f"UGRID latitude mismatch for {grid_name}",
187+
rtol=ERROR_TOLERANCE
188+
)
189+
np.testing.assert_allclose(
190+
original_grid.node_lat.values,
191+
reloaded_exodus.node_lat.values,
192+
err_msg=f"Exodus latitude mismatch for {grid_name}",
193+
rtol=ERROR_TOLERANCE
194+
)
195+
196+
# Clean up temporary test files
197+
for filepath in test_files:
198+
if os.path.exists(filepath):
199+
os.remove(filepath)
108200

109201

110202
def test_grid_init_verts():

uxarray/io/_exodus.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def _read_exodus(ext_ds):
5656
)
5757
elif key == "coordy":
5858
ds["node_y"] = xr.DataArray(
59-
data=ext_ds.coordx, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Y_ATTRS
59+
data=ext_ds.coordy, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Y_ATTRS
6060
)
6161
elif key == "coordz":
6262
if ext_ds.sizes["num_dim"] > 2:
6363
ds["node_z"] = xr.DataArray(
64-
data=ext_ds.coordx, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Z_ATTRS
64+
data=ext_ds.coordz, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_Z_ATTRS
6565
)
6666
elif "connect" in key:
6767
# check if num face nodes is less than max.
@@ -161,7 +161,10 @@ def _encode_exodus(ds, outfile=None):
161161

162162
# --- Node Coordinates ---
163163
if "node_x" not in ds:
164-
x, y, z = _lonlat_rad_to_xyz(ds["node_lon"].values, ds["node_lat"].values)
164+
print("HERE", ds["node_lon"].values)
165+
node_lon_rad = np.deg2rad(ds["node_lon"].values)
166+
node_lat_rad = np.deg2rad(ds["node_lat"].values)
167+
x, y, z = _lonlat_rad_to_xyz(node_lon_rad, node_lat_rad)
165168
c_data = np.array([x, y, z])
166169
else:
167170
c_data = np.array(

0 commit comments

Comments
 (0)