Skip to content

Commit 42ecd3b

Browse files
refactor: rescale/convert filters to SingleFieldFilter (#226)
## Description Converts rescale and convert filters to SingleFieldFilters, changes tests to ensure non-matching fields are not processed, and breaks up the class hierarchy to avoid instantiating non-leaf class . ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 10a9cbc commit 42ecd3b

File tree

2 files changed

+82
-99
lines changed

2 files changed

+82
-99
lines changed

src/anemoi/transform/filters/rescale.py

Lines changed: 57 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,80 +6,64 @@
66
# In applying this licence, ECMWF does not waive the privileges and immunities
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
9-
10-
11-
from collections.abc import Iterator
9+
from abc import ABC
10+
from abc import abstractmethod
11+
from typing import Callable
1212

1313
import earthkit.data as ekd
1414

15+
from anemoi.transform.filter import SingleFieldFilter
1516
from anemoi.transform.filters import filter_registry
16-
from anemoi.transform.filters.matching import MatchingFieldsFilter
17-
from anemoi.transform.filters.matching import matching
1817

1918

20-
class Rescale(MatchingFieldsFilter):
19+
class Rescaler:
20+
def __init__(self, scale, offset):
21+
self.scale = scale
22+
self.offset = offset
23+
24+
def forward(self, x):
25+
return x * self.scale + self.offset
26+
27+
def backward(self, x):
28+
return (x - self.offset) / self.scale
29+
30+
31+
class RescaleMixin(ABC):
32+
# inheriting classes should define required_inputs (which must include param)
33+
# and must define self.rescaler in prepare_filter
34+
param: str
35+
rescaler: Rescaler
36+
# intended to be inherited from SingleFieldFilter
37+
new_field_from_numpy: Callable
38+
39+
@abstractmethod
40+
def prepare_filter(self):
41+
raise NotImplementedError("prepare_filter must be implemented by subclasses.")
42+
43+
def forward_select(self):
44+
return {"param": self.param}
45+
46+
def forward_transform(self, param: ekd.Field) -> ekd.Field:
47+
"""Apply the forward transformation (x to ax+b)."""
48+
rescaled = self.rescaler.forward(param.to_numpy())
49+
return self.new_field_from_numpy(rescaled, template=param, param=self.param)
50+
51+
def backward_transform(self, param: ekd.Field) -> ekd.Field:
52+
"""Apply the backward transformation (ax+b to x)."""
53+
descaled = self.rescaler.backward(param.to_numpy())
54+
return self.new_field_from_numpy(descaled, template=param, param=self.param)
55+
56+
57+
class Rescale(RescaleMixin, SingleFieldFilter):
2158
"""A filter to rescale a parameter from a scale and an offset, and back."""
2259

23-
@matching(
24-
select="param",
25-
forward=("param",),
26-
backward=("param",),
27-
)
28-
def __init__(
29-
self,
30-
*,
31-
scale: float,
32-
offset: float,
33-
param: str,
34-
) -> None:
35-
"""Parameters
36-
-------------
37-
scale : float
38-
The scale factor.
39-
offset : float
40-
The offset value.
41-
param : str
42-
The parameter to be rescaled.
43-
"""
60+
required_inputs = ("scale", "offset", "param")
4461

45-
self.scale = scale
46-
self.offset = offset
47-
self.param = param
48-
49-
def forward_transform(self, param: ekd.Field) -> Iterator[ekd.Field]:
50-
"""Apply the forward transformation (x to ax+b).
51-
52-
Parameters
53-
----------
54-
param : ekd.Field
55-
The input data to be transformed.
56-
57-
Returns
58-
-------
59-
Iterator[ekd.Field]
60-
A generator yielding the transformed data.
61-
"""
62-
rescaled = param.to_numpy() * self.scale + self.offset
63-
yield self.new_field_from_numpy(rescaled, template=param, param=self.param)
64-
65-
def backward_transform(self, param: ekd.Field) -> Iterator[ekd.Field]:
66-
"""Apply the backward transformation (ax+b to x).
67-
68-
Parameters
69-
----------
70-
param : ekd.Field
71-
The input data to be transformed.
72-
73-
Returns
74-
-------
75-
Iterator[ekd.Field]
76-
A generator yielding the transformed data.
77-
"""
78-
descaled = (param.to_numpy() - self.offset) / self.scale
79-
yield self.new_field_from_numpy(descaled, template=param, param=self.param)
80-
81-
82-
class Convert(Rescale):
62+
def prepare_filter(self):
63+
self.rescaler = Rescaler(self.scale, self.offset)
64+
65+
66+
class Convert(RescaleMixin, SingleFieldFilter):
8367
"""A filter to convert a parameter in a given unit to another unit, and back.
8468
8569
This filter uses :mod:`cfunits` (see the `cfunits documentation <https://ncas-cms.github.io/cfunits/>`_)
@@ -102,26 +86,19 @@ class Convert(Rescale):
10286
10387
"""
10488

105-
def __init__(self, *, unit_in: str, unit_out: str, param: str) -> None:
106-
"""Parameters
107-
-------------
108-
unit_in : str
109-
The input unit.
110-
unit_out : str
111-
The output unit.
112-
param : str
113-
The parameter to be converted.
114-
"""
89+
required_inputs = ("unit_in", "unit_out", "param")
90+
91+
def prepare_filter(self):
11592
from cfunits import Units
11693

117-
u0 = Units(unit_in)
118-
u1 = Units(unit_out)
94+
u0 = Units(self.unit_in)
95+
u1 = Units(self.unit_out)
11996
x1, x2 = 0.0, 1.0
12097
y1, y2 = Units.conform([x1, x2], u0, u1)
121-
a = (y2 - y1) / (x2 - x1)
122-
b = y1 - a * x1
98+
scale = (y2 - y1) / (x2 - x1)
99+
offset = y1 - scale * x1
123100

124-
super().__init__(scale=a, offset=b, param=param)
101+
self.rescaler = Rescaler(scale, offset)
125102

126103

127104
filter_registry.register("rescale", Rescale)

tests/test_rescale.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
import numpy.testing as npt
1212
import pytest
1313
from anemoi.utils.testing import skip_if_offline
14-
from pytest import approx
1514

16-
from anemoi.transform.filters.rescale import Convert
17-
from anemoi.transform.filters.rescale import Rescale
15+
from anemoi.transform.filters import filter_registry
1816

1917

2018
def skip_missing_udunits2():
@@ -40,15 +38,24 @@ def test_rescale(fieldlist: ekd.FieldList) -> None:
4038
The fieldlist to use for testing.
4139
"""
4240

43-
fieldlist = fieldlist.sel(param="2t")
41+
before_filter = {field.metadata("param"): field.to_numpy().copy() for field in fieldlist}
42+
4443
# rescale from K to °C
45-
k_to_deg = Rescale(scale=1.0, offset=-273.15, param="2t")
44+
k_to_deg = filter_registry.create("rescale", scale=1.0, offset=-273.15, param="2t")
4645
rescaled = k_to_deg.forward(fieldlist)
46+
after_forward = {field.metadata("param"): field.to_numpy().copy() for field in rescaled}
4747

48-
npt.assert_allclose(rescaled[0].to_numpy(), fieldlist[0].to_numpy() - 273.15)
4948
# and back
5049
rescaled_back = k_to_deg.backward(rescaled)
51-
npt.assert_allclose(rescaled_back[0].to_numpy(), fieldlist[0].to_numpy())
50+
after_backward = {field.metadata("param"): field.to_numpy().copy() for field in rescaled_back}
51+
52+
for param in ("2t", "sp"):
53+
npt.assert_allclose(before_filter[param], after_backward[param])
54+
55+
if param == "2t":
56+
npt.assert_allclose(before_filter[param] - 273.15, after_forward[param])
57+
else:
58+
npt.assert_allclose(before_filter[param], after_forward[param])
5259

5360

5461
@skip_missing_udunits2()
@@ -61,21 +68,20 @@ def test_convert(fieldlist: ekd.FieldList) -> None:
6168
fieldlist : ekd.FieldList
6269
The fieldlist to use for testing.
6370
"""
71+
before_filter = {field.metadata("param"): field.to_numpy().copy() for field in fieldlist}
6472
# rescale from K to °C
65-
fieldlist = fieldlist.sel(param="2t")
66-
k_to_deg = Convert(unit_in="K", unit_out="degC", param="2t")
73+
k_to_deg = filter_registry.create("convert", unit_in="K", unit_out="degC", param="2t")
6774
rescaled = k_to_deg.forward(fieldlist)
68-
assert rescaled[0].values.min() == fieldlist.values.min() - 273.15
69-
assert rescaled[0].values.std() == approx(fieldlist.values.std())
75+
after_forward = {field.metadata("param"): field.to_numpy().copy() for field in rescaled}
76+
7077
# and back
7178
rescaled_back = k_to_deg.backward(rescaled)
72-
assert rescaled_back[0].values.min() == fieldlist.values.min()
73-
assert rescaled_back[0].values.std() == approx(fieldlist.values.std())
79+
after_backward = {field.metadata("param"): field.to_numpy().copy() for field in rescaled_back}
7480

81+
for param in ("2t", "sp"):
82+
npt.assert_allclose(before_filter[param], after_backward[param])
7583

76-
if __name__ == "__main__":
77-
"""Run all test functions that start with 'test_'."""
78-
for name, obj in list(globals().items()):
79-
if name.startswith("test_") and callable(obj):
80-
print(f"Running {name}...")
81-
obj()
84+
if param == "2t":
85+
npt.assert_allclose(before_filter[param] - 273.15, after_forward[param])
86+
else:
87+
npt.assert_allclose(before_filter[param], after_forward[param])

0 commit comments

Comments
 (0)