Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/anemoi/transform/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,75 @@ def _repr_specific(self):
return f"(metadata={self.kwargs})"


# TODO: remove alongside earthkit data wrappers
class HiddenMetadataField(WrappedField):
"""A field which hides specified keys from its metadata."""

def __init__(self, field: ekd.Field, hidden: list[str]):
super().__init__(field)
self._hidden = hidden

def metadata(self, *args: Any, **kwargs: Any) -> Any:
"""Get the metadata of the field.

Parameters
----------
*args : Any
Additional arguments.
**kwargs : Any
Additional keyword arguments.

Returns
-------
Any
The metadata of the field.
"""
this = self

if len(args) == 0 and len(kwargs) == 0:

class MD:

geography = this._field.metadata().geography

def get(self, key, default=None):
if key in this._hidden:
return default
return this._field.metadata().get(key, default)

def keys(self):
return [k for k in this._field.metadata().keys() if k not in this._hidden]

def __getitem__(self, key):
if key in this._hidden:
raise KeyError(f"Key '{key}' is hidden and cannot be accessed.")
return this._field.metadata()[key]

def override(self, *args, **kwargs):
return this._field.metadata().override(*args, **kwargs)

return MD()

if kwargs.get("namespace"):
assert len(args) == 0, (args, kwargs)
mars = self._field.metadata(**kwargs).copy()
for k in list(mars.keys()):
if k in this._hidden:
del mars[k]
return mars

def _val(a):
if a in this._hidden:
raise KeyError(f"Key '{a}' is hidden and cannot be accessed.")
return self._field.metadata(a, **kwargs)

result = tuple(_val(a) for a in args)
if len(result) == 1:
return result[0]

return result


class NewFlavouredField(_NewMetadataField):
def __init__(self, field: Any, flavour: Flavour) -> None:
super().__init__(field)
Expand Down
39 changes: 39 additions & 0 deletions src/anemoi/transform/filters/remove_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# (C) Copyright 2026- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import earthkit.data as ekd

from anemoi.transform.fields import HiddenMetadataField
from anemoi.transform.filter import SingleFieldFilter
from anemoi.transform.filters import filter_registry


@filter_registry.register("remove_metadata")
class RemoveMetadata(SingleFieldFilter):
"""A filter to remove metadata from fields."""

required_inputs = ("keys",)
optional_inputs = {"param": None}
Comment on lines +22 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to think this class could have just been a dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does way too much stuff to be a dataclass in my opinion. It only has two class level variables – I don't think that's too many

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(There's an argument for some of the class configuration itself being a dataclass, which is exactly what I've done for a WIP refactor of the MatchingFieldsFilter base class, but I'm reluctant to make further changes to SingleFieldFilter until that refactor goes in... Eventually these will all align, but we need to let the design settle a bit first - otherwise we risk having the wrong abstractions)

Comment on lines +22 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's param here used for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring to make it clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(NB: Not 100% sure if we need this filter yet...)


def prepare_filter(self):
if isinstance(self.keys, str):
self.keys = (self.keys,)
elif not isinstance(self.keys, (list, tuple)):
raise TypeError("Keys must be a string or list of strings.")

def forward_select(self):
if self.param is None:
return {}
return {"param": self.param}

def forward_transform(self, field: ekd.Field) -> ekd.Field:
"""Create a new field with wrapped metadata which hides keys of the original field's metadata."""
# TODO: ideally should be field.clone(metadata=...) - but wrappers prevent this for now
return HiddenMetadataField(field, hidden=self.keys)
107 changes: 107 additions & 0 deletions tests/test_remove_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# (C) Copyright 2026- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from anemoi.transform.filters import filter_registry

from .utils import collect_fields_by_param


@pytest.fixture
def source(test_source):
return test_source("anemoi-filters/2t-sp.grib")


def test_remove_mars_metadata_all_params(source):
keys = ["domain", "step"]
filter = filter_registry.create("remove_metadata", keys=keys)
pipeline = source | filter

input_fields = collect_fields_by_param(source)
output_fields = collect_fields_by_param(pipeline)

# check the number of fields for each param is unchanged
assert set(input_fields) == {"2t", "sp"}
assert set(input_fields) == set(output_fields)
for param in input_fields:
assert len(input_fields[param]) == len(output_fields[param])

# check all keys exist in input
for param, fields in input_fields.items():
for field in fields:
keys_exist = (k in field.metadata(namespace="mars") for k in keys)
assert all(keys_exist)

# check keys are removed for all fields
for param, fields in output_fields.items():
for field in fields:
keys_exist = (k in field.metadata(namespace="mars") for k in keys)
assert not any(keys_exist)


def test_remove_mars_metadata_single_param(source):
keys = ["domain", "step"]
filter = filter_registry.create("remove_metadata", param="2t", keys=keys)
pipeline = source | filter

input_fields = collect_fields_by_param(source)
output_fields = collect_fields_by_param(pipeline)

# check the number of fields for each param is unchanged
assert set(input_fields) == {"2t", "sp"}
assert set(input_fields) == set(output_fields)
for param in input_fields:
assert len(input_fields[param]) == len(output_fields[param])

# check all keys exist in input
for param, fields in input_fields.items():
for field in fields:
keys_exist = (k in field.metadata(namespace="mars") for k in keys)
assert all(keys_exist)

# check keys are removed only for matching param
for param, fields in output_fields.items():
for field in fields:
keys_exist = (k in field.metadata(namespace="mars") for k in keys)
if param == "2t":
assert not any(keys_exist)
else:
assert all(keys_exist)


def test_remove_mars_metadata_list_params(source):
keys = ["domain", "step"]
params = ["2t", "sp"]
filter = filter_registry.create("remove_metadata", param=params, keys=keys)
pipeline = source | filter

input_fields = collect_fields_by_param(source)
output_fields = collect_fields_by_param(pipeline)

# check the number of fields for each param is unchanged
assert set(input_fields) == {"2t", "sp"}
assert set(input_fields) == set(output_fields)
for param in input_fields:
assert len(input_fields[param]) == len(output_fields[param])

# check all keys exist in input
for param, fields in input_fields.items():
for field in fields:
keys_exist = (k in field.metadata(namespace="mars") for k in keys)
assert all(keys_exist)

# check keys are removed for both params
for param, fields in output_fields.items():
for field in fields:
keys_exist = (k in field.metadata(namespace="mars") for k in keys)
if param in params:
assert not any(keys_exist)
else:
raise ValueError(f"Unexpected param {param}")
Loading