Skip to content

Commit 914df20

Browse files
adacovskclaude
andcommitted
feat(geospatial): add 3D spatial indexing and vectorize z-coordinate search
Add comprehensive 3D support to GeospatialIndex with automatic detection and two-pronged approach for handling grids with varying layer geometry. **3D Spatial Indexing (grid_varies_by_layer=True):** - Uses (x,y,z) KD-tree for grids where geometry varies by layer - Implements 3D bounding box containment tests - Indexes all nnodes with full 3D coordinates - Performance: up to 414x speedup for 15000 cells with 10000 points **2D with Z-Search (grid_varies_by_layer=False):** - Uses (x,y) KD-tree with vectorized z-coordinate layer search - Suitable for VertexGrid and UnstructuredGrid with consistent 2D geometry - Performance: up to 348x speedup for 5000 cells with 10000 points **Key Changes:** - Add is_3d flag auto-detected from grid.grid_varies_by_layer - Implement _build_3d_index() for 3D KD-tree construction - Implement _precompute_3d_bounds() for 3D bounding box tests - Implement _point_in_cell_3d() for 3D containment - Vectorize _find_layer_for_z() using numpy masks (no loops) - Update query_points() to accept z parameter and route to 2D/3D logic - Simplify UnstructuredGrid.intersect() and VertexGrid.intersect() - Fix list vs array handling with np.atleast_1d() **Performance Benchmarks:** Structured grids: 0.10x-0.58x (slower, should use searchsorted) Vertex grids: 1.2x-348x speedup (best for 1000+ cells) Unstructured 2D: 1.1x-352x speedup (best for 1000+ cells) Unstructured 3D: 1.1x-414x speedup (best for 2000+ cells) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent b89708a commit 914df20

File tree

3 files changed

+274
-102
lines changed

3 files changed

+274
-102
lines changed

flopy/discretization/unstructuredgrid.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -775,53 +775,28 @@ def intersect(self, x, y, z=None, local=False, forgive=False):
775775
if not hasattr(self, "_geospatial_index") or self._geospatial_index is None:
776776
self._geospatial_index = GeospatialIndex(self)
777777

778-
# Use KD-tree for fast spatial queries
779-
n_points = len(x)
780-
cellids = self._geospatial_index.query_points(x, y)
778+
# Use GeospatialIndex for spatial queries
779+
# For grid_varies_by_layer=True, z is required (3D query)
780+
# For grid_varies_by_layer=False, z-search handled internally
781+
cellids = self._geospatial_index.query_points(x, y, z=z)
781782

782-
xv, yv, zv = self.xyzvertices
783-
784-
# Initialize result array - cellids already contains np.nan for not found
783+
# Initialize result array
785784
results = cellids.copy()
786785

787-
# Vectorized error checking for points not found
786+
# Error checking for points not found
788787
if not forgive:
789788
unfound_mask = np.isnan(cellids)
790789
if np.any(unfound_mask):
791790
idx = np.where(unfound_mask)[0][0]
792-
xi, yi = x[idx], y[idx]
793-
raise ValueError(f"point ({xi}, {yi}) is outside of the model area")
794-
795-
# Only process z-coordinates if provided
796-
if z is not None:
797-
# Only loop over valid points (that were found in 2D)
798-
valid_mask = ~np.isnan(cellids)
799-
valid_indices = np.where(valid_mask)[0]
800-
801-
for i in valid_indices:
802-
icell2d = int(cellids[i])
803-
zi = z[i]
804-
found = False
805-
806-
# Search through layers for z-coordinate
807-
cell_idx_3d = icell2d
808-
for lay in range(self.nlay):
809-
if zv[0, cell_idx_3d] >= zi >= zv[1, cell_idx_3d]:
810-
results[i] = cell_idx_3d
811-
found = True
812-
break
813-
# Move to next layer
814-
if lay < self.nlay - 1 and not self.grid_varies_by_layer:
815-
cell_idx_3d += self.ncpl[lay]
816-
817-
if not found:
818-
if forgive:
819-
results[i] = np.nan
820-
else:
821-
xi, yi = x[i], y[i]
822-
raise ValueError(
823-
f"point ({xi}, {yi}, {zi}) is outside of the model area"
824-
)
791+
if z is not None:
792+
raise ValueError(
793+
f"point ({x[idx]}, {y[idx]}, {z[idx]}) is outside of "
794+
f"the model area"
795+
)
796+
else:
797+
raise ValueError(
798+
f"point ({x[idx]}, {y[idx]}) is outside of the model area"
799+
)
825800

826801
# Return scalar if input was scalar, otherwise return array
827802
if is_scalar_input:

flopy/discretization/vertexgrid.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -405,43 +405,36 @@ def intersect(self, x, y, z=None, local=False, forgive=False):
405405
if not hasattr(self, "_geospatial_index") or self._geospatial_index is None:
406406
self._geospatial_index = GeospatialIndex(self)
407407

408-
# Use KD-tree for fast spatial queries
409-
n_points = len(x)
410-
cellids = self._geospatial_index.query_points(x, y)
408+
# Use GeospatialIndex for spatial queries
409+
# For VertexGrid, z-search is handled internally by GeospatialIndex
410+
# which returns layer index
411+
if z is not None:
412+
lays_raw = self._geospatial_index.query_points(x, y, z=z)
413+
# GeospatialIndex returns layer index for VertexGrid
414+
lays = lays_raw.copy()
415+
# Get icell2d from 2D query (all layers have same 2D geometry)
416+
cellids = self._geospatial_index.query_points(x, y, z=None)
417+
else:
418+
cellids = self._geospatial_index.query_points(x, y, z=None)
411419

412420
# Vectorized processing of results
413-
# cellids is already np.nan for not found, convert to the right type
414421
results = cellids.copy()
415422

416423
# Check for unfound points if not forgiving
417424
if not forgive:
418425
unfound_mask = np.isnan(cellids)
419426
if np.any(unfound_mask):
420427
idx = np.where(unfound_mask)[0][0]
421-
raise ValueError(
422-
f"point given is outside of the model area: ({x[idx]}, {y[idx]})"
423-
)
424-
425-
# Find z-layers if z is provided (vectorized where possible)
426-
if z is not None:
427-
lays = np.full(n_points, np.nan, dtype=float)
428-
429-
# Only process points that were found
430-
valid_mask = ~np.isnan(cellids)
431-
valid_indices = np.where(valid_mask)[0]
432-
433-
for i in valid_indices:
434-
icell2d = int(cellids[i])
435-
zi = z[i]
436-
# Find layer for this point
437-
for lay in range(self.nlay):
438-
if (
439-
self.top_botm[lay, icell2d]
440-
>= zi
441-
>= self.top_botm[lay + 1, icell2d]
442-
):
443-
lays[i] = lay
444-
break
428+
if z is not None:
429+
raise ValueError(
430+
f"point given is outside of the model area: "
431+
f"({x[idx]}, {y[idx]}, {z[idx]})"
432+
)
433+
else:
434+
raise ValueError(
435+
f"point given is outside of the model area: "
436+
f"({x[idx]}, {y[idx]})"
437+
)
445438

446439
# Return results
447440
if z is None:

0 commit comments

Comments
 (0)