Skip to content

Commit c1d3133

Browse files
feat: add return_distances option to nearest grid point lookup function (#237)
## Description Adds an (optional) flag to allow `nearest_grid_points` to also return the distances. The default behaviour is to return only the indices, so this should not break the existing use cases (the only instance that uses the distances is `ComplementNearest` in `anemoi-datasets`, as far as I can tell). Partially addresses ecmwf/anemoi-datasets#534 ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 548e2fa commit c1d3133

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/anemoi/transform/spatial.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,8 @@ def nearest_grid_points(
591591
target_longitudes: NDArray[Any],
592592
max_distance: float = None,
593593
num_neighbours_to_return: int = 1,
594-
) -> NDArray[Any]:
594+
return_distances: bool = False,
595+
) -> NDArray[Any] | tuple[NDArray[Any], NDArray[Any]]:
595596
"""Find the nearest grid points from source to target coordinates.
596597
597598
Parameters
@@ -609,10 +610,12 @@ def nearest_grid_points(
609610
For example, 1e-3 is 1 km.
610611
num_neighbours_to_return : int, optional
611612
Number of nearest neighbours to return. Defaults to 1.
613+
return_distances : bool, optional
614+
Whether to return distances along with indices. Defaults to False.
612615
Returns
613616
-------
614-
NDArray[Any]
615-
Indices of the nearest grid points.
617+
NDArray[Any] | tuple[NDArray[Any], NDArray[Any]]
618+
Indices of the nearest grid points, or a tuple of indices and distances.
616619
"""
617620
from scipy.spatial import cKDTree
618621

@@ -622,9 +625,11 @@ def nearest_grid_points(
622625
target_xyz = latlon_to_xyz(target_latitudes, target_longitudes)
623626
target_points = np.array(target_xyz).transpose()
624627
if max_distance is None:
625-
_, indices = cKDTree(source_points).query(target_points, k=num_neighbours_to_return)
628+
distances, indices = cKDTree(source_points).query(target_points, k=num_neighbours_to_return)
626629
else:
627-
_, indices = cKDTree(source_points).query(
630+
distances, indices = cKDTree(source_points).query(
628631
target_points, k=num_neighbours_to_return, distance_upper_bound=max_distance
629632
)
633+
if return_distances:
634+
return indices, distances
630635
return indices

0 commit comments

Comments
 (0)