Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions src/anemoi/transform/filters/fields/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from anemoi.transform.fields import new_field_from_numpy
from anemoi.transform.fields import new_fieldlist_from_list
from anemoi.transform.filter import Filter
from anemoi.transform.filter import SingleFieldFilter
from anemoi.transform.filters import filter_registry

LOG = logging.getLogger(__name__)
Expand All @@ -36,8 +36,7 @@
}


@filter_registry.register("apply_mask")
class MaskVariable(Filter):
class ApplyMaskMixin:
"""A filter to mask variables using an external file.

The values of every filtered fields are set to NaN when they are either:
Expand Down Expand Up @@ -96,7 +95,7 @@ def __init__(
----------
path : str
Path to the external file containing the mask.
mask_value : int, optional
mask_value : float, optional
Value to be used for masking, by default 1.
threshold : float, optional
Threshold value for masking, by default None.
Expand All @@ -121,31 +120,75 @@ def __init__(

self._rename = rename

def forward(self, data: ekd.FieldList) -> ekd.FieldList:
def mask(self, data: ekd.Field) -> ekd.Field:
"""Apply the forward transformation to the data.

Parameters
----------
data : ekd.FieldList
data : ekd.Field
Input data to be transformed.

Returns
-------
ekd.FieldList
ekd.Field
Transformed data.
"""
result = []
extra = {}
for field in data:

values = field.to_numpy(flatten=True)
values[self._mask] = np.nan
values = data.to_numpy(flatten=True)
values[self._mask] = np.nan

if self._rename is not None:
param = data.metadata("param")
name = f"{param}_{self._rename}"
extra["param"] = name

if self._rename is not None:
param = field.metadata("param")
name = f"{param}_{self._rename}"
extra["param"] = name
return new_field_from_numpy(values, template=data, **extra)

result.append(new_field_from_numpy(values, template=field, **extra))

@filter_registry.register("apply_mask")
class MaskAllVariables(ApplyMaskMixin):
"""A filter to mask all variables using an external file."""

def forward(self, data: ekd.FieldList) -> ekd.FieldList:
result = []
for field in data:
masked_field = self.mask(field)
result.append(masked_field)
return new_fieldlist_from_list(result)


@filter_registry.register("apply_mask_to_param")
class ApplyMaskToParam(ApplyMaskMixin, SingleFieldFilter):
"""A filter to mask a specific variable using an external file."""

required_inputs = ["param"]

def __init__(
self,
*,
param: str,
path: str,
mask_value: float | None = None,
threshold: float | None = None,
threshold_operator: (
Literal["<"] | Literal["<="] | Literal[">"] | Literal[">="] | Literal["=="] | Literal["!="]
) = ">",
rename: str | None = None,
):
SingleFieldFilter.__init__(self, param=param)
ApplyMaskMixin.__init__(
self,
path=path,
mask_value=mask_value,
threshold=threshold,
threshold_operator=threshold_operator,
rename=rename,
)
self.param = param

def forward_select(self):
return {"param": self.param}

def forward_transform(self, param: ekd.Field) -> ekd.Field:
return self.mask(param)
73 changes: 67 additions & 6 deletions tests/field_filters/test_apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,70 @@ def test_apply_mask(source, ekd_from_source, mask_name, rename, threshold_option
assert np.sum(np.isnan(result)) == expected_mask_count


if __name__ == "__main__":
"""Run all test functions that start with 'test_'."""
for name, obj in list(globals().items()):
if name.startswith("test_") and callable(obj):
print(f"Running {name}...")
obj()
@pytest.mark.parametrize(
"threshold_options",
[
{"mask_value": 0.5},
{"mask_value": 1},
{"threshold": 0.5, "threshold_operator": ">"},
{"threshold": 0.5, "threshold_operator": "<"},
],
)
@pytest.mark.parametrize("rename", [None, "renamed"])
@pytest.mark.parametrize("mask_name", MASK_VALUES.keys())
@pytest.mark.parametrize("target_param", DATA_VALUES.keys())
def test_apply_mask_to_param(source, ekd_from_source, target_param, mask_name, rename, threshold_options):
apply_mask = filter_registry.create(
"apply_mask_to_param", param=target_param, path=mask_name, rename=rename, **threshold_options
)
ekd_from_source.assert_called_once_with("file", mask_name)

pipeline = source | apply_mask

input_fields = collect_fields_by_param(source)
output_fields = collect_fields_by_param(pipeline)

expected_mask = MASK_VALUES[mask_name].copy().flatten()
if "mask_value" in threshold_options:
expected_mask = expected_mask == threshold_options["mask_value"]
else:
operator = {"<": np.less, ">": np.greater}[threshold_options["threshold_operator"]]
expected_mask = operator(expected_mask, threshold_options["threshold"])
expected_mask_count = np.sum(expected_mask)

result_param = f"{target_param}_{rename}" if rename else target_param

# The target param should be masked (and possibly renamed)
assert result_param in output_fields
for input_field, output_field in zip(input_fields[target_param], output_fields[result_param]):
expected_values = input_field.to_numpy(flatten=True).copy()
expected_values[expected_mask] = np.nan
result = output_field.to_numpy(flatten=True)
np.array_equal(expected_values, result, equal_nan=True)
assert np.sum(np.isnan(result)) == expected_mask_count

# When renamed, the original param name must not appear in the output
if rename:
assert target_param not in output_fields

# All other params should pass through unchanged
for other_param in DATA_VALUES.keys():
if other_param == target_param:
continue
assert other_param in output_fields
for input_field, output_field in zip(input_fields[other_param], output_fields[other_param]):
np.testing.assert_array_equal(
input_field.to_numpy(flatten=True),
output_field.to_numpy(flatten=True),
)


def test_apply_mask_to_param_fails_without_param(ekd_from_source):
with pytest.raises((ValueError, TypeError)):
filter_registry.create("apply_mask_to_param", path="all_zeros", mask_value=1)


def test_apply_mask_to_param_fails_without_mask_options(ekd_from_source):
with pytest.raises(ValueError):
filter_registry.create("apply_mask_to_param", param="t", path="all_zeros")
ekd_from_source.assert_called_once_with("file", "all_zeros")
Loading