Skip to content

Commit 0fffdeb

Browse files
adacovskclaude
andcommitted
feat(geospatial): add KD-tree spatial indexing with 3D support
Implement GeospatialIndex class using KD-tree and ConvexHull for fast spatial queries on unstructured grids. Provides massive performance improvements for grid.intersect() operations on VertexGrid and UnstructuredGrid with up to 414x speedup for large grids. **Architecture:** - 2D mode: (x,y) KD-tree + ConvexHull for point-in-polygon tests - 3D mode: (x,y,z) KD-tree + 3D bounding boxes for grid_varies_by_layer - Automatic mode detection based on grid properties - Epsilon tolerance (default 1e-6) for boundary point handling - Vectorized z-coordinate search using numpy masks **Performance Benchmarks:** - VertexGrid: 1.2x-348x speedup (best for 1000+ cells) - UnstructuredGrid 2D: 1.1x-352x speedup (best for 1000+ cells) - UnstructuredGrid 3D: 1.1x-414x speedup (best for 2000+ cells) - StructuredGrid: 0.10x-0.58x (slower, keeps using searchsorted) **Key Features:** - Pre-computed ConvexHull equations for vectorized containment tests - Bounding box rejection for fast filtering - KD-tree nearest neighbor search to find candidate cells - Supports both 2D and 3D spatial indexing - Handles grid_varies_by_layer for true 3D grids - Epsilon tolerance for robust boundary point handling **Implementation Details:** - GeospatialIndex class in flopy/utils/geospatial_index.py - _build_index() for 2D (x,y) point indexing with centroids + vertices - _build_3d_index() for 3D (x,y,z) point indexing - _precompute_hulls() for 2D bounding boxes and ConvexHull equations - _precompute_3d_bounds() for 3D bounding boxes - _point_in_cell_vectorized() for 2D containment tests - _point_in_cell_3d() for 3D bounding box containment - _find_layer_for_z() vectorized layer search (no loops) - query_point() and query_points() public API **Grid Integration:** - StructuredGrid: continues using searchsorted() (faster) - VertexGrid: uses GeospatialIndex for 2D + z-search - UnstructuredGrid: uses GeospatialIndex for both 2D and 3D modes - Automatic fallback to old methods for small grids **Testing:** - autotest/test_edge_cases.py: boundary conditions and edge cases - autotest/test_geospatial_benchmark.py: performance validation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 2f24768 commit 0fffdeb

File tree

6 files changed

+1351
-150
lines changed

6 files changed

+1351
-150
lines changed

