Skip to content

Commit 1cc2d9c

Browse files
authored
Improve Point in Face Function Signature and Helper (#1272)
* update helpers and improve point in face function signature * update example
1 parent d37fbc1 commit 1cc2d9c

File tree

3 files changed

+102
-63
lines changed

3 files changed

+102
-63
lines changed

test/test_point_in_face.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_face_centers(grid):
3535

3636
for fid, center in enumerate(centers_xyz):
3737
hits = grid.get_faces_containing_point(
38-
point_xyz=center,
38+
points=center,
3939
return_counts=False
4040
)
4141
assert isinstance(hits, list)
@@ -50,7 +50,7 @@ def test_face_centers(grid):
5050

5151
for fid, (lon, lat) in enumerate(centers_lonlat):
5252
hits = grid.get_faces_containing_point(
53-
point_lonlat=(lon, lat),
53+
points=(lon, lat),
5454
return_counts=False
5555
)
5656
assert hits[0] == [fid]
@@ -61,7 +61,6 @@ def test_node_corners(grid):
6161
Cartesian and spherical (lon/lat) returns exactly the faces sharing it.
6262
"""
6363

64-
print(grid.max_face_radius)
6564

6665
node_coords = np.vstack([
6766
grid.node_x.values,
@@ -77,7 +76,7 @@ def test_node_corners(grid):
7776
expected = conn[nid, :counts[nid]].tolist()
7877

7978
hits_xyz = grid.get_faces_containing_point(
80-
point_xyz=(x, y, z),
79+
points=(x, y, z),
8180
return_counts=False
8281
)[0]
8382
assert set(hits_xyz) == set(expected)
@@ -94,7 +93,7 @@ def test_node_corners(grid):
9493
expected = conn[nid, :counts[nid]].tolist()
9594

9695
hits_ll = grid.get_faces_containing_point(
97-
point_lonlat=(lon, lat),
96+
points=(lon, lat),
9897
return_counts=False
9998
)[0]
10099
assert set(hits_ll) == set(expected)
@@ -110,24 +109,24 @@ def test_number_of_faces_found():
110109
# For a face center only one face should be found
111110
point_xyz = np.array([grid.face_x[100].values, grid.face_y[100].values, grid.face_z[100].values], dtype=np.float64)
112111

113-
assert len(grid.get_faces_containing_point(point_xyz=point_xyz, return_counts=False)[0]) == 1
112+
assert len(grid.get_faces_containing_point(point_xyz, return_counts=False)[0]) == 1
114113

115114
# For an edge two faces should be found
116115
point_xyz = np.array([grid.edge_x[100].values, grid.edge_y[100].values, grid.edge_z[100].values], dtype=np.float64)
117116

118-
assert len(grid.get_faces_containing_point(point_xyz=point_xyz, return_counts=False)[0]) == 2
117+
assert len(grid.get_faces_containing_point(point_xyz, return_counts=False)[0]) == 2
119118

120119
# For a node three faces should be found
121120
point_xyz = np.array([grid.node_x[100].values, grid.node_y[100].values, grid.node_z[100].values], dtype=np.float64)
122121

123-
assert len(grid.get_faces_containing_point(point_xyz=point_xyz, return_counts=False)[0]) == 3
122+
assert len(grid.get_faces_containing_point(point_xyz, return_counts=False)[0]) == 3
124123

125124
partial_grid.normalize_cartesian_coordinates()
126125

127126
# Test for a node on the edge where only 2 faces should be found
128127
point_xyz = np.array([partial_grid.node_x[1].values, partial_grid.node_y[1].values, partial_grid.node_z[1].values], dtype=np.float64)
129128

130-
assert len(partial_grid.get_faces_containing_point(point_xyz=point_xyz, return_counts=False)[0]) == 2
129+
assert len(partial_grid.get_faces_containing_point(point_xyz, return_counts=False)[0]) == 2
131130

132131
def test_point_along_arc():
133132
node_lon = np.array([-40, -40, 40, 40])
@@ -137,9 +136,24 @@ def test_point_along_arc():
137136
uxgrid = ux.Grid.from_topology(node_lon, node_lat, face_node_connectivity)
138137

139138
# point at exactly 20 degrees latitude
140-
out1 = uxgrid.get_faces_containing_point(point_lonlat=np.array([0, 20], dtype=np.float64), return_counts=False)
139+
out1 = uxgrid.get_faces_containing_point(np.array([0, 20], dtype=np.float64), return_counts=False)
141140

142141
# point at 25.41 degrees latitude (max along the great circle arc)
143-
out2 = uxgrid.get_faces_containing_point(point_lonlat=np.array([0, 25.41], dtype=np.float64), return_counts=False)
142+
out2 = uxgrid.get_faces_containing_point(np.array([0, 25.41], dtype=np.float64), return_counts=False)
144143

145144
nt.assert_array_equal(out1[0], out2[0])
145+
146+
147+
def test_coordinates(grid):
148+
149+
lonlat = np.vstack([grid.node_lon.values, grid.node_lat.values]).T
150+
xyz = np.vstack([grid.node_x.values, grid.node_y.values, grid.node_z.values]).T
151+
152+
faces_from_lonlat, _ = grid.get_faces_containing_point(points=lonlat)
153+
faces_from_xyz, _ = grid.get_faces_containing_point(points=xyz)
154+
155+
nt.assert_array_equal(faces_from_lonlat, faces_from_xyz)
156+
157+
with pytest.raises(ValueError):
158+
dummy_points = np.ones((10, 4))
159+
faces_query_both, _ = grid.get_faces_containing_point(points=dummy_points)

uxarray/grid/coordinates.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -868,19 +868,39 @@ def prepare_points(points, normalize):
868868
return np.vstack([x, y, z]).T
869869

870870

871-
def _prepare_points_for_kdtree(lonlat, xyz):
872-
if (lonlat is not None) and (xyz is not None):
871+
def points_atleast_2d_xyz(points):
872+
"""
873+
Ensure the input is at least 2D and return Cartesian (x, y, z) coordinates.
874+
875+
Parameters
876+
----------
877+
points : array_like, shape (N, 2) or (N, 3)
878+
- If shape is (N, 2), interpreted as [longitude, latitude] in degrees.
879+
- If shape is (N, 3), interpreted as Cartesian [x, y, z] coordinates.
880+
881+
Returns
882+
-------
883+
points_xyz : ndarray, shape (N, 3)
884+
Cartesian coordinates [x, y, z] for each input point.
885+
886+
Raises
887+
------
888+
ValueError
889+
If `points` (after `np.atleast_2d`) does not have 2 or 3 columns.
890+
891+
"""
892+
893+
points = np.atleast_2d(points)
894+
895+
if points.shape[1] == 2:
896+
points_lonlat_rad = np.deg2rad(points)
897+
x, y, z = _lonlat_rad_to_xyz(points_lonlat_rad[:, 0], points_lonlat_rad[:, 1])
898+
points_xyz = np.vstack([x, y, z]).T
899+
elif points.shape[1] == 3:
900+
points_xyz = points
901+
else:
873902
raise ValueError(
874-
"Both Cartesian (xyz) and Spherical (lonlat) coordinates were provided. One one can be "
875-
"provided at a time."
903+
"Points are neither Cartesian (shape N x 3) nor Spherical (shape N x 2)."
876904
)
877905

878-
# Convert to cartesian if points are spherical
879-
if xyz is None:
880-
lon, lat = map(np.deg2rad, lonlat)
881-
xyz = _lonlat_rad_to_xyz(lon, lat)
882-
pts = np.asarray(xyz, dtype=np.float64)
883-
if pts.ndim == 1:
884-
pts = pts[np.newaxis, :]
885-
886-
return pts
906+
return points_xyz

uxarray/grid/grid.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
_populate_face_centroids,
3737
_populate_node_latlon,
3838
_populate_node_xyz,
39-
_prepare_points_for_kdtree,
4039
_set_desired_longitude_range,
40+
points_atleast_2d_xyz,
4141
prepare_points,
4242
)
4343
from uxarray.grid.dual import construct_dual
@@ -2560,72 +2560,77 @@ def get_faces_between_latitudes(self, lats: Tuple[float, float]):
25602560

25612561
def get_faces_containing_point(
25622562
self,
2563-
*,
2564-
point_lonlat: Sequence[float] | np.ndarray = None,
2565-
point_xyz: Sequence[float] | np.ndarray = None,
2563+
points: Sequence[float] | np.ndarray,
25662564
return_counts: bool = True,
25672565
) -> Tuple[np.ndarray, np.ndarray] | List[List[int]]:
25682566
"""
2569-
Identify which faces on the grid contain the given point(s).
2570-
2571-
Exactly one of `point_lonlat` or `point_xyz` must be provided.
2567+
Identify which grid faces contain the given point(s).
25722568
25732569
Parameters
25742570
----------
2575-
point_lonlat : array_like of shape (2,), optional
2576-
Longitude and latitude in **degrees**: (lon, lat).
2577-
point_xyz : array_like of shape (3,), optional
2578-
Cartesian coordinates on the unit sphere: (x, y, z).
2571+
points : array_like, shape (N, 2) or (2,) or shape (N, 3) or (3,)
2572+
Query point(s) to locate on the grid.
2573+
- If last dimension is 2, interpreted as (longitude, latitude) in **degrees**.
2574+
- If last dimension is 3, interpreted as Cartesian coordinates on the unit sphere: (x, y, z).
2575+
You may pass a single point (shape `(2,)` or `(3,)`) or multiple points (shape `(N, 2)` or `(N, 3)`).
25792576
return_counts : bool, default=True
2580-
- If True (default), returns a tuple `(face_indices, counts)`.
2581-
- If False, returns a `list` of per-point lists of face indices.
2577+
- If True, returns a tuple `(face_indices, counts)`.
2578+
- If False, returns a `list` of per-point lists of face indices (no padding).
25822579
25832580
Returns
25842581
-------
25852582
If `return_counts=True`:
2586-
face_indices : np.ndarray, shape (N, M) or (N,)
2587-
- 2D array of face indices with unused slots are filled with `INT_FILL_VALUE`.
2583+
face_indices : np.ndarray, shape (N, M) or (N, 1)
2584+
2D array of face indices. Rows are padded with `INT_FILL_VALUE` when a point
2585+
lies on corners of multiple faces. If every queried point falls in exactly
2586+
one face, the result has shape `(N, 1)`.
25882587
counts : np.ndarray, shape (N,)
2589-
Number of valid face indices in each row of `face_indices`.
2588+
Number of valid face indices in each row of `face_indices`.
25902589
25912590
If `return_counts=False`:
25922591
List[List[int]]
2593-
A Python list of length `N`, where each element is the
2594-
list of face indices (no padding) for that query point.
2592+
Python list of length `N`, where each element is the list of face
2593+
indices for that point (no padding, in natural order).
25952594
25962595
Notes
25972596
-----
2598-
- **Most** points will lie strictly inside exactly one face:
2599-
in that common case, `counts == 1` and the returned
2600-
`face_indices` can be collapsed to shape `(N,)`, with no padding.
2601-
- If a point falls exactly on a corner shared by multiple faces,
2602-
multiple face indices will appear in the first columns of each row;
2603-
the remainder of each row is filled with `INT_FILL_VALUE`.
2597+
- Most points will lie strictly inside exactly one face; in that case,
2598+
`counts == 1` and `face_indices` has one column.
2599+
- Points that lie exactly on a vertex or edge shared by multiple faces
2600+
return multiple indices in the first `counts[i]` columns of row `i`,
2601+
with any remaining columns filled by `INT_FILL_VALUE`.
2602+
26042603
26052604
Examples
26062605
--------
2607-
>>> import uxarray as ux
2608-
>>> grid_path = "/path/to/grid.nc"
2609-
>>> uxgrid = ux.open_grid(grid_path)
26102606
2611-
# 1. Query a Spherical (lon/lat) point
2612-
>>> indices, counts = uxgrid.get_faces_containing_point(point_lonlat=(0.0, 0.0))
2607+
Query a single spherical point
26132608
2614-
# 2. Query a Cartesian (xyz) point
2615-
>>> indices, counts = uxgrid.get_faces_containing_point(
2616-
... point_xyz=(0.0, 0.0, 1.0)
2617-
... )
2609+
>>> face_indices, counts = uxgrid.get_faces_containing_point(points=(0.0, 0.0))
2610+
2611+
Query a single Cartesian point
2612+
2613+
>>> face_indices, counts = uxgrid.get_faces_containing_point(
2614+
... points=[0.0, 0.0, 1.0]
2615+
... )
26182616
2619-
# 3. Return indices as a list of lists (no counts nor padding)
2620-
>>> indices = uxgrid.get_faces_containing_point(
2621-
... point_xyz=(0.0, 0.0, 1.0), return_counts=False
2617+
Query multiple points at once
2618+
2619+
>>> pts = [(0.0, 0.0), (10.0, 20.0)]
2620+
>>> face_indices, counts = uxgrid.get_faces_containing_point(points=pts)
2621+
2622+
Return a list of lists
2623+
2624+
>>> face_indices_list = uxgrid.get_faces_containing_point(
2625+
... points=[0.0, 0.0, 1.0], return_counts=False
26222626
... )
2623-
"""
26242627
2625-
pts = _prepare_points_for_kdtree(point_lonlat, point_xyz)
2628+
"""
26262629

26272630
# Determine faces containing points
2628-
face_indices, counts = _point_in_face_query(source_grid=self, points=pts)
2631+
face_indices, counts = _point_in_face_query(
2632+
source_grid=self, points=points_atleast_2d_xyz(points)
2633+
)
26292634

26302635
# Return a list of lists if counts are not desired
26312636
if not return_counts:

0 commit comments

Comments
 (0)