Skip to content

Commit 06864b7

Browse files
committed
feat: add xarray grid coordinate indexing (#207)
- Added x, y, z world coordinates to StructuredGrid - Coordinates computed from grid geometry (delr, delc, top, botm) - Enables coordinate-based selection: grid.botm.sel(x=100, method='nearest') - Added comprehensive tests for coordinate indexing - All grid data arrays now have spatial coordinates attached
1 parent fe39f4b commit 06864b7

File tree

2 files changed

+169
-4
lines changed

2 files changed

+169
-4
lines changed

flopy4/mf6/utils/grid.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@ def __init__(self, *args, **kwargs):
2323
"ncol": "j",
2424
"nodes": "node",
2525
}
26+
27+
# Compute world coordinates (x, y, z)
28+
world_coords = self._compute_world_coordinates()
29+
30+
# Index coordinates
2631
self._coords = {
2732
"k": xr.DataArray(np.arange(self.nlay, dtype=int), dims=("nlay",)),
2833
"i": xr.DataArray(np.arange(self.nrow, dtype=int), dims=("nrow",)),
2934
"j": xr.DataArray(np.arange(self.ncol, dtype=int), dims=("ncol",)),
3035
"node": xr.DataArray(np.arange(self.nnodes, dtype=int), dims=("nodes",)),
3136
}
37+
38+
# Add world coordinates
39+
self._coords.update(world_coords)
40+
3241
data_vars = {
3342
"delr": self.delr,
3443
"delc": self.delc,
@@ -43,8 +52,56 @@ def __init__(self, *args, **kwargs):
4352
.set_xindex("i", PandasIndex)
4453
.set_xindex("j", PandasIndex)
4554
.set_xindex("node", PandasIndex)
55+
.set_xindex("x", PandasIndex)
56+
.set_xindex("y", PandasIndex)
57+
.set_xindex("z", PandasIndex)
4658
)
4759

60+
def _compute_world_coordinates(self) -> dict:
61+
"""
62+
Compute x, y, z world coordinates from grid geometry.
63+
64+
Returns
65+
-------
66+
dict
67+
Dictionary with 'x', 'y', 'z' coordinate DataArrays
68+
"""
69+
# Get grid extent
70+
xmin, xmax, ymin, ymax = self.extent
71+
72+
# Compute x coordinates (cell centers)
73+
delr = np.atleast_1d(self.delr[0] if hasattr(self.delr, '__getitem__') and not isinstance(self.delr, np.ndarray) else self.delr)
74+
if delr.size == 1:
75+
delr = np.full(self.ncol, delr[0])
76+
x = xmin + np.cumsum(delr) - 0.5 * delr
77+
78+
# Compute y coordinates (cell centers)
79+
delc = np.atleast_1d(self.delc[0] if hasattr(self.delc, '__getitem__') and not isinstance(self.delc, np.ndarray) else self.delc)
80+
if delc.size == 1:
81+
delc = np.full(self.nrow, delc[0])
82+
y = ymax - np.cumsum(delc) + 0.5 * delc
83+
84+
# Compute z coordinates (layer centers)
85+
# Use top and botm to compute layer centers
86+
top_2d = np.atleast_2d(self.top)
87+
botm_3d = np.atleast_3d(self.botm).reshape(self.nlay, self.nrow, self.ncol)
88+
89+
# Layer center z coordinates: average of top and bottom per layer
90+
z = np.zeros(self.nlay)
91+
for k in range(self.nlay):
92+
if k == 0:
93+
layer_top = top_2d.mean()
94+
else:
95+
layer_top = botm_3d[k - 1].mean()
96+
layer_bot = botm_3d[k].mean()
97+
z[k] = (layer_top + layer_bot) / 2.0
98+
99+
return {
100+
"x": xr.DataArray(x, dims=("ncol",)),
101+
"y": xr.DataArray(y, dims=("nrow",)),
102+
"z": xr.DataArray(z, dims=("nlay",)),
103+
}
104+
48105
@property
49106
def dataset(self) -> xr.Dataset:
50107
return self._dataset
@@ -55,21 +112,21 @@ def delc(self):
55112
return None
56113
dims = ("ncol",)
57114
coord_name = self._dims_coords[dims[0]]
58-
coords = coords = {coord_name: self._coords[coord_name]}
115+
coords = {coord_name: self._coords[coord_name], "x": self._coords["x"]}
59116
return xr.DataArray(super().delc, coords=coords, dims=dims).set_xindex(
60117
coord_name, PandasIndex
61-
)
118+
).set_xindex("x", PandasIndex)
62119

