Skip to content

Commit 4dbc7e5

Browse files
refactor: convert apply-mask filter to SingleFieldFilter (#250)
## Description Converts the `apply-mask` filter to be a `SingleFieldFilter` since it works only on a single field at a time. ## Additional notes ## (Makes the implementation of #249 much more straightforward) ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 1b6358c commit 4dbc7e5

File tree

1 file changed

+40
-46
lines changed

1 file changed

+40
-46
lines changed

src/anemoi/transform/filters/fields/apply_mask.py

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88
# nor does it submit to any jurisdiction.
99

1010
import logging
11-
from typing import Literal
1211

1312
import earthkit.data as ekd
1413
import numpy as np
1514

16-
from anemoi.transform.fields import new_field_from_numpy
17-
from anemoi.transform.fields import new_fieldlist_from_list
18-
from anemoi.transform.filter import Filter
15+
from anemoi.transform.filter import SingleFieldFilter
1916
from anemoi.transform.filters import filter_registry
2017

2118
LOG = logging.getLogger(__name__)
@@ -37,7 +34,7 @@
3734

3835

3936
@filter_registry.register("apply_mask")
40-
class MaskVariable(Filter):
37+
class MaskVariable(SingleFieldFilter):
4138
"""A filter to mask variables using an external file.
4239
4340
The values of every filtered fields are set to NaN when they are either:
@@ -81,19 +78,23 @@ class MaskVariable(Filter):
8178
8279
"""
8380

84-
def __init__(
85-
self,
86-
*,
87-
path: str,
88-
mask_value: float | None = None,
89-
threshold: float | None = None,
90-
threshold_operator: Literal["<", "<=", ">", ">=", "==", "!="] = ">",
91-
rename: str | None = None,
92-
):
93-
"""Initialize the MaskVariable filter.
94-
95-
Parameters
96-
----------
81+
# path: str
82+
required_inputs = ("path",)
83+
# mask_value: float | None,
84+
# threshold: float | None,
85+
# threshold_operator: Literal["<", "<=", ">", ">=", "==", "!="],
86+
# rename: str | None,
87+
optional_inputs = {
88+
"mask_value": None,
89+
"threshold": None,
90+
"threshold_operator": ">",
91+
"rename": None,
92+
}
93+
94+
def prepare_filter(self):
95+
"""Setup the MaskVariable filter.
96+
97+
Note:
9798
path : str
9899
Path to the external file containing the mask.
99100
mask_value : int, optional
@@ -104,48 +105,41 @@ def __init__(
104105
New name for the masked variable, by default None.
105106
"""
106107

107-
mask = ekd.from_source("file", path)[0].to_numpy(flatten=True)
108+
mask = ekd.from_source("file", self.path)[0].to_numpy(flatten=True)
108109

109-
if mask_value is None and threshold is None:
110+
if self.mask_value is None and self.threshold is None:
110111
raise ValueError("Either `mask_value` or `threshold` must be provided.")
111112

112-
if threshold is not None:
113-
if threshold_operator not in OPERATORS:
113+
if self.threshold is not None:
114+
if self.threshold_operator not in OPERATORS:
114115
raise ValueError(
115-
f"Invalid threshold operator: {threshold_operator}. "
116+
f"Invalid threshold operator: {self.threshold_operator}. "
116117
f"Valid operators are: {', '.join(OPERATORS.keys())}."
117118
)
118-
self._mask = OPERATORS[threshold_operator](mask, threshold)
119+
self.mask = OPERATORS[self.threshold_operator](mask, self.threshold)
119120
else:
120-
self._mask = mask == mask_value
121+
self.mask = mask == self.mask_value
121122

122-
self._rename = rename
123-
124-
def forward(self, data: ekd.FieldList) -> ekd.FieldList:
125-
"""Apply the forward transformation to the data.
123+
def forward_transform(self, field: ekd.Field) -> ekd.Field:
124+
"""Apply the forward transformation to the field.
126125
127126
Parameters
128127
----------
129-
data : ekd.FieldList
130-
Input data to be transformed.
128+
field : ekd.Field
129+
Input field to be transformed.
131130
132131
Returns
133132
-------
134-
ekd.FieldList
135-
Transformed data.
133+
ekd.Field
134+
Transformed field.
136135
"""
137-
result = []
138-
extra = {}
139-
for field in data:
140-
141-
values = field.to_numpy(flatten=True)
142-
values[self._mask] = np.nan
143-
144-
if self._rename is not None:
145-
param = field.metadata("param")
146-
name = f"{param}_{self._rename}"
147-
extra["param"] = name
136+
metadata = {}
137+
values = field.to_numpy(flatten=True)
138+
values[self.mask] = np.nan
148139

149-
result.append(new_field_from_numpy(values, template=field, **extra))
140+
if self.rename is not None:
141+
param = field.metadata("param")
142+
name = f"{param}_{self.rename}"
143+
metadata["param"] = name
150144

151-
return new_fieldlist_from_list(result)
145+
return self.new_field_from_numpy(values, template=field, **metadata)

0 commit comments

Comments
 (0)