autotest/test_edge_cases.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
Edge case tests for GeospatialIndex to ensure robustness.
3+
"""
4+
5+
import numpy as np
6+
import pytest
7+
from scipy.spatial import Delaunay
8+
9+
from flopy.discretization import VertexGrid
10+
from flopy.utils.geospatial_index import GeospatialIndex
11+
12+
13+
def test_thin_sliver_cell():
14+
"""
15+
Test that GeospatialIndex can find points in very thin "sliver" cells
16+
where the centroid might be far from the actual cell location.
17+
"""
18+
# Create a grid with a very thin sliver cell
19+
# This tests the centroid+vertices KD-tree approach
20+
np.random.seed(42)
21+
22+
# Create base random points
23+
n_points = 15
24+
x_verts = np.random.uniform(0, 100, n_points).tolist()
25+
y_verts = np.random.uniform(0, 100, n_points).tolist()
26+
27+
# Add vertices for a very thin vertical sliver at x=50
28+
sliver_indices = []
29+
for i in range(4):
30+
idx = len(x_verts)
31+
sliver_indices.append(idx)
32+
x_verts.append(50.0 + i * 0.05) # Very thin: 0.15 units wide
33+
y_verts.append(i * 33.33) # Tall: 100 units high
34+
35+
# Create Delaunay triangulation
36+
points = np.column_stack([x_verts, y_verts])
37+
tri = Delaunay(points)
38+
39+
# Build VertexGrid
40+
vertices = [[i, x_verts[i], y_verts[i]] for i in range(len(x_verts))]
41+
cell2d = []
42+
for i, simplex in enumerate(tri.simplices):
43+
# Calculate centroid
44+
cell_x = np.mean([x_verts[j] for j in simplex])
45+
cell_y = np.mean([y_verts[j] for j in simplex])
46+
cell2d.append([i, cell_x, cell_y, len(simplex)] + list(simplex))
47+
48+
ncells = len(cell2d)
49+
grid = VertexGrid(
50+
vertices=vertices,
51+
cell2d=cell2d,
52+
top=np.ones(ncells) * 10.0,
53+
botm=np.zeros(ncells),
54+
)
55+
56+
# Build index
57+
index = GeospatialIndex(grid)
58+
59+
# Test points in/near the sliver region
60+
test_points = [
61+
(50.025, 50.0), # Should be in a sliver cell
62+
(50.075, 25.0), # Should be in a sliver cell
63+
(50.025, 75.0), # Should be in a sliver cell
64+
]
65+
66+
found_count = 0
67+
for x, y in test_points:
68+
result = index.query_point(x, y, k=20) # Use k=20 to be thorough
69+
if result is not None:
70+
# Verify the point is actually in the found cell
71+
xv, yv, _ = grid.xyzvertices
72+
verts = np.column_stack([xv[result], yv[result]])
73+
from matplotlib.path import Path
74+
75+
path = Path(verts)
76+
is_inside = path.contains_point((x, y), radius=1e-9)
77+
78+
if is_inside:
79+
found_count += 1
80+
print(f"✅ Point ({x}, {y}) correctly found in cell {result}")
81+
else:
82+
print(
83+
f"⚠️ Point ({x}, {y}) found in cell {result} "
84+
f"but verification failed"
85+
)
86+
else:
87+
print(f"❌ Point ({x}, {y}) NOT FOUND")
88+
89+
# At least some points should be found
90+
# (not all may be in cells due to Delaunay triangulation specifics)
91+
assert found_count > 0, (
92+
f"Should find at least some points in sliver cells, found {found_count}/3"
93+
)
94+
95+
96+
def test_boundary_points():
97+
"""Test points exactly on cell boundaries."""
98+
# Create simple 2x2 grid of triangles
99+
vertices = [
100+
[0, 0.0, 0.0],
101+
[1, 1.0, 0.0],
102+
[2, 2.0, 0.0],
103+
[3, 0.0, 1.0],
104+
[4, 1.0, 1.0],
105+
[5, 2.0, 1.0],
106+
]
107+
108+
# Create triangular cells
109+
cell2d = [
110+
[0, 0.33, 0.33, 3, 0, 1, 3], # Lower-left triangle
111+
[1, 0.67, 0.33, 3, 1, 4, 3], # Lower-middle triangle
112+
[2, 1.33, 0.33, 3, 1, 2, 4], # Lower-right triangle
113+
[3, 1.67, 0.67, 3, 2, 5, 4], # Upper-right triangle
114+
]
115+
116+
grid = VertexGrid(
117+
vertices=vertices, cell2d=cell2d, top=np.ones(4) * 10.0, botm=np.zeros(4)
118+
)
119+
120+
index = GeospatialIndex(grid)
121+
122+
# Test point exactly on a shared edge (between cells 0 and 1)
123+
# Should find one of the two cells
124+
result = index.query_point(1.0, 0.5, k=10)
125+
assert result is not None or result == 0 or result == 1, (
126+
"Point on boundary should be found in one of the adjacent cells"
127+
)
128+
129+
130+
def test_corner_points():
131+
"""Test points at vertices (corners where multiple cells meet)."""
132+
# Simple 2x2 grid
133+
vertices = [
134+
[0, 0.0, 0.0],
135+
[1, 1.0, 0.0],
136+
[2, 1.0, 1.0],
137+
[3, 0.0, 1.0],
138+
[4, 0.5, 0.5], # Center point
139+
]
140+
141+
cell2d = [
142+
[0, 0.5, 0.17, 3, 0, 1, 4],
143+
[1, 0.83, 0.5, 3, 1, 2, 4],
144+
[2, 0.5, 0.83, 3, 2, 3, 4],
145+
[3, 0.17, 0.5, 3, 3, 0, 4],
146+
]
147+
148+
grid = VertexGrid(
149+
vertices=vertices, cell2d=cell2d, top=np.ones(4) * 10.0, botm=np.zeros(4)
150+
)
151+
152+
index = GeospatialIndex(grid)
153+
154+
# Test point at center vertex (shared by all 4 cells)
155+
result = index.query_point(0.5, 0.5, k=10)
156+
assert result is not None, (
157+
"Point at center vertex should be found in one of the cells"
158+
)
159+
assert result in [0, 1, 2, 3], (
160+
f"Result {result} should be one of the 4 center cells"
161+
)
162+
163+
164+
def test_outside_grid():
165+
"""Test points well outside the grid."""
166+
# Simple triangle
167+
vertices = [
168+
[0, 0.0, 0.0],
169+
[1, 10.0, 0.0],
170+
[2, 5.0, 10.0],
171+
]
172+
173+
cell2d = [
174+
[0, 5.0, 3.33, 3, 0, 1, 2],
175+
]
176+
177+
grid = VertexGrid(
178+
vertices=vertices, cell2d=cell2d, top=np.array([10.0]), botm=np.array([0.0])
179+
)
180+
181+
index = GeospatialIndex(grid)
182+
183+
# Points clearly outside
184+
outside_points = [
185+
(-10.0, 0.0),
186+
(20.0, 0.0),
187+
(5.0, 20.0),
188+
(15.0, 15.0),
189+
]
190+
191+
for x, y in outside_points:
192+
result = index.query_point(x, y, k=10)
193+
assert result is None, (
194+
f"Point ({x}, {y}) outside grid should return None, got {result}"
195+
)
196+
197+
198+
if __name__ == "__main__":
199+
print("Running edge case tests...")
200+
201+
print("\n1. Testing thin sliver cells...")
202+
test_thin_sliver_cell()
203+
204+
print("\n2. Testing boundary points...")
205+
test_boundary_points()
206+
207+
print("\n3. Testing corner points...")
208+
test_corner_points()
209+
210+
print("\n4. Testing outside grid points...")
211+
test_outside_grid()
212+
213+
print("\n✅ All edge case tests passed!")

0 commit comments

Comments
 (0)