Skip to content

Commit 8f4ed7c

Browse files
feat: ability for apply-mask to work on a single param (#251)
## Description Alternative implementation of #249. Add the ability for the `apply_mask` filter to work only on a single param. ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 4dbc7e5 commit 8f4ed7c

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/anemoi/transform/filters/fields/apply_mask.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,13 @@ class MaskVariable(SingleFieldFilter):
8484
# threshold: float | None,
8585
# threshold_operator: Literal["<", "<=", ">", ">=", "==", "!="],
8686
# rename: str | None,
87+
# param: str | None,
8788
optional_inputs = {
8889
"mask_value": None,
8990
"threshold": None,
9091
"threshold_operator": ">",
9192
"rename": None,
93+
"param": None,
9294
}
9395

9496
def prepare_filter(self):
@@ -120,6 +122,11 @@ def prepare_filter(self):
120122
else:
121123
self.mask = mask == self.mask_value
122124

125+
def forward_select(self):
126+
if self.param is not None:
127+
return {"param": self.param}
128+
return {}
129+
123130
def forward_transform(self, field: ekd.Field) -> ekd.Field:
124131
"""Apply the forward transformation to the field.
125132

tests/field_filters/test_apply_mask.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,43 @@ def test_apply_mask(source, ekd_from_source, mask_name, rename, threshold_option
102102
expected_values = input_field.to_numpy(flatten=True).copy()
103103
expected_values[expected_mask] = np.nan
104104
result = output_field.to_numpy(flatten=True)
105-
np.array_equal(expected_values, result, equal_nan=True)
105+
assert np.array_equal(expected_values, result, equal_nan=True)
106106
assert np.sum(np.isnan(result)) == expected_mask_count
107107

108108

109+
def test_apply_mask_only_single_param(source, ekd_from_source):
110+
apply_mask = filter_registry.create(
111+
"apply_mask",
112+
path="mixed_floats",
113+
threshold=0.5,
114+
threshold_operator=">",
115+
param="t",
116+
)
117+
ekd_from_source.assert_called_once_with("file", "mixed_floats")
118+
119+
pipeline = source | apply_mask
120+
121+
input_fields = collect_fields_by_param(source)
122+
output_fields = collect_fields_by_param(pipeline)
123+
124+
expected_mask = MASK_VALUES["mixed_floats"].copy().flatten()
125+
expected_mask = np.greater(expected_mask, 0.5)
126+
expected_mask_count = np.sum(expected_mask)
127+
128+
for param in DATA_VALUES.keys():
129+
assert param in output_fields
130+
for input_field, output_field in zip(input_fields[param], output_fields[param]):
131+
if param == "t":
132+
# only mask t
133+
expected_values = input_field.to_numpy(flatten=True).copy()
134+
expected_values[expected_mask] = np.nan
135+
result = output_field.to_numpy(flatten=True)
136+
assert np.array_equal(expected_values, result, equal_nan=True)
137+
assert np.sum(np.isnan(result)) == expected_mask_count
138+
else:
139+
assert np.array_equal(input_field.to_numpy(flatten=True), output_field.to_numpy(flatten=True))
140+
141+
109142
if __name__ == "__main__":
110143
"""Run all test functions that start with 'test_'."""
111144
for name, obj in list(globals().items()):

0 commit comments

Comments
 (0)