Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dependencies = [
"earthkit-regrid>=0.4",
]

optional-dependencies.all = [ "anemoi-transform" ]
optional-dependencies.all = [ "anemoi-transform[plots]" ]
optional-dependencies.dev = [ "anemoi-transform[all,docs,tests]" ]

optional-dependencies.docs = [
Expand All @@ -62,7 +62,9 @@ optional-dependencies.docs = [
"termcolor",
]

optional-dependencies.tests = [ "pytest" ]
optional-dependencies.plots = [ "matplotlib" ]

optional-dependencies.tests = [ "pytest", "pytest-skip-slow" ]

urls.Documentation = "https://anemoi-transform.readthedocs.io/"
urls.Homepage = "https://github.com/ecmwf/anemoi-transform/"
Expand Down
92 changes: 59 additions & 33 deletions src/anemoi/transform/commands/make-regrid-file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,54 @@
import logging
import os

import numpy as np

from anemoi.transform.commands import Command
from anemoi.transform.constants import L_1_degree_earth_arc_length_km as L_1d_km

LOG = logging.getLogger(__name__)


def _ds_to_lat_lon(ds):
def _xr_ds_lat_lon(path: str, lat_name: str, lon_name: str) -> tuple[np.ndarray, np.ndarray]:
import xarray as xr

ds = xr.open_dataset(path)
lat = ds[lat_name].values.flatten()
lon = ds[lon_name].values.flatten()
return lat, lon


def _ds_to_lat_lon(path: str) -> tuple[np.ndarray, np.ndarray]:
import earthkit.data as ekd

try:
ds = ekd.from_source("file", path)
return ds[0].grid_points()
except TypeError:
# This is a workaround for datasets that do not have data variables,
# but have latitude and longitude coordinates.
import xarray as xr
# but have "latitude" and "longitude" coordinates.
return _xr_ds_lat_lon(path, "latitude", "longitude")

ds = xr.open_dataset(ds.path)
lat = ds["latitude"].values.flatten()
lon = ds["longitude"].values.flatten()
return lat, lon


def _path_to_lat_lon(path):
def _path_to_lat_lon(path: str) -> tuple[np.ndarray, np.ndarray]:
"""Extract latitudes and longitudes from a file path."""
import earthkit.data as ekd
import numpy as np

if path.endswith(".npz"):
data = np.load(path)
return data["latitudes"], data["longitudes"]
if path.endswith(".zarr"):
from anemoi.datasets import open_dataset
# assume anemoi-dataset first - load with xarray
try:
return _xr_ds_lat_lon(path, "latitudes", "longitudes")
except KeyError:
pass

dataset = open_dataset(path)
return dataset.latitudes, dataset.longitudes
ds = ekd.from_source("file", path)
return _ds_to_lat_lon(ds)
# fallback to earthkit-data
return _ds_to_lat_lon(path)


def check_duplicate_latlons(input_file, latitudes, longitudes):
def check_duplicate_latlons(input_file: str, latitudes: np.ndarray, longitudes: np.ndarray) -> None:
LOG.info(f"Checking for duplicate lat/lon pairs in {input_file}...")
seen = set()
for lat, lon in zip(latitudes, longitudes):
Expand All @@ -58,21 +68,13 @@ def check_duplicate_latlons(input_file, latitudes, longitudes):
seen.add((lat, lon))


def round_lat_lon(latitudes, longitudes, rounding):
def round_lat_lon(latitudes: np.ndarray, longitudes: np.ndarray, num_decimals: int) -> tuple[np.ndarray, np.ndarray]:
import numpy as np

LOG.info(f"Rounding latitudes and longitudes to {rounding} decimal places ({L_1d_km / ( 10 ) ** rounding} m).")
return np.round(latitudes, rounding), np.round(longitudes, rounding)


def _lat_lon_plot(lat, lon, plot: str) -> None:
import matplotlib.pyplot as plt
import numpy as np

lon = np.where(lon >= 180, lon - 360, lon)
plt.figure(figsize=(10, 5))
plt.scatter(lon, lat, s=0.1, c="k")
plt.savefig(plot)
LOG.info(
f"Rounding latitudes and longitudes to {num_decimals} decimal places ({L_1d_km / (10) ** num_decimals} m)."
)
return np.round(latitudes, num_decimals), np.round(longitudes, num_decimals)


class MakeMIRMatrix:
Expand Down Expand Up @@ -201,19 +203,43 @@ def run(self, args: argparse.Namespace) -> None:
global_lat, global_lon = _path_to_lat_lon(args.global_grid)

MakeGlobalOnLamMask.make_global_on_lam_mask(
lam_lat, lam_lon, global_lat, global_lon, output=args.output, plot=args.plot, distance_km=args.distance_km
lam_lat,
lam_lon,
global_lat,
global_lon,
output=args.output,
plot_path=args.plot,
distance_km=args.distance_km,
)

@staticmethod
def make_global_on_lam_mask(lam_lat, lam_lon, global_lat, global_lon, output, plot: str | None = None, **kwargs):
def _lat_lon_plot(lat: np.ndarray, lon: np.ndarray, plot_path: str) -> None:
import matplotlib.pyplot as plt
import numpy as np

lon = np.where(lon >= 180, lon - 360, lon)
plt.figure(figsize=(10, 5))
plt.scatter(lon, lat, s=0.1, c="k")
plt.savefig(plot_path)

@staticmethod
def make_global_on_lam_mask(
lam_lat: np.ndarray,
lam_lon: np.ndarray,
global_lat: np.ndarray,
global_lon: np.ndarray,
output: str,
plot_path: str | None = None,
**kwargs,
) -> None:
import numpy as np

from anemoi.transform.spatial import global_on_lam_mask

mask = global_on_lam_mask(lam_lat, lam_lon, global_lat, global_lon, **kwargs)
np.savez(output, mask=mask)
if plot:
_plot_lat_lon(global_lat[mask], global_lons[mask], plot)
if plot_path:
MakeGlobalOnLamMask._lat_lon_plot(global_lat[mask], global_lon[mask], plot_path)


OPTIONS = {
Expand Down
18 changes: 8 additions & 10 deletions src/anemoi/transform/filters/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LOG = logging.getLogger(__name__)


def as_gridspec(grid: str | dict[str, Any] | None = None) -> dict[str, Any] | None:
def as_gridspec(grid: str | dict[str, Any] | None) -> dict[str, Any] | None:
"""Convert grid specification to a dictionary format.

Parameters
Expand All @@ -36,7 +36,7 @@ def as_gridspec(grid: str | dict[str, Any] | None = None) -> dict[str, Any] | No

Returns
-------
dict[str, Any]
dict[str, Any] | None
The grid specification as a dictionary.
"""
if grid is None:
Expand All @@ -48,7 +48,7 @@ def as_gridspec(grid: str | dict[str, Any] | None = None) -> dict[str, Any] | No
return grid


def as_griddata(grid: str | Field | dict[str, Any] | None = None) -> dict[str, Any] | None:
def as_griddata(grid: str | Field | dict[str, Any] | None) -> dict[str, Any] | None:
"""Convert grid data to a dictionary format.

Parameters
Expand All @@ -58,7 +58,7 @@ def as_griddata(grid: str | Field | dict[str, Any] | None = None) -> dict[str, A

Returns
-------
dict[str, Any]
dict[str, Any] | None
The grid data as a dictionary.
"""
if grid is None:
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(self, *, matrix: str, check: bool) -> None:
-------------
matrix : str
The regrid matrix file path.
check : bool, default = False
check : bool
Whether to perform checks.
"""
import numpy as np
Expand Down Expand Up @@ -317,9 +317,7 @@ class ScipyKDTreeNearestNeighbours:

nearest_grid_points = None

def __init__(
self, *, in_grid: Any, out_grid: Any, method: str, matrix: str | None = None, check: bool = False
) -> None:
def __init__(self, *, in_grid: Any, out_grid: Any, method: str, check: bool = False) -> None:
"""Parameters
-------------
in_grid : Any
Expand Down Expand Up @@ -394,7 +392,7 @@ def __init__(self, *, mask: str, check: bool) -> None:
-------------
mask : str
The mask file path.
check : bool, default = False
check : bool
Whether to perform checks.
"""

Expand Down Expand Up @@ -492,7 +490,7 @@ def make_interpolator(
The regrid matrix file path.
mask : str, optional
The mask file path.
check : bool, optional, default = False
check : bool, optional
Whether to perform checks.

Returns
Expand Down
39 changes: 18 additions & 21 deletions src/anemoi/transform/grids/named.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@


@cached(collection="grids", encoding="npz")
def _grids(name: str | list[float] | tuple[float, ...]) -> bytes:
"""Get grid data by name.
def _get_grid_data(grid_id: str | list[float] | tuple[float, float]) -> bytes:
"""Get grid data from the grid registry by identifier.

Parameters
----------
name : str
The name of the grid
grid_id : str | list[float] | tuple[float, float]
The identifier of the grid, either a string like "o96" or a tuple/list
of two numbers (describing the resolution).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For info: renamed function and added more info to docstring

Copy link
Contributor

@yoel-zerah yoel-zerah Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id is a bad argument name, because it is a python built-in function, I propose grid_id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to grid_id. Thanks!


Returns
-------
Expand All @@ -40,30 +41,29 @@ def _grids(name: str | list[float] | tuple[float, ...]) -> bytes:
"""
from anemoi.utils.config import load_config

if isinstance(name, (tuple, list)):
assert len(name) == 2, "Grid name must be a list or a tuple of length 2"
assert all(isinstance(i, (int, float)) for i in name), "Grid name must be a list or a tuple of numbers"
if name[0] == name[1]:
name = str(float(name[0]))
if isinstance(grid_id, (tuple, list)):
assert len(grid_id) == 2, "Grid name must be a list or a tuple of length 2"
assert all(isinstance(i, (int, float)) for i in grid_id), "Grid name must be a list or a tuple of numbers"
if grid_id[0] == grid_id[1]:
grid_id = str(float(grid_id[0]))
else:
name = str(float(name[0])) + "x" + str(float(name[1]))
name = name.replace(".", "p")
grid_id = str(float(grid_id[0])) + "x" + str(float(grid_id[1]))
grid_id = grid_id.replace(".", "p")

user_path = load_config().get("utils", {}).get("grids_path")
if user_path:
path = os.path.expanduser(os.path.join(user_path, f"grid-{name}.npz"))
path = os.path.expanduser(os.path.join(user_path, f"grid-{grid_id}.npz"))
if os.path.exists(path):
LOG.warning("Loading grids from custom user path %s", path)
with open(path, "rb") as f:
return f.read()
else:
LOG.warning("Custom user path %s does not exist", path)

# To add a grid
# To generate a grid
# anemoi-transform get-grid --source mars grid=o400,levtype=sfc,param=2t grid-o400.npz
# nexus-cli -u xxxx -p yyyy -s GET_INSTANCE --repository anemoi upload --remote-path grids --local-path grid-o400.npz

url = GRIDS_URL_PATTERN.format(name=name.lower())
url = GRIDS_URL_PATTERN.format(name=grid_id.lower())
LOG.warning("Downloading grids from %s", url)
response = requests.get(url)
response.raise_for_status()
Expand All @@ -83,9 +83,6 @@ def lookup(name: str | list[float] | tuple[float, ...]) -> dict:
dict
The grid data
"""
if isinstance(name, str) and name.endswith(".npz"):
return dict(np.load(name))

data = _grids(name)
npz = np.load(BytesIO(data))
return dict(npz)
is_npz_file = isinstance(name, str) and name.endswith(".npz")
data = name if is_npz_file else BytesIO(_get_grid_data(name))
return dict(np.load(data))
26 changes: 0 additions & 26 deletions src/anemoi/transform/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,29 +596,3 @@ def nearest_grid_points(
else:
_, indices = cKDTree(source_points).query(target_points, k=1, distance_upper_bound=max_distance)
return indices


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

import earthkit.data as ekd

glob = ekd.from_source("file", "tmp/era5.grib")
lam = ekd.from_source("file", "tmp/carra-east.grib")

global_lats, global_lons = glob[0].grid_points()

lats, lons = lam[0].grid_points()

mask = global_on_lam_mask(lats, lons, global_lats, global_lons)
print(len(mask))
import matplotlib.pyplot as plt

lat = global_lats[mask]
lon = global_lons[mask]
lon = np.where(lon >= 180, lon - 360, lon)

fig = plt.figure(figsize=(10, 5))
plt.scatter(lon, lat, s=0.1, c="k")
# plt.scatter(lons, lats, s=0.01)
plt.savefig("cutout.png")
7 changes: 3 additions & 4 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import pytest
from anemoi.utils.testing import cli_testing
from anemoi.utils.testing import skip_if_missing_command
from anemoi.utils.testing import skip_if_offline
Expand All @@ -25,11 +24,10 @@

from .utils import compare_npz_files

LOG = logging.getLogger(__name__)


@skip_if_offline
@skip_if_missing_command("mir")
@pytest.mark.slow
def test_make_regrid_matrix(get_test_data):
era5 = get_test_data("anemoi-transform/filters/regrid/2t-ea.grib")
carra = get_test_data("anemoi-transform/filters/regrid/2t-rr.grib")
Expand Down Expand Up @@ -61,6 +59,7 @@ def test_regrid_matrix(get_test_data, test_source):


@skip_if_offline
@pytest.mark.slow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are we using this? as in do those run then at PR level, nightly or which is the frequency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default we won't be running them in the CI, which is as the same as in anemoi-datasets. I run these locally before pushing up a PR (as with datasets)

def test_make_regrid_mask(get_test_data):
era5 = get_test_data("anemoi-transform/filters/regrid/2t-ea.grib")
carra = get_test_data("anemoi-transform/filters/regrid/2t-rr.grib")
Expand Down
3 changes: 0 additions & 3 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
import logging
from collections import defaultdict

import numpy as np

from anemoi.transform.fields import new_fieldlist_from_list
from anemoi.transform.source import Source

LOG = logging.getLogger(__name__)


def collect_fields_by_param(pipeline):
fields = defaultdict(list)
Expand Down
Loading