-
Hi, I'm currently working on implementing GNN using the MessagePassing class. While developing my GNN, I encountered a problem in implementing the GATConv module. Here's the code snippet I've written for GATConv using the MessagePassing class: import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.utils import softmax
class GATConv(gnn.MessagePassing):
def __init__(self, in_dim, out_dim, concat=True, n_heads=1):
super().__init__()
self.out_dim = out_dim
self.concat = concat
self.n_heads = n_heads
self.linear = nn.Linear(in_dim, out_dim * n_heads)
self.attn = nn.Linear(out_dim, 1)
def forward(self, x, edge_index):
h = self.linear(x).view(-1, self.n_heads, self.out_dim)
attn = self.attn(h)
print("edge indeindex",edge_index)
alpha = self.edge_updater(edge_index, attn=attn)
out = self.propagate(edge_index, h=h, alpha=alpha)
if self.concat:
out = out.view(-1, self.out_dim * self.n_heads)
else:
out = out.mean(dim=1)
return out
def edge_update(self, attn_j, index, ptr):
print(attn_j.shape)
alpha = F.leaky_relu(attn_j)
alpha = softmax(alpha, index, ptr)
return alpha
def message(self, h_j, alpha):
return alpha * h_j+ I also have an example graph data for testing: from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3]], dtype=torch.long)
x = torch.randn(5, 16)
graph_data = Data(x=x, edge_index=edge_index) However, when I run the code, I encounter the following error: IndexError: Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 0] (got interval [0, 4]) I'm puzzled about why this error is occurring. I've verified that the edge_index values are within the correct range [0, 4]. Can you help me understand the reason behind this issue? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You need to set |
Beta Was this translation helpful? Give feedback.
You need to set
super().__init__(node_dim=0)
.