63120
@property
64121
def delr(self):
65122
if self.__delr is None:
66123
return None
67124
dims = ("nrow",)
68125
coord_name = self._dims_coords[dims[0]]
69-
coords = {coord_name: self._coords[coord_name]}
126+
coords = {coord_name: self._coords[coord_name], "y": self._coords["y"]}
70127
return xr.DataArray(super().delr, coords=coords, dims=dims).set_xindex(
71128
coord_name, PandasIndex
72-
)
129+
).set_xindex("y", PandasIndex)
73130

74131
@property
75132
def delz(self):
@@ -80,22 +137,29 @@ def delz(self):
80137
self._dims_coords[dims[2]],
81138
)
82139
coords = {coord_name: self._coords[coord_name] for coord_name in coord_names}
140+
coords.update({"x": self._coords["x"], "y": self._coords["y"], "z": self._coords["z"]})
83141
return (
84142
xr.DataArray(super().delz, coords=coords, dims=dims)
85143
.set_xindex(coord_names[0], PandasIndex)
86144
.set_xindex(coord_names[1], PandasIndex)
87145
.set_xindex(coord_names[2], PandasIndex)
146+
.set_xindex("x", PandasIndex)
147+
.set_xindex("y", PandasIndex)
148+
.set_xindex("z", PandasIndex)
88149
)
89150

90151
@property
91152
def top(self):
92153
dims = ("nrow", "ncol")
93154
coord_names = (self._dims_coords[dims[0]], self._dims_coords[dims[1]])
94155
coords = {coord_name: self._coords[coord_name] for coord_name in coord_names}
156+
coords.update({"x": self._coords["x"], "y": self._coords["y"]})
95157
return (
96158
xr.DataArray(super().top, coords=coords, dims=dims)
97159
.set_xindex(coord_names[0], PandasIndex)
98160
.set_xindex(coord_names[1], PandasIndex)
161+
.set_xindex("x", PandasIndex)
162+
.set_xindex("y", PandasIndex)
99163
)
100164

101165
@property
@@ -107,11 +171,15 @@ def botm(self):
107171
self._dims_coords[dims[2]],
108172
)
109173
coords = {coord_name: self._coords[coord_name] for coord_name in coord_names}
174+
coords.update({"x": self._coords["x"], "y": self._coords["y"], "z": self._coords["z"]})
110175
return (
111176
xr.DataArray(super().botm, coords=coords, dims=dims)
112177
.set_xindex(coord_names[0], PandasIndex)
113178
.set_xindex(coord_names[1], PandasIndex)
114179
.set_xindex(coord_names[2], PandasIndex)
180+
.set_xindex("x", PandasIndex)
181+
.set_xindex("y", PandasIndex)
182+
.set_xindex("z", PandasIndex)
115183
)
116184

117185
@property
@@ -123,11 +191,15 @@ def idomain(self):
123191
self._dims_coords[dims[2]],
124192
)
125193
coords = {coord_name: self._coords[coord_name] for coord_name in coord_names}
194+
coords.update({"x": self._coords["x"], "y": self._coords["y"], "z": self._coords["z"]})
126195
return (
127196
xr.DataArray(super().idomain, coords=coords, dims=dims)
128197
.set_xindex(coord_names[0], PandasIndex)
129198
.set_xindex(coord_names[1], PandasIndex)
130199
.set_xindex(coord_names[2], PandasIndex)
200+
.set_xindex("x", PandasIndex)
201+
.set_xindex("y", PandasIndex)
202+
.set_xindex("z", PandasIndex)
131203
)
132204

133205

