Skip to content

Commit 77c164d

Browse files
committed
feat: Add mask per param filter
1 parent 1b6358c commit 77c164d

File tree

2 files changed

+126
-22
lines changed

2 files changed

+126
-22
lines changed

src/anemoi/transform/filters/fields/apply_mask.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from anemoi.transform.fields import new_field_from_numpy
1717
from anemoi.transform.fields import new_fieldlist_from_list
18-
from anemoi.transform.filter import Filter
18+
from anemoi.transform.filter import SingleFieldFilter
1919
from anemoi.transform.filters import filter_registry
2020

2121
LOG = logging.getLogger(__name__)
@@ -36,8 +36,7 @@
3636
}
3737

3838

39-
@filter_registry.register("apply_mask")
40-
class MaskVariable(Filter):
39+
class ApplyMaskMixin:
4140
"""A filter to mask variables using an external file.
4241
4342
The values of every filtered fields are set to NaN when they are either:
@@ -96,7 +95,7 @@ def __init__(
9695
----------
9796
path : str
9897
Path to the external file containing the mask.
99-
mask_value : int, optional
98+
mask_value : float, optional
10099
Value to be used for masking, by default 1.
101100
threshold : float, optional
102101
Threshold value for masking, by default None.
@@ -121,31 +120,75 @@ def __init__(
121120

122121
self._rename = rename
123122

124-
def forward(self, data: ekd.FieldList) -> ekd.FieldList:
123+
def mask(self, data: ekd.Field) -> ekd.Field:
125124
"""Apply the forward transformation to the data.
126125
127126
Parameters
128127
----------
129-
data : ekd.FieldList
128+
data : ekd.Field
130129
Input data to be transformed.
131130
132131
Returns
133132
-------
134-
ekd.FieldList
133+
ekd.Field
135134
Transformed data.
136135
"""
137-
result = []
138136
extra = {}
139-
for field in data:
140137

141-
values = field.to_numpy(flatten=True)
142-
values[self._mask] = np.nan
138+
values = data.to_numpy(flatten=True)
139+
values[self._mask] = np.nan
140+
141+
if self._rename is not None:
142+
param = data.metadata("param")
143+
name = f"{param}_{self._rename}"
144+
extra["param"] = name
143145

144-
if self._rename is not None:
145-
param = field.metadata("param")
146-
name = f"{param}_{self._rename}"
147-
extra["param"] = name
146+
return new_field_from_numpy(values, template=data, **extra)
148147

149-
result.append(new_field_from_numpy(values, template=field, **extra))
150148

149+
@filter_registry.register("apply_mask")
150+
class MaskAllVariables(ApplyMaskMixin):
151+
"""A filter to mask all variables using an external file."""
152+
153+
def forward(self, data: ekd.FieldList) -> ekd.FieldList:
154+
result = []
155+
for field in data:
156+
masked_field = self.mask(field)
157+
result.append(masked_field)
151158
return new_fieldlist_from_list(result)
159+
160+
161+
@filter_registry.register("apply_mask_to_param")
162+
class ApplyMaskToParam(ApplyMaskMixin, SingleFieldFilter):
163+
"""A filter to mask a specific variable using an external file."""
164+
165+
required_inputs = ["param"]
166+
167+
def __init__(
168+
self,
169+
*,
170+
param: str,
171+
path: str,
172+
mask_value: float | None = None,
173+
threshold: float | None = None,
174+
threshold_operator: (
175+
Literal["<"] | Literal["<="] | Literal[">"] | Literal[">="] | Literal["=="] | Literal["!="]
176+
) = ">",
177+
rename: str | None = None,
178+
):
179+
SingleFieldFilter.__init__(self, param=param)
180+
ApplyMaskMixin.__init__(
181+
self,
182+
path=path,
183+
mask_value=mask_value,
184+
threshold=threshold,
185+
threshold_operator=threshold_operator,
186+
rename=rename,
187+
)
188+
self.param = param
189+
190+
def forward_select(self):
191+
return {"param": self.param}
192+
193+
def forward_transform(self, param: ekd.Field) -> ekd.Field:
194+
return self.mask(param)

tests/field_filters/test_apply_mask.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,70 @@ def test_apply_mask(source, ekd_from_source, mask_name, rename, threshold_option
106106
assert np.sum(np.isnan(result)) == expected_mask_count
107107

108108

109-
if __name__ == "__main__":
110-
"""Run all test functions that start with 'test_'."""
111-
for name, obj in list(globals().items()):
112-
if name.startswith("test_") and callable(obj):
113-
print(f"Running {name}...")
114-
obj()
109+
@pytest.mark.parametrize(
110+
"threshold_options",
111+
[
112+
{"mask_value": 0.5},
113+
{"mask_value": 1},
114+
{"threshold": 0.5, "threshold_operator": ">"},
115+
{"threshold": 0.5, "threshold_operator": "<"},
116+
],
117+
)
118+
@pytest.mark.parametrize("rename", [None, "renamed"])
119+
@pytest.mark.parametrize("mask_name", MASK_VALUES.keys())
120+
@pytest.mark.parametrize("target_param", DATA_VALUES.keys())
121+
def test_apply_mask_to_param(source, ekd_from_source, target_param, mask_name, rename, threshold_options):
122+
apply_mask = filter_registry.create(
123+
"apply_mask_to_param", param=target_param, path=mask_name, rename=rename, **threshold_options
124+
)
125+
ekd_from_source.assert_called_once_with("file", mask_name)
126+
127+
pipeline = source | apply_mask
128+
129+
input_fields = collect_fields_by_param(source)
130+
output_fields = collect_fields_by_param(pipeline)
131+
132+
expected_mask = MASK_VALUES[mask_name].copy().flatten()
133+
if "mask_value" in threshold_options:
134+
expected_mask = expected_mask == threshold_options["mask_value"]
135+
else:
136+
operator = {"<": np.less, ">": np.greater}[threshold_options["threshold_operator"]]
137+
expected_mask = operator(expected_mask, threshold_options["threshold"])
138+
expected_mask_count = np.sum(expected_mask)
139+
140+
result_param = f"{target_param}_{rename}" if rename else target_param
141+
142+
# The target param should be masked (and possibly renamed)
143+
assert result_param in output_fields
144+
for input_field, output_field in zip(input_fields[target_param], output_fields[result_param]):
145+
expected_values = input_field.to_numpy(flatten=True).copy()
146+
expected_values[expected_mask] = np.nan
147+
result = output_field.to_numpy(flatten=True)
148+
np.array_equal(expected_values, result, equal_nan=True)
149+
assert np.sum(np.isnan(result)) == expected_mask_count
150+
151+
# When renamed, the original param name must not appear in the output
152+
if rename:
153+
assert target_param not in output_fields
154+
155+
# All other params should pass through unchanged
156+
for other_param in DATA_VALUES.keys():
157+
if other_param == target_param:
158+
continue
159+
assert other_param in output_fields
160+
for input_field, output_field in zip(input_fields[other_param], output_fields[other_param]):
161+
np.testing.assert_array_equal(
162+
input_field.to_numpy(flatten=True),
163+
output_field.to_numpy(flatten=True),
164+
)
165+
166+
167+
def test_apply_mask_to_param_fails_without_param(ekd_from_source):
168+
with pytest.raises((ValueError, TypeError)):
169+
filter_registry.create("apply_mask_to_param", path="all_zeros", mask_value=1)
170+
171+
172+
def test_apply_mask_to_param_fails_without_mask_options(ekd_from_source):
173+
with pytest.raises(ValueError):
174+
filter_registry.create("apply_mask_to_param", param="t", path="all_zeros")
175+
ekd_from_source.assert_called_once_with("file", "all_zeros")

0 commit comments

Comments
 (0)