3434from anemoi .models .layers .conv import GraphConv
3535from anemoi .models .layers .conv import GraphTransformerConv
3636from anemoi .models .layers .mlp import MLP
37+ from anemoi .models .triton .gt import GraphTransformerFunction
38+ from anemoi .models .triton .utils import edge_index_to_csc
3739from anemoi .utils .config import DotDict
3840
3941LOGGER = logging .getLogger (__name__ )
@@ -443,6 +445,7 @@ def __init__(
443445 qk_norm : bool = False ,
444446 update_src_nodes : bool = False ,
445447 layer_kernels : DotDict ,
448+ graph_attention_backend : str = "triton" ,
446449 ** kwargs ,
447450 ) -> None :
448451 """Initialize GraphTransformerBlock.
@@ -466,6 +469,8 @@ def __init__(
466469 layer_kernels : DotDict
467470 A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear"
468471 Defined in config/models/<model>.yaml
472+ graph_attention_backend: str, by default "triton"
473+ Backend to use for graph transformer conv, options are "triton" and "pyg"
469474 """
470475 super ().__init__ (** kwargs )
471476
@@ -483,8 +488,6 @@ def __init__(
483488 self .lin_self = Linear (in_channels , num_heads * self .out_channels_conv , bias = bias )
484489 self .lin_edge = Linear (edge_dim , num_heads * self .out_channels_conv ) # , bias=False)
485490
486- self .conv = GraphTransformerConv (out_channels = self .out_channels_conv )
487-
488491 self .projection = Linear (out_channels , out_channels )
489492
490493 if self .qk_norm :
@@ -499,6 +502,19 @@ def __init__(
499502 Linear (hidden_dim , out_channels ),
500503 )
501504
505+ self .graph_attention_backend = graph_attention_backend
506+ assert self .graph_attention_backend in [
507+ "triton" ,
508+ "pyg" ,
509+ ], f"Backend { self .graph_attention_backend } not supported for GraphTransformerBlock, valid options are 'triton' and 'pyg'"
510+
511+ if self .graph_attention_backend == "triton" :
512+ LOGGER .info (f"{ self .__class__ .__name__ } using triton graph attention backend." )
513+ self .conv = GraphTransformerFunction .apply
514+ else :
515+ LOGGER .warning (f"{ self .__class__ .__name__ } using pyg graph attention backend, consider using 'triton'." )
516+ self .conv = GraphTransformerConv (out_channels = self .out_channels_conv )
517+
502518 def run_node_dst_mlp (self , x , ** layer_kwargs ):
503519 return self .node_dst_mlp (self .layer_norm_mlp_dst (x , ** layer_kwargs ))
504520
@@ -555,37 +571,50 @@ def shard_qkve_heads(
555571
556572 return query , key , value , edges
557573
558- def attention_block (
574+ def apply_gt (
559575 self ,
560576 query : Tensor ,
561577 key : Tensor ,
562578 value : Tensor ,
563579 edges : Tensor ,
564580 edge_index : Adj ,
565581 size : Union [int , tuple [int , int ]],
566- num_chunks : int ,
567582 ) -> Tensor :
568583 # self.conv requires size to be a tuple
569584 conv_size = (size , size ) if isinstance (size , int ) else size
570585
586+ if self .graph_attention_backend == "triton" :
587+ csc , perm , reverse = edge_index_to_csc (edge_index , num_nodes = conv_size , reverse = True )
588+ edges_csc = edges .index_select (0 , perm )
589+ args_conv = (edges_csc , csc , reverse )
590+ else :
591+ args_conv = (edges , edge_index , conv_size )
592+
593+ return self .conv (query , key , value , * args_conv )
594+
595+ def attention_block (
596+ self ,
597+ query : Tensor ,
598+ key : Tensor ,
599+ value : Tensor ,
600+ edges : Tensor ,
601+ edge_index : Adj ,
602+ size : Union [int , tuple [int , int ]],
603+ num_chunks : int ,
604+ ) -> Tensor :
605+ # split 1-hop edges into chunks, compute self.conv chunk-wise
571606 if num_chunks > 1 :
572- # split 1-hop edges into chunks, compute self.conv chunk-wise
573607 edge_attr_list , edge_index_list = sort_edges_1hop_chunks (
574608 num_nodes = size , edge_attr = edges , edge_index = edge_index , num_chunks = num_chunks
575609 )
576610 # shape: (num_nodes, num_heads, out_channels_conv)
577611 out = torch .zeros ((* query .shape [:- 1 ], self .out_channels_conv ), device = query .device )
578612 for i in range (num_chunks ):
579- out += self .conv (
580- query = query ,
581- key = key ,
582- value = value ,
583- edge_attr = edge_attr_list [i ],
584- edge_index = edge_index_list [i ],
585- size = conv_size ,
613+ out += self .apply_gt (
614+ query = query , key = key , value = value , edges = edge_attr_list [i ], edge_index = edge_index_list [i ], size = size
586615 )
587616 else :
588- out = self .conv (query = query , key = key , value = value , edge_attr = edges , edge_index = edge_index , size = conv_size )
617+ out = self .apply_gt (query = query , key = key , value = value , edges = edges , edge_index = edge_index , size = size )
589618
590619 return out
591620
@@ -635,6 +664,7 @@ def __init__(
635664 update_src_nodes : bool = False ,
636665 layer_kernels : DotDict ,
637666 shard_strategy : str = "edges" ,
667+ graph_attention_backend : str = "triton" ,
638668 ** kwargs ,
639669 ) -> None :
640670 """Initialize GraphTransformerBlock.
@@ -662,6 +692,8 @@ def __init__(
662692 Defined in config/models/<model>.yaml
663693 shard_strategy: str, by default "edges"
664694 Strategy to shard tensors
695+ graph_attention_backend: str, by default "triton"
696+ Backend to use for graph transformer conv, options are "triton" and "pyg"
665697 """
666698
667699 super ().__init__ (
@@ -674,6 +706,7 @@ def __init__(
674706 bias = bias ,
675707 qk_norm = qk_norm ,
676708 update_src_nodes = update_src_nodes ,
709+ graph_attention_backend = graph_attention_backend ,
677710 ** kwargs ,
678711 )
679712
@@ -791,6 +824,7 @@ def __init__(
791824 qk_norm : bool = False ,
792825 update_src_nodes : bool = False ,
793826 layer_kernels : DotDict ,
827+ graph_attention_backend : str = "triton" ,
794828 ** kwargs ,
795829 ) -> None :
796830 """Initialize GraphTransformerBlock.
@@ -814,6 +848,8 @@ def __init__(
814848 layer_kernels : DotDict
815849 A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear"
816850 Defined in config/models/<model>.yaml
851+ graph_attention_backend: str, by default "triton"
852+ Backend to use for graph transformer conv, options are "triton" and "pyg"
817853 """
818854
819855 super ().__init__ (
@@ -826,6 +862,7 @@ def __init__(
826862 bias = bias ,
827863 qk_norm = qk_norm ,
828864 update_src_nodes = update_src_nodes ,
865+ graph_attention_backend = graph_attention_backend ,
829866 ** kwargs ,
830867 )
831868
0 commit comments