Skip to content

Commit 7b22238

Browse files
committed
adding adjacency matrix normalization
1 parent f7d5ae4 commit 7b22238

File tree

5 files changed

+35
-5
lines changed

5 files changed

+35
-5
lines changed

models/src/anemoi/models/layers/block.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from abc import ABC
1414
from abc import abstractmethod
15-
from typing import Optional
15+
from typing import Literal, Optional
1616
from typing import Union
1717

1818
import einops
@@ -443,6 +443,7 @@ def __init__(
443443
edge_dim: int,
444444
bias: bool = True,
445445
qk_norm: bool = False,
446+
adj_norm: Literal['sym','rw'] | None = None,
446447
update_src_nodes: bool = False,
447448
layer_kernels: DotDict,
448449
graph_attention_backend: str = "triton",
@@ -464,6 +465,8 @@ def __init__(
464465
Add bias or not
465466
qk_norm : bool, by default False
466467
Normalize query and key
468+
adj_norm : Literal['sym','rw'] | None
469+
Normalize adjacency aggregation: D^-1A ('rw') or D^{-1/2}AD^{-1/2} ('sym')
467470
update_src_nodes: bool, by default False
468471
Update src if src and dst nodes are given
469472
layer_kernels : DotDict
@@ -479,6 +482,7 @@ def __init__(
479482
self.out_channels_conv = out_channels // num_heads
480483
self.num_heads = num_heads
481484
self.qk_norm = qk_norm
485+
self.adj_norm = adj_norm
482486

483487
Linear = layer_kernels.Linear
484488
LayerNorm = layer_kernels.LayerNorm
@@ -513,7 +517,7 @@ def __init__(
513517
self.conv = GraphTransformerFunction.apply
514518
else:
515519
LOGGER.warning(f"{self.__class__.__name__} using pyg graph attention backend, consider using 'triton'.")
516-
self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)
520+
self.conv = GraphTransformerConv(out_channels=self.out_channels_conv,adj_norm=self.adj_norm)
517521

518522
def run_node_dst_mlp(self, x, **layer_kwargs):
519523
return self.node_dst_mlp(self.layer_norm_mlp_dst(x, **layer_kwargs))

models/src/anemoi/models/layers/conv.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# nor does it submit to any jurisdiction.
99

1010

11-
from typing import Optional
11+
from typing import Literal, Optional
1212

1313
import torch
1414
from torch import Tensor
@@ -18,6 +18,7 @@
1818
from torch_geometric.typing import OptPairTensor
1919
from torch_geometric.typing import OptTensor
2020
from torch_geometric.typing import Size
21+
from torch_geometric.utils import degree
2122
from torch_geometric.utils import scatter
2223
from torch_geometric.utils import softmax
2324

@@ -83,17 +84,23 @@ class GraphTransformerConv(MessagePassing):
8384
8485
Adapted from 'Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification'
8586
(https://arxiv.org/abs/2009.03509)
87+
88+
Edge normalization taken from 'Semi-Supervised Classification with Graph Convolutional Networks'(https://arxiv.org/abs/1609.02907)
89+
Code inspired from https://pytorch-geometric.readthedocs.io/en/2.6.0/_modules/torch_geometric/utils/laplacian.html#get_laplacian
8690
"""
8791

8892
def __init__(
8993
self,
9094
out_channels: int,
9195
dropout: float = 0.0,
96+
adj_norm: Literal["sym", "rw"] | None = None,
9297
**kwargs,
9398
):
9499
kwargs.setdefault("aggr", "add")
95100
super().__init__(node_dim=0, **kwargs)
96101

102+
self.adj_norm = adj_norm
103+
97104
self.out_channels = out_channels
98105
self.dropout = dropout
99106

@@ -108,6 +115,19 @@ def forward(
108115
):
109116
dim_size = query.shape[0]
110117
heads = query.shape[1]
118+
119+
edge_weights = torch.ones(edge_index.size(1), dtype = query.dtype, device = query.device)
120+
121+
if self.adj_norm is not None:
122+
row, col = edge_index
123+
deg = degree(col,dtype=query.dtype)
124+
125+
if self.adj_norm=="sym":
126+
deg_inv_sqrt = deg.pow_(-0.5)
127+
deg_inv_sqrt.masked_fill_(deg_inv_sqrt==float('inf'),0)
128+
edge_weights = (deg_inv_sqrt[row] * deg_inv_sqrt[col])
129+
elif self.adj_norm=='rw':
130+
edge_weights = (deg.pow_(-1.0)[row] * edge_weights)
111131

112132
out = self.propagate(
113133
edge_index=edge_index,
@@ -118,8 +138,9 @@ def forward(
118138
query=query,
119139
key=key,
120140
value=value,
141+
edge_weights = edge_weights.repeat(1,heads)
121142
)
122-
143+
123144
return out
124145

125146
def message(
@@ -128,6 +149,7 @@ def message(
128149
query_i: Tensor,
129150
key_j: Tensor,
130151
value_j: Tensor,
152+
edge_weights: Tensor,
131153
edge_attr: OptTensor,
132154
index: Tensor,
133155
ptr: OptTensor,
@@ -141,4 +163,4 @@ def message(
141163
alpha = softmax(alpha, index, ptr, size_i)
142164
alpha = dropout(alpha, p=self.dropout, training=self.training)
143165

144-
return (value_j + edge_attr) * alpha.view(-1, heads, 1)
166+
return edge_weights.view(-1,heads,1) * (value_j + edge_attr) * alpha.view(-1, heads, 1)

models/src/anemoi/models/schemas/processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class GraphTransformerProcessorSchema(TransformerModelComponent):
4242
"Number of chunks to divide the layer into. Default to 2."
4343
qk_norm: bool = Field(example=False)
4444
"Normalize the query and key vectors. Default to False."
45+
adj_norm: Literal['sym','rw'] | None =Field(example=None)
46+
"Normalize adjacency aggregation: D^-1A ('rw') or D^{-1/2}AD^{-1/2} ('sym')"
4547

4648
@model_validator(mode="after")
4749
def check_valid_extras(self) -> Any:

models/tests/layers/block/test_block_graphtransformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def block(init_proc):
6868
bias=bias,
6969
update_src_nodes=False,
7070
qk_norm=qk_norm,
71+
adj_norm=None,
7172
graph_attention_backend=graph_attention_backend,
7273
)
7374

training/src/anemoi/training/config/model/graphtransformer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ processor:
3232
qk_norm: False
3333
cpu_offload: ${model.cpu_offload}
3434
layer_kernels: ${model.layer_kernels}
35+
adj_norm: null
3536
graph_attention_backend: "triton" # Options: "triton", "pyg"
3637

3738
encoder:

0 commit comments

Comments
 (0)