Skip to content

Commit 820eb39

Browse files
clairemerkerfrazanepre-commit-ci[bot]gmertes
authored
feat(apply mask): apply a boolean mask when opening a dataset (#496)
## Description This PR introduces `mask` as an argument to `open_dataset`. This argument allows to apply an arbitrary spatial boolean mask from a .npy file to the dataset. ## What problem does this change solve? It is a new feature for `open_dataset` that allows to select arbitrary spatial points of a dataset, it could be used e.g. to thin data on an irregular grid or select a region of interest. It was tested to reduce the density of grid points of the triangular grid of the ICON model. The plots show a zoom over Switzerland with the dense grid at approx. 1km resolution and the grid after applying a mask selecting only every second triangle. Original grid: <img width="500" alt="image" src="https://github.com/user-attachments/assets/9a8e1036-c77a-4f35-ab1f-1ef67014e363" /> Half the points masked using external mask: <img width="500" alt="image" src="https://github.com/user-attachments/assets/d3cfcb36-3f5f-43a6-aee0-c29b1d9f66d7" /> ## What issue or task does this change relate to? ## Additional notes ## ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-datasets start --> ---- 📚 Documentation preview 📚: https://anemoi-datasets--496.org.readthedocs.build/en/496/ <!-- readthedocs-preview anemoi-datasets end --> --------- Co-authored-by: Francesco Zanetta <62377868+frazane@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gert Mertes <13658335+gmertes@users.noreply.github.com>
1 parent 859ed19 commit 820eb39

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

docs/using/grids.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ dataset:
3737
:width: 75%
3838
:align: center
3939

40+
*********
41+
masking
42+
*********
43+
44+
You can apply an arbitrary spatial mask to a dataset by specifying the
45+
``mask`` parameter in the ``open_dataset`` function. The mask must be a
46+
NumPy .npy file containing a boolean array, where True indicates points
47+
to be kept and False indicates points to be removed.
48+
49+
.. code:: python
50+
51+
ds = open_dataset(dataset, mask="path/to/mask.npy")
52+
53+
The mask array must have the same total number of grid points and
54+
dimension as the dataset. Please note that this masking will not be
55+
automatically applied to the input data provided during inference.
56+
Users must ensure they are applying the mask explicitly, for instance
57+
via the use of pre-processors.
58+
4059
******
4160
area
4261
******

src/anemoi/datasets/data/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,12 @@ def __subset(self, **kwargs: Any) -> "Dataset":
245245

246246
return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
247247

248+
if "mask" in kwargs:
249+
from .masked import Masking
250+
251+
mask_file = kwargs.pop("mask")
252+
return Masking(self, mask_file)._subset(**kwargs).mutate()
253+
248254
# Note: trim_edge should go before thinning
249255
if "trim_edge" in kwargs:
250256
from .masked import TrimEdge

src/anemoi/datasets/data/masked.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import logging
1212
from functools import cached_property
13+
from pathlib import Path
1314
from typing import Any
1415

1516
import numpy as np
@@ -207,6 +208,60 @@ def field_shape(self) -> Shape:
207208
return x, y
208209

209210

211+
class Masking(Masked):
212+
"""A class that applies a precomputed boolean mask from a .npy file."""
213+
214+
def __init__(self, forward: Dataset, mask_file: str) -> None:
215+
"""Initialize the Masking class.
216+
217+
Parameters
218+
----------
219+
forward : Dataset
220+
The dataset to be masked.
221+
mask_file : str
222+
Path to a .npy file containing a boolean mask of same shape as fields.
223+
"""
224+
self.mask_file = mask_file
225+
226+
# Check path
227+
if not Path(self.mask_file).exists():
228+
raise FileNotFoundError(f"Mask file not found: {self.mask_file}")
229+
# Load mask
230+
try:
231+
mask = np.load(self.mask_file)
232+
except Exception as e:
233+
raise ValueError(f"Could not load data from {mask_file}: {e}")
234+
235+
if mask.dtype != bool:
236+
raise ValueError(f"Mask file {mask_file} does not contain boolean values.")
237+
if mask.shape != forward.field_shape:
238+
raise ValueError(f"Mask length {mask.shape} does not match field size {forward.field_shape}.")
239+
if sum(mask) == 0:
240+
LOG.warning(f"Mask in {mask_file} eliminates all points in field.")
241+
242+
super().__init__(forward, mask)
243+
244+
def tree(self) -> Node:
245+
"""Get the tree representation of the dataset.
246+
247+
Returns
248+
-------
249+
Node
250+
The tree representation of the dataset.
251+
"""
252+
return Node(self, [self.forward.tree()], mask_file=self.mask_file)
253+
254+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
255+
"""Get the metadata specific to the Masking subclass.
256+
257+
Returns
258+
-------
259+
Dict[str, Any]
260+
The metadata specific to the Masking subclass.
261+
"""
262+
return dict(mask_file=self.mask_file)
263+
264+
210265
class Cropping(Masked):
211266
"""A class to represent a cropped dataset."""
212267

tests/test_data.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from anemoi.datasets.data.ensemble import Ensemble
3131
from anemoi.datasets.data.grids import GridsBase
3232
from anemoi.datasets.data.join import Join
33+
from anemoi.datasets.data.masked import Masking
3334
from anemoi.datasets.data.misc import as_first_date
3435
from anemoi.datasets.data.misc import as_last_date
3536
from anemoi.datasets.data.padded import Padded
@@ -1400,6 +1401,67 @@ def test_invalid_trim_edge() -> None:
14001401
)
14011402

