-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy path__init__.py
More file actions
169 lines (137 loc) · 5.6 KB
/
__init__.py
File metadata and controls
169 lines (137 loc) · 5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
import logging
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from earthkit.data import SimpleFieldList
LOG = logging.getLogger(__name__)
def _lost(f: Any) -> None:
"""Raise a ValueError indicating a lost field.
Parameters
----------
f : Any
The lost field.
"""
raise ValueError(f"Lost field {f}")
def _flatten(params: list[Any] | tuple[Any, ...]) -> list[str]:
"""Flatten a list of parameters.
Parameters
----------
params : list or tuple of Any
List or tuple of parameters to flatten.
Returns
-------
list of str
Flattened list of parameters.
"""
flat = []
for p in params:
if isinstance(p, (list, tuple)):
flat.extend(_flatten(p))
else:
flat.append(p)
return flat
class GroupByParam:
"""Group matching fields by parameters name.
Parameters
----------
params : list of str
List of parameters to group by.
"""
def __init__(self, params: list[str]) -> None:
if not isinstance(params, (list, tuple)):
params = [params]
self.params = _flatten(params)
@staticmethod
def _get_grouping_key(field, extract_keys: list[str] | None = None, remove_from_key: list[str] | None = None):
key = field.metadata(namespace="mars")
if not key:
keys = [k for k in field.metadata().keys() if k not in ("latitudes", "longitudes", "values")]
key = {k: field.metadata(k) for k in keys}
if not keys:
raise NotImplementedError(f"GroupByParam: {field} has no sufficient metadata")
extract = {}
for k in extract_keys:
extract[k] = key.pop(k, field.metadata().get(k, default=None))
for k in remove_from_key:
key.pop(k, None)
if len(extract) != len(extract_keys):
raise ValueError(f"Expected {extract_keys} keys to extract, got {extract}")
return key, extract
def _get_groups(self, data: list[Any], *, other: Callable[[Any], None] = _lost) -> None:
assert callable(other), type(other)
self.groups: dict[tuple[tuple[str, Any], ...], dict[str, Any]] = defaultdict(dict)
self.groups_params = set()
for f in data:
key, extras = self._get_grouping_key(f, extract_keys=["param"], remove_from_key=["variable"])
param = extras["param"]
if param not in self.params:
other(f)
continue
key = frozenset(key.items())
if param in self.groups[key]:
raise ValueError(f"Duplicate component {param} for {key}")
self.groups[key][param] = f
self.groups_params.add(param)
LOG.info(f"Params groups: {self.groups_params}")
def iterate(self, data: list[Any], *, other: Callable[[Any], None] = _lost) -> Iterator[tuple[Any, ...]]:
"""Iterate over the data and group fields by parameters.
Parameters
----------
data : list of Any
List of data fields to group.
other : callable, optional
Function to call for fields that do not match the parameters, by default _lost.
Returns
-------
Iterator[Tuple[Any, ...]]
Iterator yielding tuples of grouped fields.
"""
self._get_groups(data, other=other)
for _, group in self.groups.items():
if len(group) != len(self.params):
for p in data:
print(p)
raise ValueError(f"Missing component. Want {sorted(self.params)}, got {sorted(group.keys())}")
yield tuple(group[p] for p in self.params)
class GroupByParamVertical(GroupByParam):
def _get_groups(self, data: list[Any], *, other: Callable[[Any], None] = _lost) -> None:
assert callable(other), type(other)
self.groups: dict[tuple[tuple[str, Any], ...], dict[str, Any]] = defaultdict(dict)
self.groups_params = set()
levels: dict[str, Any] = defaultdict(list)
for f in data:
key, extras = self._get_grouping_key(
f, extract_keys=["param", "levelist"], remove_from_key=["variable", "levtype"]
)
param = extras["param"]
level = extras["levelist"]
if param not in self.params:
other(f)
continue
key = frozenset(key.items())
if level is None:
if param in self.groups[key]:
raise ValueError(f"Duplicate component {param} for {key}")
self.groups[key][param] = f
else:
if param in self.groups[key]:
if level in levels[param]:
raise ValueError(f"Duplicate component {param} for {key} and level {level}")
else:
self.groups[key][param].append(f)
else:
ds = SimpleFieldList()
ds.append(f)
self.groups[key][param] = ds
levels[param].append(level)
self.groups_params.add(param)
LOG.info(f"Params groups: {self.groups_params}")