Skip to content

Commit 415b184

Browse files
japolsdietervdb-meteossmmnn11pre-commit-ci[bot]anaprietonem
authored andcommitted
feat(models): Triton GraphTransformer (#631)
## Description Introduce custom Triton Kernel for a fast, memory efficient and deterministic GraphTransformer. Benchmarking results comparing the current GraphTransformer implementation (PyG) against PyG compiled and the Triton kernel on different graphs, measured in ms per fwd + bwd: All runs use era = n320, h = lvl6-multi-scale, dtype = fp16. A100 | Graph type | PyG | PyG (compiled) | Triton | |:------------|-------:|---------------:|---------:| | era_to_h | 53.09 | 15.86 | **6.81** | | h_to_h | 20.32 | 4.45 | **1.82** | | h_to_era | 100.44 | 20.53 | **11.19** | | dop_enc | 60.44 | 17.90 | **9.08** | H200 | Graph type | PyG | PyG (compiled) | Triton | | :--------- | ----: | -------------: | -------: | | era_to_h | 23.97 | 6.02 | **2.77** | | h_to_h | 9.64 | 1.69 | **0.83** | | h_to_era | 46.99 | 8.16 | **4.87** | | dop_enc | 26.62 | 6.85 | **3.86** | Exemplary impact on training for an o800 lvl7 GraphTransformer model in bf16 on H200s: pyg compiled: 3.90 s/it → Triton: 2.62 s/it (×1.5 speedup) ## What problem does this change solve? Current PyG GraphTransformer suffers from high memory footprint due to per-edge computations, multiple kernel launches and stability issues with atomics in low precision. [Compiling](#181) can help adress some of these issues, but one can do better with a custom triton kernel. ## Additional notes ## Checked outputs and gradients match pyg version + small training sanity check: [mlflow](https://mlflow.ecmwf.int/#/metric?runs=[%224803881259ff447cbbb4ff4436926823%22,%22b37aee50b3664afa9fc4aee6b557d71b%22]&metric=%22train_mse_loss_step%22&experiments=[%2245%22]&plot_metric_keys=%5B%22train_mse_loss_step%22%5D&plot_layout={%22autosize%22:true,%22xaxis%22:{%22autorange%22:true,%22type%22:%22linear%22,%22range%22:[-1035.9219986020375,960.4671241229228]},%22yaxis%22:{%22range%22:[-0.6010538821019717,6.501633332403755],%22autorange%22:true,%22type%22:%22linear%22}}&x_axis=step&y_axis_scale=linear&line_smoothness=1&show_point=false&deselected_curves=[]&last_linear_y_axis_range=[]) TODO - [x] make GT backend configurable, pyg as fallback - [x] add pytests ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: Dieter Van den Bleeken <[email protected]> Co-authored-by: Simon Lang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent c9bf182 commit 415b184

22 files changed

+695
-25
lines changed

.github/workflows/integration-tests-hpc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ jobs:
4343
pip install --upgrade pip
4444
pip install -e ./training[all,tests] -e ./models[all,tests] -e ./graphs[all,tests]
4545
python3 -m pytest -v training/tests/integration --slow
46+
python3 -m pytest -v models/tests/integration --slow
4647
deactivate
4748
rm -rf $REPO_NAME
4849
sbatch_options: |

models/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ markers =
33
data_dependent: marks tests depending on data (deselect with '-m "not data_dependent"')
44
auth: marks tests that require authentication (deselect with '-m "not auth"')
55
gpu: marks tests that require a GPU (deselect with '-m "not gpu"')
6+
slow: mark test as slow (skipped unless --slow is used)
67

78
tmp_path_retention_policy = none

models/src/anemoi/models/layers/block.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from anemoi.models.layers.conv import GraphConv
3535
from anemoi.models.layers.conv import GraphTransformerConv
3636
from anemoi.models.layers.mlp import MLP
37+
from anemoi.models.triton.gt import GraphTransformerFunction
38+
from anemoi.models.triton.utils import edge_index_to_csc
3739
from anemoi.utils.config import DotDict
3840

3941
LOGGER = logging.getLogger(__name__)
@@ -443,6 +445,7 @@ def __init__(
443445
qk_norm: bool = False,
444446
update_src_nodes: bool = False,
445447
layer_kernels: DotDict,
448+
graph_attention_backend: str = "triton",
446449
**kwargs,
447450
) -> None:
448451
"""Initialize GraphTransformerBlock.
@@ -466,6 +469,8 @@ def __init__(
466469
layer_kernels : DotDict
467470
A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear"
468471
Defined in config/models/<model>.yaml
472+
graph_attention_backend: str, by default "triton"
473+
Backend to use for graph transformer conv, options are "triton" and "pyg"
469474
"""
470475
super().__init__(**kwargs)
471476

@@ -483,8 +488,6 @@ def __init__(
483488
self.lin_self = Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
484489
self.lin_edge = Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)
485490

486-
self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)
487-
488491
self.projection = Linear(out_channels, out_channels)
489492

