Skip to content

Commit 3920c51

Browse files
test: replace MockField with a real ekd field in tests (#236)
## Description To avoid the problem of "mock skew" (i.e. the behaviour of a mock differing from the real object), the `MockField` in the tests has been replaced with a real earthkit data field (in memory) --------- Co-authored-by: Harrison Cook <Harrison.cook@ecmwf.int>
1 parent dff5eb2 commit 3920c51

File tree

4 files changed

+50
-95
lines changed

4 files changed

+50
-95
lines changed

tests/test_fields.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,7 @@
1313
from anemoi.transform.fields import new_field_with_metadata
1414
from src.anemoi.transform.fields import FieldSelection
1515

16-
17-
class MockField:
18-
def __init__(self, **metadata):
19-
self._metadata = metadata
20-
21-
# FieldSelection only sends the metadata message to objects
22-
def metadata(self, key):
23-
return self._metadata[key]
16+
from .utils import mock_field
2417

2518

2619
@pytest.fixture
@@ -74,7 +67,7 @@ def test_field_adding_metadata_updates_keys(sample_field):
7467

7568
def test_fieldselection_match_all():
7669
"""Test FieldSelection with no arguments matches all fields."""
77-
field = MockField(invalid_key="any_value")
70+
field = mock_field(invalid_key="any_value")
7871
selection = FieldSelection()
7972
assert selection.match(field)
8073

@@ -87,55 +80,55 @@ def test_fieldselection_invalid_key():
8780

8881
def test_fieldselection_match_fail_different_param():
8982
"""Test FieldSelection match fails when param is different."""
90-
field = MockField(param="2t")
83+
field = mock_field(param="2t")
9184
selection = FieldSelection(param="2z")
9285
assert not selection.match(field)
9386

9487

9588
def test_fieldselection_match_same_param():
9689
"""Test FieldSelection match succeeds when param is the same."""
97-
field = MockField(param="2t")
90+
field = mock_field(param="2t")
9891
selection = FieldSelection(param="2t")
9992
assert selection.match(field)
10093

10194

10295
def test_fieldselection_match_fail_missing_key():
10396
"""Test FieldSelection match fails when a selection key is missing on the field."""
104-
field = MockField(param="t")
97+
field = mock_field(param="t")
10598
selection = FieldSelection(param="t", levelist=850)
10699
assert not selection.match(field)
107100

108101

109102
def test_fieldselection_match_field_with_extra_metadata():
110103
"""Test FieldSelection match succeeds when the field has extra metadata."""
111-
field = MockField(param="t", levelist=850)
104+
field = mock_field(param="t", levelist=850)
112105
selection = FieldSelection(param="t")
113106
assert selection.match(field)
114107

115108

116109
def test_fieldselection_match_fail_same_param_different_level():
117110
"""Test FieldSelection match fails when param is the same but the levelist is different."""
118-
field = MockField(param="t", levelist=100)
111+
field = mock_field(param="t", levelist=100)
119112
selection = FieldSelection(param="t", levelist=850)
120113
assert not selection.match(field)
121114

122115

123116
def test_fieldselection_match_same_param_same_level():
124117
"""Test FieldSelection match succeeds when param and level are the same."""
125-
field = MockField(param="t", levelist=850)
118+
field = mock_field(param="t", levelist=850)
126119
selection = FieldSelection(param="t", levelist=850)
127120
assert selection.match(field)
128121

129122

130123
def test_fieldselection_match_is_subset():
131124
"""Test FieldSelection match succeeds when the field is a subset of the selection."""
132-
field = MockField(param="t", levelist=850)
125+
field = mock_field(param="t", levelist=850)
133126
selection = FieldSelection(param=["t", "q"], levelist=[850, 950])
134127
assert selection.match(field)
135128

136129

137130
def test_fieldselection_match_fail_different_param_same_level():
138131
"""Test FieldSelection match fails when the is on the same level but a different param."""
139-
field = MockField(param="t", levelist=850)
132+
field = mock_field(param="t", levelist=850)
140133
selection = FieldSelection(param="q", levelist=[850, 950])
141134
assert not selection.match(field)

tests/test_grouping.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,9 @@
1+
import earthkit.data as ekd
12
import pytest
23

34
from anemoi.transform.grouping import GroupByParam
45

5-
6-
class MockField:
7-
def __init__(self, **metadata):
8-
self._metadata = metadata
9-
10-
def metadata(self, key=None, namespace=None, **kwargs):
11-
MARS_KEYS = [
12-
"domain",
13-
"levtype",
14-
"levelist",
15-
"date",
16-
"time",
17-
"step",
18-
"param",
19-
"class",
20-
"type",
21-
"stream",
22-
"expver",
23-
]
24-
if namespace and (key or kwargs):
25-
raise ValueError("Cannot specify both namespace and key, or namespace and kwargs")
26-
if key and kwargs:
27-
raise ValueError("Cannot specify both key and kwargs")
28-
29-
if namespace == "mars":
30-
return {k: self._metadata[k] for k in MARS_KEYS if k in self._metadata}
31-
elif namespace:
32-
raise ValueError(f"Unknown namespace {namespace}")
33-
if key:
34-
return self._metadata[key]
35-
return {k: self._metadata[k] for k in kwargs}
36-
37-
def __repr__(self):
38-
return f"MockField({self._metadata})"
6+
from .utils import mock_field
397

408

419
def field_generator(**metadata_values):
@@ -58,7 +26,7 @@ def field_generator(**metadata_values):
5826
combinations = itertools.product(*metadata_values.values())
5927
for values in combinations:
6028
metadata = MOCK_MARS_METADATA | dict(zip(metadata_values.keys(), values))
61-
fields.append(MockField(**metadata))
29+
fields.append(mock_field(**metadata))
6230
return fields
6331

6432

@@ -112,7 +80,7 @@ def test_group_by_param_vertical(sample_fields_vertical):
11280
from anemoi.transform.grouping import GroupByParamVertical
11381

11482
def get_param(f):
115-
if isinstance(f, MockField):
83+
if isinstance(f, ekd.Field):
11684
f = [f]
11785

11886
param = [x.metadata("param") for x in f]
@@ -131,7 +99,7 @@ def get_param(f):
13199
assert [get_param(field) for field in group] == match_params
132100
metadata = []
133101
for fields in group:
134-
if isinstance(fields, MockField):
102+
if isinstance(fields, ekd.Field):
135103
fields = [fields]
136104
for field in fields:
137105
num_matching += 1

tests/test_matching.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,13 @@
1010

1111
from collections.abc import Iterator
1212

13-
import numpy as np
13+
import earthkit.data as ekd
1414
import pytest
1515

1616
from anemoi.transform.filters.matching import MatchingFieldsFilter
1717
from anemoi.transform.filters.matching import matching
1818

19-
20-
class MockField:
21-
def __init__(self, param, **meta):
22-
self._param = param
23-
self._meta = meta
24-
self.values = np.array([1.0]) # dummy data
25-
26-
def metadata(self, namespace=None):
27-
if namespace == "mars":
28-
return dict(self._meta, param=self._param)
29-
return self._param
30-
31-
32-
class MockFieldList(list):
33-
def metadata(self, name):
34-
return [getattr(f, "metadata")("mars")[name] for f in self]
19+
from .utils import mock_field
3520

3621

3722
class AddFields(MatchingFieldsFilter):
@@ -41,15 +26,13 @@ def __init__(self, a, b, return_inputs="none"):
4126
self.b = b
4227
self.return_inputs = return_inputs
4328

44-
def forward_transform(self, a: MockField, b: MockField) -> Iterator[MockField]:
29+
def forward_transform(self, a: ekd.Field, b: ekd.Field) -> Iterator[ekd.Field]:
4530
result = a.values + b.values
4631
yield self.new_field_from_numpy(result, template=a, param="c")
4732

4833
def new_field_from_numpy(self, array, *, template, param):
49-
return MockField(param, **template._meta)
50-
51-
def new_fieldlist_from_list(self, fields):
52-
return MockFieldList(fields)
34+
metadata = dict(template.metadata()) | {"param": param}
35+
return mock_field(**metadata)
5336

5437

5538
def test_matching_decorator_initializes_correctly():
@@ -59,43 +42,43 @@ def test_matching_decorator_initializes_correctly():
5942

6043

6144
def test_forward_transform_adds_fields():
62-
a = MockField("a", step=0, level=850)
63-
b = MockField("b", step=0, level=850)
64-
data = MockFieldList([a, b])
45+
a = mock_field(param="a", step=0, level=850)
46+
b = mock_field(param="b", step=0, level=850)
47+
data = ekd.SimpleFieldList([a, b])
6548

6649
f = AddFields(a="a", b="b")
6750
result = f.forward(data)
6851
assert len(result) == 1
69-
assert isinstance(result[0], MockField)
70-
assert result[0]._param == "c"
52+
assert isinstance(result[0], ekd.Field)
53+
assert result[0].metadata("param") == "c"
7154

7255

7356
def test_return_inputs():
74-
a = MockField("a", step=0, level=850)
75-
b = MockField("b", step=0, level=850)
76-
data = MockFieldList([a, b])
57+
a = mock_field(param="a", step=0, level=850)
58+
b = mock_field(param="b", step=0, level=850)
59+
data = ekd.SimpleFieldList([a, b])
7760

7861
f = AddFields(a="a", b="b", return_inputs="all")
7962
result = f.forward(data)
8063
assert len(result) == 3
8164
for i in range(3):
82-
assert isinstance(result[i], MockField)
83-
assert {result[i]._param for i in range(2)} == {"a", "b"}
84-
assert result[2]._param == "c"
65+
assert isinstance(result[i], ekd.Field)
66+
assert {result[i].metadata("param") for i in range(2)} == {"a", "b"}
67+
assert result[2].metadata("param") == "c"
8568

8669
f = AddFields(a="a", b="b", return_inputs=["a"])
8770
result = f.forward(data)
8871
assert len(result) == 2
8972
for i in range(2):
90-
assert isinstance(result[i], MockField)
91-
assert result[0]._param == "a"
92-
assert result[1]._param == "c"
73+
assert isinstance(result[i], ekd.Field)
74+
assert result[0].metadata("param") == "a"
75+
assert result[1].metadata("param") == "c"
9376

9477

9578
def test_missing_component_raises():
96-
a = MockField("a", step=0, level=850)
79+
a = mock_field(param="a", step=0, level=850)
9780
# Missing 'b'
98-
data = MockFieldList([a])
81+
data = ekd.SimpleFieldList([a])
9982
f = AddFields(a="a", b="b")
10083

10184
with pytest.raises(ValueError):
@@ -113,9 +96,9 @@ def forward_transform(self, *args):
11396

11497

11598
def test_metadata_mismatch_warning(caplog):
116-
c = MockField("c", step=0, level=850)
117-
d = MockField("d", step=0, level=850)
118-
data = MockFieldList([c, d])
99+
c = mock_field(param="c", step=0, level=850)
100+
d = mock_field(param="d", step=0, level=850)
101+
data = ekd.SimpleFieldList([c, d])
119102

120103
f = AddFields(a="a", b="b")
121104

tests/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# nor does it submit to any jurisdiction.
99
from collections import defaultdict
1010

11+
import earthkit.data as ekd
1112
import numpy as np
1213

1314
from anemoi.transform.fields import new_fieldlist_from_list
@@ -104,5 +105,15 @@ def compare_npz_files(file1, file2):
104105
assert (data1[key] == data2[key]).all(), f"Data for key {key} does not match between {file1} and {file2}"
105106

106107

108+
def mock_field(**metadata):
109+
class MetadataOverride(ekd.core.metadata.RawMetadata):
110+
def as_namespace(self, namespace):
111+
if namespace != "mars":
112+
raise ValueError(f"Unknown namespace {namespace}")
113+
return dict(self)
114+
115+
return ekd.ArrayField(array=[1], metadata=MetadataOverride(**metadata))
116+
117+
107118
def create_tabular_filter(name, **kwargs):
108119
return create_filter(name, **kwargs)

0 commit comments

Comments
 (0)