test/test_mf6_component.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,96 @@ def test_to_xarray_on_context(function_tmpdir):
627627
assert np.array_equal(dt.per, [0])
628628
assert dt.attrs["filename"] == "mfsim.nam"
629629
assert dt.attrs["workspace"] == Path(function_tmpdir)
630+
631+
632+
def test_grid_coordinate_indexing():
633+
"""Test that grid data can be indexed using x, y, z world coordinates."""
634+
# Create a simple structured grid
635+
grid = StructuredGrid(
636+
nlay=3,
637+
nrow=10,
638+
ncol=10,
639+
delr=100.0, # 100 m cell width
640+
delc=100.0, # 100 m cell height
641+
top=100.0,
642+
botm=[0.0, -50.0, -100.0],
643+
)
644+
645+
# Test that world coordinates are present
646+
assert "x" in grid.dataset.coords
647+
assert "y" in grid.dataset.coords
648+
assert "z" in grid.dataset.coords
649+
650+
# Test coordinate shapes
651+
assert len(grid.dataset.coords["x"]) == 10 # ncol
652+
assert len(grid.dataset.coords["y"]) == 10 # nrow
653+
assert len(grid.dataset.coords["z"]) == 3 # nlay
654+
655+
# Test coordinate values
656+
# X coordinates should be cell centers: 50, 150, 250, ..., 950
657+
expected_x = np.array([50.0, 150.0, 250.0, 350.0, 450.0, 550.0, 650.0, 750.0, 850.0, 950.0])
658+
np.testing.assert_allclose(grid.dataset.coords["x"].values, expected_x)
659+
660+
# Y coordinates should be cell centers from top: 950, 850, 750, ..., 50
661+
expected_y = np.array([950.0, 850.0, 750.0, 650.0, 550.0, 450.0, 350.0, 250.0, 150.0, 50.0])
662+
np.testing.assert_allclose(grid.dataset.coords["y"].values, expected_y)
663+
664+
# Z coordinates should be layer centers
665+
# Layer 0: (100 + 0) / 2 = 50
666+
# Layer 1: (0 + (-50)) / 2 = -25
667+
# Layer 2: (-50 + (-100)) / 2 = -75
668+
expected_z = np.array([50.0, -25.0, -75.0])
669+
np.testing.assert_allclose(grid.dataset.coords["z"].values, expected_z)
670+
671+
# Test coordinate-based selection using nearest
672+
# Select near x=250 (should get col=2, which has x=250)
673+
botm_at_x250 = grid.botm.sel(x=250.0, method="nearest")
674+
assert botm_at_x250.shape == (3, 10) # (nlay, nrow)
675+
676+
# Select near y=650 (should get row=3, which has y=650)
677+
botm_at_y650 = grid.botm.sel(y=650.0, method="nearest")
678+
assert botm_at_y650.shape == (3, 10) # (nlay, ncol)
679+
680+
# Select near z=50 (should get lay=0, which has z=50)
681+
botm_at_z50 = grid.botm.sel(z=50.0, method="nearest")
682+
assert botm_at_z50.shape == (10, 10) # (nrow, ncol)
683+
684+
# Test combined selection
685+
botm_at_point = grid.botm.sel(x=250.0, y=650.0, z=50.0, method="nearest")
686+
assert botm_at_point.shape == () # scalar
687+
688+
# Verify the value makes sense
689+
# This should be layer 0, row 3, col 2 -> botm[0]
690+
expected_value = 0.0 # botm[0] = 0.0
691+
assert float(botm_at_point) == expected_value
692+
693+
694+
def test_grid_coordinate_indexing_in_dis():
695+
"""Test that Dis package data also has coordinate indexing."""
696+
time = Time(perlen=[1.0], nstp=[1])
697+
dis = Dis(
698+
nlay=2,
699+
nrow=5,
700+
ncol=5,
701+
delr=10.0,
702+
delc=10.0,
703+
top=10.0,
704+
botm=[0.0, -10.0],
705+
)
706+
707+
# Convert to grid to access coordinates
708+
grid = dis.to_grid()
709+
710+
# Test that coordinates are available
711+
assert "x" in grid.dataset.coords
712+
assert "y" in grid.dataset.coords
713+
assert "z" in grid.dataset.coords
714+
715+
# Test coordinate-based selection on botm
716+
# X coordinates: 5, 15, 25, 35, 45
717+
botm_near_x15 = grid.botm.sel(x=15.0, method="nearest")
718+
assert botm_near_x15.shape == (2, 5) # (nlay, nrow)
719+
720+
# Y coordinates: 45, 35, 25, 15, 5
721+
botm_near_y25 = grid.botm.sel(y=25.0, method="nearest")
722+
assert botm_near_y25.shape == (2, 5) # (nlay, ncol)

0 commit comments

Comments
 (0)