Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 40 additions & 46 deletions src/anemoi/transform/filters/fields/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
# nor does it submit to any jurisdiction.

import logging
from typing import Literal

import earthkit.data as ekd
import numpy as np

from anemoi.transform.fields import new_field_from_numpy
from anemoi.transform.fields import new_fieldlist_from_list
from anemoi.transform.filter import Filter
from anemoi.transform.filter import SingleFieldFilter
from anemoi.transform.filters import filter_registry

LOG = logging.getLogger(__name__)
Expand All @@ -37,7 +34,7 @@


@filter_registry.register("apply_mask")
class MaskVariable(Filter):
class MaskVariable(SingleFieldFilter):
"""A filter to mask variables using an external file.

The values of every filtered fields are set to NaN when they are either:
Expand Down Expand Up @@ -81,19 +78,23 @@ class MaskVariable(Filter):

"""

def __init__(
self,
*,
path: str,
mask_value: float | None = None,
threshold: float | None = None,
threshold_operator: Literal["<", "<=", ">", ">=", "==", "!="] = ">",
rename: str | None = None,
):
"""Initialize the MaskVariable filter.

Parameters
----------
# path: str
required_inputs = ("path",)
# mask_value: float | None,
# threshold: float | None,
# threshold_operator: Literal["<", "<=", ">", ">=", "==", "!="],
# rename: str | None,
optional_inputs = {
"mask_value": None,
"threshold": None,
"threshold_operator": ">",
"rename": None,
}

def prepare_filter(self):
"""Setup the MaskVariable filter.

Note:
path : str
Path to the external file containing the mask.
mask_value : int, optional
Expand All @@ -104,48 +105,41 @@ def __init__(
New name for the masked variable, by default None.
"""

mask = ekd.from_source("file", path)[0].to_numpy(flatten=True)
mask = ekd.from_source("file", self.path)[0].to_numpy(flatten=True)

if mask_value is None and threshold is None:
if self.mask_value is None and self.threshold is None:
raise ValueError("Either `mask_value` or `threshold` must be provided.")

if threshold is not None:
if threshold_operator not in OPERATORS:
if self.threshold is not None:
if self.threshold_operator not in OPERATORS:
raise ValueError(
f"Invalid threshold operator: {threshold_operator}. "
f"Invalid threshold operator: {self.threshold_operator}. "
f"Valid operators are: {', '.join(OPERATORS.keys())}."
)
self._mask = OPERATORS[threshold_operator](mask, threshold)
self.mask = OPERATORS[self.threshold_operator](mask, self.threshold)
else:
self._mask = mask == mask_value
self.mask = mask == self.mask_value

self._rename = rename

def forward(self, data: ekd.FieldList) -> ekd.FieldList:
"""Apply the forward transformation to the data.
def forward_transform(self, field: ekd.Field) -> ekd.Field:
"""Apply the forward transformation to the field.

Parameters
----------
data : ekd.FieldList
Input data to be transformed.
field : ekd.Field
Input field to be transformed.

Returns
-------
ekd.FieldList
Transformed data.
ekd.Field
Transformed field.
"""
result = []
extra = {}
for field in data:

values = field.to_numpy(flatten=True)
values[self._mask] = np.nan

if self._rename is not None:
param = field.metadata("param")
name = f"{param}_{self._rename}"
extra["param"] = name
metadata = {}
values = field.to_numpy(flatten=True)
values[self.mask] = np.nan

result.append(new_field_from_numpy(values, template=field, **extra))
if self.rename is not None:
param = field.metadata("param")
name = f"{param}_{self.rename}"
metadata["param"] = name

return new_fieldlist_from_list(result)
return self.new_field_from_numpy(values, template=field, **metadata)
Loading