Skip to content

Commit aeaf00b

Browse files
JPXKQXmatschreinerHCookieanaprietonem
authored
feat(models): add configurable residual connections in enc-proc-dec (#670)
Abstracts the residual connection around a new class, `BaseResidualConnection`. It currently supports 3 options: `SkipConnection`, `TruncatedConnection`, and `NoConnection`. Creates the `SparseProjector` class which handles the projection logic. This class could be used in the future for the multi-scale loss or other use cases. Co-authored-by: Mathias Schreiner <[email protected]> Co-authored-by: Harrison Cook <[email protected]> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent 969b787 commit aeaf00b

36 files changed

+735
-363
lines changed

models/docs/modules/residual.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
######################
2+
Residual connections
3+
######################
4+
5+
Residual connections are a key architectural feature in Anemoi's
6+
encoder-processor-decoder models, enabling more effective information
7+
flow and gradient propagation across network layers. Residual
8+
connections help mitigate issues such as vanishing gradients and support
9+
the training of deeper, and more expressive models.
10+
11+
In Anemoi, the type of residual connection used in a model is specified
12+
under the `residual` key in the model configuration YAML. This modular
13+
approach allows users to select and customize the residual strategy best
14+
suited for their forecasting task, whether it be a standard skip
15+
connection, no connection, or a truncated connection.
16+
17+
The following classes implement the available residual connection types
18+
in Anemoi.
19+
20+
*****************
21+
Skip Connection
22+
*****************
23+
24+
.. autoclass:: anemoi.models.layers.residual.SkipConnection
25+
:members:
26+
:no-undoc-members:
27+
:show-inheritance:
28+
29+
**********************
30+
Truncated Connection
31+
**********************
32+
33+
.. autoclass:: anemoi.models.layers.residual.TruncatedConnection
34+
:members:
35+
:no-undoc-members:
36+
:show-inheritance:

models/docs/usage/create_model.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ First, let's take the model configuration ``transformer.yaml``:
4949
mlp_hidden_ratio: 4
5050
num_heads: 16
5151
52+
residual:
53+
_target_: anemoi.models.layers.residual.SkipConnection
54+
5255
attributes:
5356
edges:
5457
- edge_length

models/src/anemoi/models/interface/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363
metadata: dict,
6464
statistics_tendencies: dict = None,
6565
supporting_arrays: dict = None,
66-
truncation_data: dict,
6766
) -> None:
6867
super().__init__()
6968
self.config = config
@@ -72,7 +71,6 @@ def __init__(
7271
self.graph_data = graph_data
7372
self.statistics = statistics
7473
self.statistics_tendencies = statistics_tendencies
75-
self.truncation_data = truncation_data
7674
self.metadata = metadata
7775
self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {}
7876
self.data_indices = data_indices
@@ -112,7 +110,6 @@ def _build_model(self) -> None:
112110
data_indices=self.data_indices,
113111
statistics=self.statistics,
114112
graph_data=self.graph_data,
115-
truncation_data=self.truncation_data,
116113
_recursive_=False, # Disables recursive instantiation by Hydra
117114
)
118115

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
from abc import ABC
12+
from abc import abstractmethod
13+
from typing import Optional
14+
15+
import einops
16+
import torch
17+
from torch import nn
18+
from torch_geometric.data import HeteroData
19+
20+
from anemoi.models.distributed.graph import gather_channels
21+
from anemoi.models.distributed.graph import shard_channels
22+
from anemoi.models.distributed.shapes import apply_shard_shapes
23+
from anemoi.models.layers.sparse_projector import build_sparse_projector
24+
25+
26+
class BaseResidualConnection(nn.Module, ABC):
27+
"""Base class for residual connection modules."""
28+
29+
def __init__(self, graph: HeteroData | None = None) -> None:
30+
super().__init__()
31+
32+
@abstractmethod
33+
def forward(self, x: torch.Tensor, grid_shard_shapes=None, model_comm_group=None) -> torch.Tensor:
34+
"""Define the residual connection operation.
35+
36+
Should be overridden by subclasses.
37+
"""
38+
pass
39+
40+
41+
class SkipConnection(BaseResidualConnection):
42+
"""Skip connection module
43+
44+
This layer returns the most recent timestep from the input sequence.
45+
46+
This module is used to bypass processing layers and directly pass the latest input forward.
47+
"""
48+
49+
def __init__(self, step: int = -1, **_) -> None:
50+
super().__init__()
51+
self.step = step
52+
53+
def forward(self, x: torch.Tensor, grid_shard_shapes=None, model_comm_group=None) -> torch.Tensor:
54+
"""Return the last timestep of the input sequence."""
55+
return x[:, self.step, ...] # x shape: (batch, time, ens, nodes, features)
56+
57+
58+
class TruncatedConnection(BaseResidualConnection):
59+
"""Truncated skip connection
60+
61+
This connection applies a coarse-graining and reconstruction of input features using sparse
62+
projections to truncate high frequency features.
63+
64+
This module uses two projection operators: one to map features from the full-resolution
65+
grid to a truncated (coarse) grid, and another to project back to the original resolution.
66+
67+
Parameters
68+
----------
69+
graph : HeteroData, optional
70+
The graph containing the subgraphs for down and up projections.
71+
data_nodes : str, optional
72+
Name of the nodes representing the data nodes.
73+
truncation_nodes : str, optional
74+
Name of the nodes representing the truncated (coarse) nodes.
75+
edge_weight_attribute : str, optional
76+
Name of the edge attribute to use as weights for the projections.
77+
src_node_weight_attribute : str, optional
78+
Name of the source node attribute to use as weights for the projections.
79+
autocast : bool, default False
80+
Whether to use automatic mixed precision for the projections.
81+
truncation_up_file_path : str, optional
82+
File path (.npz) to load the up-projection matrix from.
83+
truncation_down_file_path : str, optional
84+
File path (.npz) to load the down-projection matrix from.
85+
86+
Example
87+
-------
88+
>>> from torch_geometric.data import HeteroData
89+
>>> import torch
90+
>>> # Assume graph is a HeteroData object with the required edges and node types
91+
>>> graph = HeteroData()
92+
>>> # ...populate graph with nodes and edges for 'data' and 'int'...
93+
>>> # Example creating the projection matrices from the graph
94+
>>> conn = TruncatedConnection(
95+
... graph=graph,
96+
... data_nodes="data",
97+
... truncation_nodes="int",
98+
... edge_weight_attribute="gauss_weight",
99+
... )
100+
>>> x = torch.randn(2, 4, 1, 40192, 44) # (batch, time, ens, nodes, features)
101+
>>> out = conn(x)
102+
>>> print(out.shape)
103+
torch.Size([2, 4, 1, 40192, 44])
104+
105+
>>> # Example specifying .npz files for projection matrices
106+
>>> conn = TruncatedConnection(
107+
... truncation_down_file_path="n320_to_o96.npz",
108+
... truncation_up_file_path="o96_to_n320.npz",
109+
... )
110+
>>> x = torch.randn(2, 4, 1, 40192, 44)
111+
>>> out = conn(x)
112+
>>> print(out.shape)
113+
torch.Size([2, 4, 1, 40192, 44])
114+
"""
115+
116+
def __init__(
117+
self,
118+
graph: Optional[HeteroData] = None,
119+
data_nodes: Optional[str] = None,
120+
truncation_nodes: Optional[str] = None,
121+
edge_weight_attribute: Optional[str] = None,
122+
src_node_weight_attribute: Optional[str] = None,
123+
truncation_up_file_path: Optional[str] = None,
124+
truncation_down_file_path: Optional[str] = None,
125+
autocast: bool = False,
126+
) -> None:
127+
super().__init__()
128+
up_edges, down_edges = self._get_edges_name(
129+
graph,
130+
data_nodes,
131+
truncation_nodes,
132+
truncation_up_file_path,
133+
truncation_down_file_path,
134+
edge_weight_attribute,
135+
)
136+
137+
self.project_down = build_sparse_projector(
138+
graph=graph,
139+
edges_name=down_edges,
140+
edge_weight_attribute=edge_weight_attribute,
141+
src_node_weight_attribute=src_node_weight_attribute,
142+
file_path=truncation_down_file_path,
143+
autocast=autocast,
144+
)
145+
146+
self.project_up = build_sparse_projector(
147+
graph=graph,
148+
edges_name=up_edges,
149+
edge_weight_attribute=edge_weight_attribute,
150+
src_node_weight_attribute=src_node_weight_attribute,
151+
file_path=truncation_up_file_path,
152+
autocast=autocast,
153+
)
154+
155+
def _get_edges_name(
156+
self,
157+
graph,
158+
data_nodes,
159+
truncation_nodes,
160+
truncation_up_file_path,
161+
truncation_down_file_path,
162+
edge_weight_attribute,
163+
):
164+
are_files_specified = truncation_up_file_path is not None and truncation_down_file_path is not None
165+
if not are_files_specified:
166+
assert graph is not None, "graph must be provided if file paths are not specified."
167+
assert data_nodes is not None, "data nodes name must be provided if file paths are not specified."
168+
assert (
169+
truncation_nodes is not None
170+
), "truncation nodes name must be provided if file paths are not specified."
171+
up_edges = (truncation_nodes, "to", data_nodes)
172+
down_edges = (data_nodes, "to", truncation_nodes)
173+
assert up_edges in graph.edge_types, f"Graph must contain edges {up_edges} for up-projection."
174+
assert down_edges in graph.edge_types, f"Graph must contain edges {down_edges} for down-projection."
175+
else:
176+
assert (
177+
data_nodes is None or truncation_nodes is None or edge_weight_attribute is None
178+
), "If file paths are specified, node and attribute names should not be provided."
179+
up_edges = down_edges = None # Not used when loading from files
180+
return up_edges, down_edges
181+
182+
def forward(self, x: torch.Tensor, grid_shard_shapes=None, model_comm_group=None) -> torch.Tensor:
183+
"""Apply truncated skip connection."""
184+
batch_size = x.shape[0]
185+
x = x[:, -1, ...] # pick latest step
186+
shard_shapes = apply_shard_shapes(x, 0, grid_shard_shapes) if grid_shard_shapes is not None else None
187+
188+
x = einops.rearrange(x, "batch ensemble grid features -> (batch ensemble) grid features")
189+
x = self._to_channel_shards(x, shard_shapes, model_comm_group)
190+
x = self.project_down(x)
191+
x = self.project_up(x)
192+
x = self._to_grid_shards(x, shard_shapes, model_comm_group)
193+
x = einops.rearrange(x, "(batch ensemble) grid features -> batch ensemble grid features", batch=batch_size)
194+
195+
return x
196+
197+
def _to_channel_shards(self, x, shard_shapes=None, model_comm_group=None):
198+
return self._reshard(x, shard_channels, shard_shapes, model_comm_group)
199+
200+
def _to_grid_shards(self, x, shard_shapes=None, model_comm_group=None):
201+
return self._reshard(x, gather_channels, shard_shapes, model_comm_group)
202+
203+
def _reshard(self, x, fn, shard_shapes=None, model_comm_group=None):
204+
if shard_shapes is not None:
205+
x = fn(x, shard_shapes, model_comm_group)
206+
return x

0 commit comments

Comments
 (0)