Skip to content
Discussion options

You must be logged in to vote

If you want to retrieve attention scores in a heterogeneous GNN, then this is possible but a bit cumbersome. I suggest to write a HeteroGATConv implementation:

class HeteroGATConv(torch.nn.Module):
    def __init__(self, node_types, edge_types, channels):
        self.convs = torch.nn.ModuleDict()
        for edge_type in edge_types:
             self.convs[edge_type] = GATConv((-1, -1), channels)
             
     def forward(self, x_dict, edge_index_dict):
         out_dict = defaultdict(list)
         attention_dict = {}
         for edge_type edge_index in edge_index_dict.items():
             x_src = x_dict[edge_type[0]]
             x_dst = x_dict[edge_type[2]]
             out, att =

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@xubingze
Comment options

@xubingze
Comment options

@rusty1s
Comment options

@xubingze
Comment options

@rusty1s
Comment options

Answer selected by xubingze
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants