Skip to content

Commit 2f24768

Browse files
authored
feat(grid): add array support to .intersect() (#2646)
Add vectorized array support to the intersect() method on all grid types
1 parent 1aae548 commit 2f24768

File tree

4 files changed

+472
-129
lines changed

4 files changed

+472
-129
lines changed

autotest/test_grid.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from matplotlib import pyplot as plt
1212
from modflow_devtools.markers import requires_exe, requires_pkg
1313
from modflow_devtools.misc import has_pkg
14+
from scipy.spatial import Delaunay
1415

1516
from autotest.test_dis_cases import case_dis, case_disv
1617
from autotest.test_grid_cases import GridCases
@@ -77,6 +78,17 @@ def minimal_vertex_grid_info(minimal_unstructured_grid_info):
7778
return d
7879

7980

81+
@pytest.fixture
82+
def simple_structured_grid():
83+
"""Create a simple 10x10 structured grid for testing."""
84+
nrow, ncol = 10, 10
85+
delr = np.ones(ncol) * 10.0
86+
delc = np.ones(nrow) * 10.0
87+
top = np.ones((nrow, ncol)) * 10.0
88+
botm = np.zeros((1, nrow, ncol))
89+
return StructuredGrid(delr=delr, delc=delc, top=top, botm=botm)
90+
91+
8092
def test_rotation():
8193
m = Modflow(rotation=20.0)
8294
dis = ModflowDis(
@@ -419,6 +431,127 @@ def test_unstructured_xyz_intersect(example_data_path):
419431
raise AssertionError("Unstructured grid intersection failed")
420432

421433

434+
def test_structured_grid_intersect_array(simple_structured_grid):
435+
"""Test StructuredGrid.intersect() with array inputs."""
436+
grid = simple_structured_grid
437+
438+
# Test array input
439+
x = np.array([25.0, 50.0, 75.0])
440+
y = np.array([25.0, 50.0, 75.0])
441+
rows, cols = grid.intersect(x, y, forgive=True)
442+
443+
# Verify array output
444+
assert isinstance(rows, np.ndarray)
445+
assert isinstance(cols, np.ndarray)
446+
assert len(rows) == 3
447+
assert len(cols) == 3
448+
449+
# Verify equivalence with scalar calls
450+
for i in range(len(x)):
451+
row_scalar, col_scalar = grid.intersect(x[i], y[i], forgive=True)
452+
assert rows[i] == row_scalar
453+
assert cols[i] == col_scalar
454+
455+
# Test out-of-bounds with forgive
456+
x_mixed = np.array([50.0, 150.0])
457+
y_mixed = np.array([50.0, 50.0])
458+
rows_mixed, cols_mixed = grid.intersect(x_mixed, y_mixed, forgive=True)
459+
assert not np.isnan(rows_mixed[0]) # First point valid
460+
assert np.isnan(rows_mixed[1]) # Second point out of bounds
461+
462+
463+
def test_vertex_grid_intersect_array():
464+
"""Test VertexGrid.intersect() with array inputs."""
465+
# Create a simple vertex grid using Delaunay triangulation
466+
np.random.seed(42)
467+
n_points = 50
468+
x_verts = np.random.uniform(0, 100, n_points)
469+
y_verts = np.random.uniform(0, 100, n_points)
470+
points = np.column_stack([x_verts, y_verts])
471+
472+
tri = Delaunay(points)
473+
vertices = [[i, x_verts[i], y_verts[i]] for i in range(len(x_verts))]
474+
cell2d = [[i] + list(tri.simplices[i]) for i in range(len(tri.simplices))]
475+
476+
ncells = len(cell2d)
477+
top = np.ones(ncells) * 10.0
478+
botm = np.zeros(ncells)
479+
grid = VertexGrid(vertices=vertices, cell2d=cell2d, top=top, botm=botm)
480+
481+
# Test array input
482+
x = np.array([25.0, 50.0, 75.0])
483+
y = np.array([25.0, 50.0, 75.0])
484+
results = grid.intersect(x, y, forgive=True)
485+
486+
# Verify array output
487+
assert isinstance(results, np.ndarray)
488+
assert len(results) == 3
489+
490+
# Verify equivalence with scalar calls
491+
for i in range(len(x)):
492+
result_scalar = grid.intersect(x[i], y[i], forgive=True)
493+
if np.isnan(results[i]) and np.isnan(result_scalar):
494+
continue
495+
assert results[i] == result_scalar
496+
497+
# Test out-of-bounds with forgive
498+
x_mixed = np.array([50.0, 200.0])
499+
y_mixed = np.array([50.0, 50.0])
500+
results_mixed = grid.intersect(x_mixed, y_mixed, forgive=True)
501+
assert np.isnan(results_mixed[1]) # Second point out of bounds
502+
503+
504+
def test_unstructured_grid_intersect_array():
505+
"""Test UnstructuredGrid.intersect() with array inputs."""
506+
# Create a simple unstructured grid using Delaunay triangulation
507+
np.random.seed(42)
508+
n_points = 50
509+
x_verts = np.random.uniform(0, 100, n_points)
510+
y_verts = np.random.uniform(0, 100, n_points)
511+
points = np.column_stack([x_verts, y_verts])
512+
513+
tri = Delaunay(points)
514+
vertices = [[i, x_verts[i], y_verts[i]] for i in range(len(x_verts))]
515+
iverts = tri.simplices.tolist()
516+
517+
xcenters = np.mean(points[tri.simplices], axis=1)[:, 0]
518+
ycenters = np.mean(points[tri.simplices], axis=1)[:, 1]
519+
520+
ncells = len(iverts)
521+
top = np.ones(ncells) * 10.0
522+
botm = np.zeros(ncells)
523+
grid = UnstructuredGrid(
524+
vertices=vertices,
525+
iverts=iverts,
526+
xcenters=xcenters,
527+
ycenters=ycenters,
528+
top=top,
529+
botm=botm,
530+
)
531+
532+
# Test array input
533+
x = np.array([25.0, 50.0, 75.0])
534+
y = np.array([25.0, 50.0, 75.0])
535+
results = grid.intersect(x, y, forgive=True)
536+
537+
# Verify array output
538+
assert isinstance(results, np.ndarray)
539+
assert len(results) == 3
540+
541+
# Verify equivalence with scalar calls
542+
for i in range(len(x)):
543+
result_scalar = grid.intersect(x[i], y[i], forgive=True)
544+
if np.isnan(results[i]) and np.isnan(result_scalar):
545+
continue
546+
assert results[i] == result_scalar
547+
548+
# Test out-of-bounds with forgive
549+
x_mixed = np.array([50.0, 200.0])
550+
y_mixed = np.array([50.0, 50.0])
551+
results_mixed = grid.intersect(x_mixed, y_mixed, forgive=True)
552+
assert np.isnan(results_mixed[1]) # Second point out of bounds
553+
554+
422555
@pytest.mark.parametrize("spc_file", ["grd.spc", "grdrot.spc"])
423556
def test_structured_from_gridspec(example_data_path, spc_file):
424557
fn = example_data_path / "specfile" / spc_file
@@ -1522,3 +1655,20 @@ def test_unstructured_iverts_cleanup():
15221655

15231656
if clean_ugrid.nvert != cleaned_vert_num:
15241657
raise AssertionError("Improper number of vertices for cleaned 'shared' iverts")
1658+
1659+
1660+
def test_structured_grid_intersect_edge_case(simple_structured_grid):
1661+
"""Test StructuredGrid.intersect() edge case: out-of-bounds x,y with z.
1662+
1663+
This tests the specific case where x,y are out of bounds and z is provided.
1664+
The expected behavior is to return (None, nan, nan).
1665+
"""
1666+
grid = simple_structured_grid
1667+
1668+
# Test out-of-bounds x,y with z coordinate
1669+
lay, row, col = grid.intersect(-50.0, -50.0, z=5.0, forgive=True)
1670+
1671+
# Expected behavior: lay=None, row=nan, col=nan
1672+
assert lay is None, f"Expected lay=None, got {lay}"
1673+
assert np.isnan(row), f"Expected row=nan, got {row}"
1674+
assert np.isnan(col), f"Expected col=nan, got {col}"

flopy/discretization/structuredgrid.py

Lines changed: 126 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,16 @@ def intersect(self, x, y, z=None, local=False, forgive=False):
971971
When the point is on the edge of two cells, the cell with the lowest
972972
row or column is returned.
973973
974+
Supports both scalar and array inputs for vectorized operations.
975+
974976
Parameters
975977
----------
976-
x : float
977-
The x-coordinate of the requested point
978-
y : float
979-
The y-coordinate of the requested point
980-
z : float
981-
Optional z-coordinate of the requested point (will return layer,
978+
x : float or array-like
979+
The x-coordinate(s) of the requested point(s)
980+
y : float or array-like
981+
The y-coordinate(s) of the requested point(s)
982+
z : float, array-like, or None
983+
Optional z-coordinate(s) of the requested point(s) (will return layer,
982984
row, column) if supplied
983985
local: bool (optional)
984986
If True, x and y are in local coordinates (defaults to False)
@@ -988,59 +990,135 @@ def intersect(self, x, y, z=None, local=False, forgive=False):
988990
989991
Returns
990992
-------
991-
row : int
992-
The row number
993-
col : int
994-
The column number
993+
row : int or ndarray
994+
The row number(s). Returns int for scalar input, ndarray for array input.
995+
col : int or ndarray
996+
The column number(s). Returns int for scalar input, ndarray for array input.
997+
lay : int or ndarray (only if z is provided)
998+
The layer number(s). Returns int for scalar input, ndarray for array input.
999+
1000+
"""
1001+
# Check if inputs are scalar
1002+
x_is_scalar = np.isscalar(x)
1003+
y_is_scalar = np.isscalar(y)
1004+
z_is_scalar = z is None or np.isscalar(z)
1005+
is_scalar_input = x_is_scalar and y_is_scalar and z_is_scalar
1006+
1007+
# Convert to arrays for uniform processing
1008+
x = np.atleast_1d(x)
1009+
y = np.atleast_1d(y)
1010+
if z is not None:
1011+
z = np.atleast_1d(z)
1012+
1013+
# Validate array shapes
1014+
if len(x) != len(y):
1015+
raise ValueError("x and y must have the same length")
1016+
if z is not None and len(z) != len(x):
1017+
raise ValueError("z must have the same length as x and y")
9951018

996-
"""
9971019
# transform x and y to local coordinates
998-
x, y = super().intersect(x, y, local, forgive)
1020+
if not local:
1021+
x, y = self.get_local_coords(x, y)
9991022

10001023
# get the cell edges in local coordinates
10011024
xe, ye = self.xyedges
10021025

1003-
xcomp = x > xe
1004-
if np.all(xcomp) or not np.any(xcomp):
1005-
if forgive:
1006-
col = np.nan
1026+
# Vectorized row/col calculation
1027+
n_points = len(x)
1028+
rows = np.full(n_points, np.nan if forgive else -1, dtype=float)
1029+
cols = np.full(n_points, np.nan if forgive else -1, dtype=float)
1030+
1031+
for i in range(n_points):
1032+
xi, yi = x[i], y[i]
1033+
1034+
# Find column
1035+
xcomp = xi > xe
1036+
if np.all(xcomp) or not np.any(xcomp):
1037+
if forgive:
1038+
cols[i] = np.nan
1039+
else:
1040+
raise ValueError(
1041+
f"x, y point given is outside of the model area: ({xi}, {yi})"
1042+
)
10071043
else:
1008-
raise Exception("x, y point given is outside of the model area")
1009-
else:
1010-
col = np.asarray(xcomp).nonzero()[0][-1]
1044+
cols[i] = np.asarray(xcomp).nonzero()[0][-1]
10111045

1012-
ycomp = y < ye
1013-
if np.all(ycomp) or not np.any(ycomp):
1014-
if forgive:
1015-
row = np.nan
1046+
# Find row
1047+
ycomp = yi < ye
1048+
if np.all(ycomp) or not np.any(ycomp):
1049+
if forgive:
1050+
rows[i] = np.nan
1051+
else:
1052+
raise ValueError(
1053+
f"x, y point given is outside of the model area: ({xi}, {yi})"
1054+
)
10161055
else:
1017-
raise Exception("x, y point given is outside of the model area")
1018-
else:
1019-
row = np.asarray(ycomp).nonzero()[0][-1]
1020-
if np.any(np.isnan([row, col])):
1021-
row = col = np.nan
1022-
if z is not None:
1023-
return None, row, col
1056+
rows[i] = np.asarray(ycomp).nonzero()[0][-1]
1057+
1058+
# If either row or col is NaN, set both to NaN
1059+
invalid_mask = np.isnan(rows) | np.isnan(cols)
1060+
rows[invalid_mask] = np.nan
1061+
cols[invalid_mask] = np.nan
1062+
1063+
# Convert to int where valid
1064+
valid_mask = ~invalid_mask
1065+
if np.any(valid_mask):
1066+
rows[valid_mask] = rows[valid_mask].astype(int)
1067+
cols[valid_mask] = cols[valid_mask].astype(int)
10241068

10251069
if z is None:
1026-
return row, col
1027-
1028-
lay = np.nan
1029-
for layer in range(self.__nlay):
1030-
if (
1031-
self.top_botm[layer, row, col]
1032-
>= z
1033-
>= self.top_botm[layer + 1, row, col]
1034-
):
1035-
lay = layer
1036-
break
1037-
1038-
if np.any(np.isnan([lay, row, col])):
1039-
lay = row = col = np.nan
1040-
if not forgive:
1041-
raise Exception("point given is outside the model area")
1042-
1043-
return lay, row, col
1070+
# Return results
1071+
if is_scalar_input:
1072+
row, col = rows[0], cols[0]
1073+
if not np.isnan(row) and not np.isnan(col):
1074+
row, col = int(row), int(col)
1075+
return row, col
1076+
else:
1077+
return rows.astype(int) if np.all(valid_mask) else rows, cols.astype(
1078+
int
1079+
) if np.all(valid_mask) else cols
1080+
1081+
# Handle z-coordinate
1082+
lays = np.full(n_points, np.nan if forgive else -1, dtype=float)
1083+
1084+
for i in range(n_points):
1085+
if np.isnan(rows[i]) or np.isnan(cols[i]):
1086+
continue
1087+
1088+
row, col = int(rows[i]), int(cols[i])
1089+
zi = z[i]
1090+
1091+
for layer in range(self.__nlay):
1092+
if (
1093+
self.top_botm[layer, row, col]
1094+
>= zi
1095+
>= self.top_botm[layer + 1, row, col]
1096+
):
1097+
lays[i] = layer
1098+
break
1099+
1100+
if np.isnan(lays[i]) and not forgive:
1101+
raise ValueError(
1102+
f"point given is outside the model area: ({x[i]}, {y[i]}, {zi})"
1103+
)
1104+
1105+
# Return results
1106+
if is_scalar_input:
1107+
lay, row, col = lays[0], rows[0], cols[0]
1108+
if not np.isnan(lay):
1109+
lay, row, col = int(lay), int(row), int(col)
1110+
else:
1111+
# When x,y are out of bounds: lay=None, row/col keep NaN
1112+
lay = None
1113+
# row and col already have their NaN values
1114+
return lay, row, col
1115+
else:
1116+
valid_3d = ~np.isnan(lays) & ~np.isnan(rows) & ~np.isnan(cols)
1117+
return (
1118+
lays.astype(int) if np.all(valid_3d) else lays,
1119+
rows.astype(int) if np.all(valid_3d) else rows,
1120+
cols.astype(int) if np.all(valid_3d) else cols,
1121+
)
10441122

10451123
def _cell_vert_list(self, i, j):
10461124
"""Get vertices for a single cell or sequence of i, j locations."""

0 commit comments

Comments
 (0)