14021403

1404+
@mockup_open_zarr
1405+
def test_masking() -> None:
1406+
"""Test masking a dataset."""
1407+
test_mask = np.array([True, False, True, True, True, True, False, False, True, False])
1408+
with (
1409+
patch("anemoi.datasets.data.masked.np.load", return_value=test_mask),
1410+
patch("anemoi.datasets.data.masked.Path.exists", return_value=True),
1411+
):
1412+
1413+
test = DatasetTester("test-2021-2022-6h-o96-abcd", mask="./test_mask.npy")
1414+
test.run(
1415+
expected_class=Masking,
1416+
expected_length=365 * 2 * 4,
1417+
date_to_row=lambda date: simple_row(date, "abcd")[..., test_mask],
1418+
start_date=datetime.datetime(2021, 1, 1),
1419+
time_increment=datetime.timedelta(hours=6),
1420+
expected_shape=(365 * 2 * 4, 4, 1, sum(test_mask)),
1421+
expected_variables="abcd",
1422+
expected_name_to_index="abcd",
1423+
statistics_reference_dataset="test-2021-2022-6h-o96-abcd",
1424+
statistics_reference_variables="abcd",
1425+
)
1426+
return
1427+
1428+
1429+
@mockup_open_zarr
1430+
def test_masking_wrong_mask_dims() -> None:
1431+
"""Test masking a dataset (wrong dims in mask)."""
1432+
test_mask = np.array([True, False, True, True, True, True, False, False, True])
1433+
with (
1434+
patch("anemoi.datasets.data.masked.np.load", return_value=test_mask),
1435+
patch("anemoi.datasets.data.masked.Path.exists", return_value=True),
1436+
):
1437+
with pytest.raises(ValueError):
1438+
_ = DatasetTester("test-2021-2022-6h-o96-abcd", mask="./test_mask.npy")
1439+
return
1440+
1441+
1442+
@mockup_open_zarr
1443+
def test_masking_mask_file_not_found() -> None:
1444+
"""Test masking a dataset (mask file not found)."""
1445+
test_mask = np.array([True, False, True, True, True, True, False, False, True, False])
1446+
with patch("anemoi.datasets.data.masked.np.load", return_value=test_mask):
1447+
with pytest.raises(FileNotFoundError):
1448+
_ = DatasetTester("test-2021-2022-6h-o96-abcd", mask="./test_mask.npy")
1449+
return
1450+
1451+
1452+
@mockup_open_zarr
1453+
def test_masking_wrong_dtype() -> None:
1454+
"""Test masking a dataset (mask file not found)."""
1455+
test_mask = np.array([1, 0, 1, 1, 1, 1, 0, 0, 1, 0])
1456+
with (
1457+
patch("anemoi.datasets.data.masked.np.load", return_value=test_mask),
1458+
patch("anemoi.datasets.data.masked.Path.exists", return_value=True),
1459+
):
1460+
with pytest.raises(ValueError):
1461+
_ = DatasetTester("test-2021-2022-6h-o96-abcd", mask="./test_mask.npy")
1462+
return
1463+
1464+
14031465
def test_save_dataset() -> None:
14041466
"""Test save datasets."""
14051467

0 commit comments

Comments
 (0)