Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions src/anemoi/transform/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from numpy.typing import NDArray

from anemoi.transform.constants import L_1_degree_earth_arc_length_km
from anemoi.transform.constants import R_earth_km

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -298,9 +299,15 @@ def cutout_mask(
cropping_distance: float = 2.0,
neighbours: int = 5,
min_distance_km: int | float = None,
max_distance_km: int | float = None,
plot: str = None,
) -> NDArray[Any]:
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons].
"""Return a mask for the points in [global_lats, global_lons] to mask out.

This may be because these points are :
- inside of [lats, lons]
- too close to it (if min_distance_km is set)
- too far from it (if max_distance_km is set)

Parameters
----------
Expand All @@ -318,6 +325,9 @@ def cutout_mask(
Number of neighbours. Defaults to 5.
min_distance_km : int | float, optional
Minimum distance in kilometers. Defaults to None.
max_distance_km : Optional[Union[int, float]], optional
Maximum distance in kilometers. Points further than this distance from the LAM
region will be excluded from the mask. Defaults to None.
plot : str, optional
Path for saving the plot. Defaults to None.

Expand All @@ -326,6 +336,11 @@ def cutout_mask(
NDArray[Any]
Mask array.
"""
assert cropping_distance >= 0.0, "cropping_distance must be non-negative"
assert min_distance_km is None or min_distance_km >= 0.0, "min_distance_km must be non-negative"
assert max_distance_km is None or max_distance_km >= 0.0, "max_distance_km must be non-negative"
assert neighbours > 0, "neighbours must be positive"

from scipy.spatial import cKDTree

# TODO: transform min_distance from lat/lon to xyz
Expand All @@ -337,14 +352,20 @@ def cutout_mask(
west = np.amin(lons)

# Reduce the global grid to the area of interest
effective_cropping_distance = cropping_distance
if max_distance_km is not None:
# If max_distance_km is specified, ensure that cropping_mask() will contain
# only point too far
max_distance_degrees = max_distance_km / (1.1 * L_1_degree_earth_arc_length_km)
effective_cropping_distance = max(cropping_distance, max_distance_degrees)

mask = cropping_mask(
global_lats,
global_lons,
np.min([90.0, north + cropping_distance]),
west - cropping_distance,
np.max([-90.0, south - cropping_distance]),
east + cropping_distance,
np.min([90.0, north + effective_cropping_distance]),
west - effective_cropping_distance,
np.max([-90.0, south - effective_cropping_distance]),
east + effective_cropping_distance,
)

# return mask
Expand Down Expand Up @@ -376,7 +397,7 @@ def cutout_mask(

for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)):

# We check more than one triangle in case te global point
# We check more than one triangle in case the global point
# is near the edge of triangle, (the lam point and global points are colinear)

