diff --git a/src/anemoi/transform/filters/fields/apply_mask.py b/src/anemoi/transform/filters/fields/apply_mask.py index ac6d0602..ef781b5a 100644 --- a/src/anemoi/transform/filters/fields/apply_mask.py +++ b/src/anemoi/transform/filters/fields/apply_mask.py @@ -15,7 +15,7 @@ from anemoi.transform.fields import new_field_from_numpy from anemoi.transform.fields import new_fieldlist_from_list -from anemoi.transform.filter import Filter +from anemoi.transform.filter import SingleFieldFilter from anemoi.transform.filters import filter_registry LOG = logging.getLogger(__name__) @@ -36,8 +36,7 @@ } -@filter_registry.register("apply_mask") -class MaskVariable(Filter): +class ApplyMaskMixin: """A filter to mask variables using an external file. The values of every filtered fields are set to NaN when they are either: @@ -96,7 +95,7 @@ def __init__( ---------- path : str Path to the external file containing the mask. - mask_value : int, optional + mask_value : float, optional Value to be used for masking, by default 1. threshold : float, optional Threshold value for masking, by default None. @@ -121,31 +120,75 @@ def __init__( self._rename = rename - def forward(self, data: ekd.FieldList) -> ekd.FieldList: + def mask(self, data: ekd.Field) -> ekd.Field: """Apply the forward transformation to the data. Parameters ---------- - data : ekd.FieldList + data : ekd.Field Input data to be transformed. Returns ------- - ekd.FieldList + ekd.Field Transformed data. """ - result = [] extra = {} - for field in data: - values = field.to_numpy(flatten=True) - values[self._mask] = np.nan + values = data.to_numpy(flatten=True) + values[self._mask] = np.nan + + if self._rename is not None: + param = data.metadata("param") + name = f"{param}_{self._rename}" + extra["param"] = name - if self._rename is not None: - param = field.metadata("param") - name = f"{param}_{self._rename}" - extra["param"] = name + return new_field_from_numpy(values, template=data, **extra) - result.append(new_field_from_numpy(values, template=field, **extra)) +@filter_registry.register("apply_mask") +class MaskAllVariables(ApplyMaskMixin): + """A filter to mask all variables using an external file.""" + + def forward(self, data: ekd.FieldList) -> ekd.FieldList: + result = [] + for field in data: + masked_field = self.mask(field) + result.append(masked_field) return new_fieldlist_from_list(result) + + +@filter_registry.register("apply_mask_to_param") +class ApplyMaskToParam(ApplyMaskMixin, SingleFieldFilter): + """A filter to mask a specific variable using an external file.""" + + required_inputs = ["param"] + + def __init__( + self, + *, + param: str, + path: str, + mask_value: float | None = None, + threshold: float | None = None, + threshold_operator: ( + Literal["<"] | Literal["<="] | Literal[">"] | Literal[">="] | Literal["=="] | Literal["!="] + ) = ">", + rename: str | None = None, + ): + SingleFieldFilter.__init__(self, param=param) + ApplyMaskMixin.__init__( + self, + path=path, + mask_value=mask_value, + threshold=threshold, + threshold_operator=threshold_operator, + rename=rename, + ) + self.param = param + + def forward_select(self): + return {"param": self.param} + + def forward_transform(self, param: ekd.Field) -> ekd.Field: + return self.mask(param) diff --git a/tests/field_filters/test_apply_mask.py b/tests/field_filters/test_apply_mask.py index 54dfabe8..1850b8c5 100644 --- a/tests/field_filters/test_apply_mask.py +++ b/tests/field_filters/test_apply_mask.py @@ -106,9 +106,70 @@ def test_apply_mask(source, ekd_from_source, mask_name, rename, threshold_option assert np.sum(np.isnan(result)) == expected_mask_count -if __name__ == "__main__": - """Run all test functions that start with 'test_'.""" - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() +@pytest.mark.parametrize( + "threshold_options", + [ + {"mask_value": 0.5}, + {"mask_value": 1}, + {"threshold": 0.5, "threshold_operator": ">"}, + {"threshold": 0.5, "threshold_operator": "<"}, + ], +) +@pytest.mark.parametrize("rename", [None, "renamed"]) +@pytest.mark.parametrize("mask_name", MASK_VALUES.keys()) +@pytest.mark.parametrize("target_param", DATA_VALUES.keys()) +def test_apply_mask_to_param(source, ekd_from_source, target_param, mask_name, rename, threshold_options): + apply_mask = filter_registry.create( + "apply_mask_to_param", param=target_param, path=mask_name, rename=rename, **threshold_options + ) + ekd_from_source.assert_called_once_with("file", mask_name) + + pipeline = source | apply_mask + + input_fields = collect_fields_by_param(source) + output_fields = collect_fields_by_param(pipeline) + + expected_mask = MASK_VALUES[mask_name].copy().flatten() + if "mask_value" in threshold_options: + expected_mask = expected_mask == threshold_options["mask_value"] + else: + operator = {"<": np.less, ">": np.greater}[threshold_options["threshold_operator"]] + expected_mask = operator(expected_mask, threshold_options["threshold"]) + expected_mask_count = np.sum(expected_mask) + + result_param = f"{target_param}_{rename}" if rename else target_param + + # The target param should be masked (and possibly renamed) + assert result_param in output_fields + for input_field, output_field in zip(input_fields[target_param], output_fields[result_param]): + expected_values = input_field.to_numpy(flatten=True).copy() + expected_values[expected_mask] = np.nan + result = output_field.to_numpy(flatten=True) + np.array_equal(expected_values, result, equal_nan=True) + assert np.sum(np.isnan(result)) == expected_mask_count + + # When renamed, the original param name must not appear in the output + if rename: + assert target_param not in output_fields + + # All other params should pass through unchanged + for other_param in DATA_VALUES.keys(): + if other_param == target_param: + continue + assert other_param in output_fields + for input_field, output_field in zip(input_fields[other_param], output_fields[other_param]): + np.testing.assert_array_equal( + input_field.to_numpy(flatten=True), + output_field.to_numpy(flatten=True), + ) + + +def test_apply_mask_to_param_fails_without_param(ekd_from_source): + with pytest.raises((ValueError, TypeError)): + filter_registry.create("apply_mask_to_param", path="all_zeros", mask_value=1) + + +def test_apply_mask_to_param_fails_without_mask_options(ekd_from_source): + with pytest.raises(ValueError): + filter_registry.create("apply_mask_to_param", param="t", path="all_zeros") + ekd_from_source.assert_called_once_with("file", "all_zeros")