490493
if self.qk_norm:
@@ -499,6 +502,19 @@ def __init__(
499502
Linear(hidden_dim, out_channels),
500503
)
501504

505+
self.graph_attention_backend = graph_attention_backend
506+
assert self.graph_attention_backend in [
507+
"triton",
508+
"pyg",
509+
], f"Backend {self.graph_attention_backend} not supported for GraphTransformerBlock, valid options are 'triton' and 'pyg'"
510+
511+
if self.graph_attention_backend == "triton":
512+
LOGGER.info(f"{self.__class__.__name__} using triton graph attention backend.")
513+
self.conv = GraphTransformerFunction.apply
514+
else:
515+
LOGGER.warning(f"{self.__class__.__name__} using pyg graph attention backend, consider using 'triton'.")
516+
self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)
517+
502518
def run_node_dst_mlp(self, x, **layer_kwargs):
503519
return self.node_dst_mlp(self.layer_norm_mlp_dst(x, **layer_kwargs))
504520

@@ -555,37 +571,50 @@ def shard_qkve_heads(
555571

556572
return query, key, value, edges
557573

558-
def attention_block(
574+
def apply_gt(
559575
self,
560576
query: Tensor,
561577
key: Tensor,
562578
value: Tensor,
563579
edges: Tensor,
564580
edge_index: Adj,
565581
size: Union[int, tuple[int, int]],
566-
num_chunks: int,
567582
) -> Tensor:
568583
# self.conv requires size to be a tuple
569584
conv_size = (size, size) if isinstance(size, int) else size
570585

586+
if self.graph_attention_backend == "triton":
587+
csc, perm, reverse = edge_index_to_csc(edge_index, num_nodes=conv_size, reverse=True)
588+
edges_csc = edges.index_select(0, perm)
589+
args_conv = (edges_csc, csc, reverse)
590+
else:
591+
args_conv = (edges, edge_index, conv_size)
592+
593+
return self.conv(query, key, value, *args_conv)
594+
595+
def attention_block(
596+
self,
597+
query: Tensor,
598+
key: Tensor,
599+
value: Tensor,
600+
edges: Tensor,
601+
edge_index: Adj,
602+
size: Union[int, tuple[int, int]],
603+
num_chunks: int,
604+
) -> Tensor:
605+
# split 1-hop edges into chunks, compute self.conv chunk-wise
571606
if num_chunks > 1:
572-
# split 1-hop edges into chunks, compute self.conv chunk-wise
573607
edge_attr_list, edge_index_list = sort_edges_1hop_chunks(
574608
num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks
575609
)
576610
# shape: (num_nodes, num_heads, out_channels_conv)
577611
out = torch.zeros((*query.shape[:-1], self.out_channels_conv), device=query.device)
578612
for i in range(num_chunks):
579-
out += self.conv(
580-
query=query,
581-
key=key,
582-
value=value,
583-
edge_attr=edge_attr_list[i],
584-
edge_index=edge_index_list[i],
585-
size=conv_size,
613+
out += self.apply_gt(
614+
query=query, key=key, value=value, edges=edge_attr_list[i], edge_index=edge_index_list[i], size=size
586615
)
587616
else:
588-
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=conv_size)
617+
out = self.apply_gt(query=query, key=key, value=value, edges=edges, edge_index=edge_index, size=size)
589618

590619
return out
591620

@@ -635,6 +664,7 @@ def __init__(
635664
update_src_nodes: bool = False,
636665
layer_kernels: DotDict,
637666
shard_strategy: str = "edges",
667+
graph_attention_backend: str = "triton",
638668
**kwargs,
639669
) -> None:
640670
"""Initialize GraphTransformerBlock.
@@ -662,6 +692,8 @@ def __init__(
662692
Defined in config/models/<model>.yaml
663693
shard_strategy: str, by default "edges"
664694
Strategy to shard tensors
695+
graph_attention_backend: str, by default "triton"
696+
Backend to use for graph transformer conv, options are "triton" and "pyg"
665697
"""
666698

667699
super().__init__(
@@ -674,6 +706,7 @@ def __init__(
674706
bias=bias,
675707
qk_norm=qk_norm,
676708
update_src_nodes=update_src_nodes,
709+
graph_attention_backend=graph_attention_backend,
677710
**kwargs,
678711
)
679712

@@ -791,6 +824,7 @@ def __init__(
791824
qk_norm: bool = False,
792825
update_src_nodes: bool = False,
793826
layer_kernels: DotDict,
827+
graph_attention_backend: str = "triton",
794828
**kwargs,
795829
) -> None:
796830
"""Initialize GraphTransformerBlock.
@@ -814,6 +848,8 @@ def __init__(
814848
layer_kernels : DotDict
815849
A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear"
816850
Defined in config/models/<model>.yaml
851+
graph_attention_backend: str, by default "triton"
852+
Backend to use for graph transformer conv, options are "triton" and "pyg"
817853
"""
818854

819855
super().__init__(
@@ -826,6 +862,7 @@ def __init__(
826862
bias=bias,
827863
qk_norm=qk_norm,
828864
update_src_nodes=update_src_nodes,
865+
graph_attention_backend=graph_attention_backend,
829866
**kwargs,
830867
)
831868

models/src/anemoi/models/layers/mapper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
cpu_offload: bool = False,
217217
layer_kernels: DotDict = None,
218218
shard_strategy: str = "edges",
219+
graph_attention_backend: str = "triton",
219220
) -> None:
220221
"""Initialize GraphTransformerBaseMapper.
221222
@@ -254,6 +255,8 @@ def __init__(
254255
Defined in config/models/<model>.yaml
255256
shard_strategy : str, optional
256257
Strategy to shard tensors, by default "edges"
258+
graph_attention_backend: str, by default "triton"
259+
Backend to use for graph transformer conv, options are "triton" and "pyg"
257260
"""
258261
super().__init__(
259262
in_channels_src=in_channels_src,
@@ -282,6 +285,7 @@ def __init__(
282285
qk_norm=qk_norm,
283286
layer_kernels=self.layer_factory,
284287
shard_strategy=shard_strategy,
288+
graph_attention_backend=graph_attention_backend,
285289
)
286290

287291
self.offload_layers(cpu_offload)
@@ -539,6 +543,7 @@ def __init__(
539543
cpu_offload: bool = False,
540544
layer_kernels: DotDict = None,
541545
shard_strategy: str = "edges",
546+
graph_attention_backend: str = "triton",
542547
) -> None:
543548
"""Initialize GraphTransformerForwardMapper.
544549
@@ -574,6 +579,8 @@ def __init__(
574579
A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear"
575580
shard_strategy : str, optional
576581
Strategy to shard tensors, by default "edges"
582+
graph_attention_backend: str, by default "triton"
583+
Backend to use for graph transformer conv, options are "triton" and "pyg"
577584
"""
578585
super().__init__(
579586
in_channels_src=in_channels_src,
@@ -592,6 +599,7 @@ def __init__(
592599
dst_grid_size=dst_grid_size,
593600
layer_kernels=layer_kernels,
594601
shard_strategy=shard_strategy,
602+
graph_attention_backend=graph_attention_backend,
595603
)
596604

597605
self.emb_nodes_src = self.layer_factory.Linear(self.in_channels_src, self.hidden_dim)
@@ -643,6 +651,7 @@ def __init__(
643651
cpu_offload: bool = False,
644652
layer_kernels: DotDict = None,
645653
shard_strategy: str = "edges",
654+
graph_attention_backend: str = "triton",
646655
) -> None:
647656
"""Initialize GraphTransformerBackwardMapper.
648657
@@ -683,6 +692,8 @@ def __init__(
683692
Defined in config/models/<model>.yaml
684693
shard_strategy : str, optional
685694
Strategy to shard tensors, by default "edges"
695+
graph_attention_backend: str, by default "triton"
696+
Backend to use for graph transformer conv, options are "triton" and "pyg"
686697
"""
687698
super().__init__(
688699
in_channels_src=in_channels_src,
@@ -701,6 +712,7 @@ def __init__(
701712
dst_grid_size=dst_grid_size,
702713
layer_kernels=layer_kernels,
703714
shard_strategy=shard_strategy,
715+
graph_attention_backend=graph_attention_backend,
704716
)
705717

706718
self.node_data_extractor = nn.Sequential(

models/src/anemoi/models/layers/processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def __init__(
408408
qk_norm: bool = False,
409409
cpu_offload: bool = False,
410410
layer_kernels: DotDict,
411+
graph_attention_backend: str = "triton",
411412
**kwargs,
412413
) -> None:
413414
"""Initialize GraphTransformerProcessor.
@@ -441,6 +442,8 @@ def __init__(
441442
layer_kernels : DotDict
442443
A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear"
443444
Defined in config/models/<model>.yaml
445+
graph_attention_backend: str, by default "triton"
446+
Backend to use for graph transformer conv, options are "triton" and "pyg"
444447
"""
445448
super().__init__(
446449
num_channels=num_channels,
@@ -465,6 +468,7 @@ def __init__(
465468
edge_dim=self.edge_dim,
466469
layer_kernels=self.layer_factory,
467470
qk_norm=qk_norm,
471+
graph_attention_backend=graph_attention_backend,
468472
)
469473

470474
self.offload_layers(cpu_offload)

models/src/anemoi/models/schemas/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class GraphTransformerDecoderSchema(TransformerModelComponent):
4040
@model_validator(mode="after")
4141
def check_valid_extras(self) -> Any:
4242
# This is a check to allow backwards compatibilty of the configs, as the extra fields are not required.
43-
allowed_extras = {"shard_strategy": str}
43+
allowed_extras = {"shard_strategy": str, "graph_attention_backend": str}
4444
extras = getattr(self, "__pydantic_extra__", {}) or {}
4545
for extra_field, value in extras.items():
4646
if extra_field not in allowed_extras:

models/src/anemoi/models/schemas/encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class GraphTransformerEncoderSchema(TransformerModelComponent):
3838
@model_validator(mode="after")
3939
def check_valid_extras(self) -> Any:
4040
# This is a check to allow backwards compatibilty of the configs, as the extra fields are not required.
41-
allowed_extras = {"shard_strategy": str}
41+
allowed_extras = {"shard_strategy": str, "graph_attention_backend": str}
4242
extras = getattr(self, "__pydantic_extra__", {}) or {}
4343
for extra_field, value in extras.items():
4444
if extra_field not in allowed_extras:

models/src/anemoi/models/schemas/processor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ class GraphTransformerProcessorSchema(TransformerModelComponent):
4343
qk_norm: bool = Field(example=False)
4444
"Normalize the query and key vectors. Default to False."
4545

46+
@model_validator(mode="after")
47+
def check_valid_extras(self) -> Any:
48+
# This is a check to allow backwards compatibilty of the configs, as the extra fields are not required.
49+
allowed_extras = {"graph_attention_backend": str}
50+
extras = getattr(self, "__pydantic_extra__", {}) or {}
51+
for extra_field, value in extras.items():
52+
if extra_field not in allowed_extras:
53+
msg = f"Extra field '{extra_field}' is not allowed. Allowed fields are: {list(allowed_extras.keys())}."
54+
raise ValueError(msg)
55+
if not isinstance(value, allowed_extras[extra_field]):
56+
msg = f"Extra field '{extra_field}' must be of type {allowed_extras[extra_field].__name__}."
57+
raise TypeError(msg)
58+
59+
return self
60+
4661

4762
class TransformerProcessorSchema(TransformerModelComponent):
4863
target_: Literal["anemoi.models.layers.processor.TransformerProcessor"] = Field(..., alias="_target_")

0 commit comments

Comments
 (0)