inside = False
Expand All @@ -390,18 +411,19 @@ def cutout_mask(

close = np.min(distance) <= min_distance

inside_lam.append(inside or close)
too_far = False
if max_distance_km is not None:
too_far = np.min(distance) > (max_distance_km / R_earth_km)

j = 0
inside_lam_array = np.array(inside_lam)
for i, m in enumerate(mask):
if not m:
continue
inside_lam.append(inside or close or too_far)

mask[i] = inside_lam_array[j]
j += 1
# Apply max_distance_km filter if specified
too_far = False
if isinstance(max_distance_km, (int, float)):
too_far = ~mask.copy() # all points outside the cropping area are too far

assert j == len(inside_lam_array)
mask[mask] = inside_lam
mask[too_far] = True

# Invert the mask, so we have only the points outside the cutout
mask = ~mask
Expand Down
145 changes: 145 additions & 0 deletions test_spatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# (C) Copyright 2026 Anemoi contributors.

#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import numpy as np
import pytest

from anemoi.transform.spatial import cutout_mask


@pytest.mark.parametrize("cropping_distance", [1.0, 3.0, 5.0])
def test_cutout_mask_with_max_distance(cropping_distance: float):
"""Test cutout_mask with max_distance_km parameter.

The results should be independent of the cropping_distance parameter.
"""
# Create a LAM region
lam_lat_range = np.linspace(44.0, 46.0, 11)
lam_lon_range = np.linspace(0.0, 2.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

# Create a global grid with points at varying distances
global_lats = np.array([43.1, 44.0, 45.0, 45.5, 46.0, 50.0])
global_lons = np.array([359.1, 359.5, 0.0, 1.0, 2.0, 0.0])

# Apply mask with max_distance_km to exclude far points
mask = cutout_mask(
lam_lats,
lam_lons,
global_lats,
global_lons,
cropping_distance=cropping_distance,
max_distance_km=250.0, # 250 km limit
)

# The first point at lat=43.1 should be included (distance in [0, max_distance_km])
# The next 4 points should be excluded (inside)
# The last point at lat=50.0 should be excluded (too far)
assert isinstance(mask, np.ndarray)
assert mask.shape == global_lats.shape
assert np.array_equal(mask, np.array([True, False, False, False, False, False]))


def test_cutout_mask_with_min_distance():
"""Test cutout_mask with both min_distance_km."""
# Create a LAM region
lam_lat_range = np.linspace(44.0, 46.0, 11)
lam_lon_range = np.linspace(0.0, 2.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

# Create a global grid
global_lats = np.array([44.0, 45.0, 46.0, 46.1, 47.5])
global_lons = np.array([0.0, 1.0, 2.0, -0.1, -1.5])

mask = cutout_mask(
lam_lats,
lam_lons,
global_lats,
global_lons,
min_distance_km=100.0,
)

# The first 3 points should be excluded (inside)
# The 4th point at lat=46.1 should be excluded (too close)
# The last point at lat=47.5 should be included
assert isinstance(mask, np.ndarray)
assert mask.shape == global_lats.shape
assert np.array_equal(mask, np.array([False, False, False, False, True]))


def test_cutout_mask_array_shapes():
"""Test that input arrays must be 1D."""
lam_lats = np.array([[45.0, 45.0], [46.0, 46.0]])
lam_lons = np.array([[0.0, 1.0], [0.0, 1.0]])
global_lats = np.array([45.0])
global_lons = np.array([0.0])

# Should raise assertion error due to 2D arrays
with pytest.raises(AssertionError):
cutout_mask(lam_lats, lam_lons, global_lats, global_lons)


def test_cutout_mask_parameter_types():
"""Test that max_distance_km accepts int and float."""
lam_lat_range = np.linspace(44.0, 46.0, 11)
lam_lon_range = np.linspace(0.0, 2.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

global_lats = np.array([45.0, 46.0])
global_lons = np.array([0.0, 2.0])

# Test with int
mask_int = cutout_mask(lam_lats, lam_lons, global_lats, global_lons, max_distance_km=100)
assert isinstance(mask_int, np.ndarray)

# Test with float
mask_float = cutout_mask(lam_lats, lam_lons, global_lats, global_lons, max_distance_km=100.0)
assert isinstance(mask_float, np.ndarray)


def test_cutout_mask_large_grid():
"""Test cutout_mask with a larger, more realistic grid."""
# Create a LAM region (21x21 grid)
lam_lat_range = np.linspace(40.0, 50.0, 21)
lam_lon_range = np.linspace(0.0, 10.0, 21)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

# Create a global grid (31x31 grid)
global_lat_range = np.linspace(30.0, 60.0, 31)
global_lon_range = np.linspace(-10.0, 20.0, 31)
global_lats, global_lons = np.meshgrid(global_lat_range, global_lon_range)
global_lats = global_lats.flatten()
global_lons = global_lons.flatten()

mask = cutout_mask(
lam_lats,
lam_lons,
global_lats,
global_lons,
min_distance_km=150.0,
max_distance_km=300.0,
)

assert isinstance(mask, np.ndarray)
assert mask.shape == (961,) # 31x31 flattened
assert mask.dtype == bool
# Some points should be masked (excluded)
assert np.any(mask)
# Some points should not be masked
assert not np.all(mask)
Loading