-
Hello, I have encountered a mysterios error while trying to execute a graph classification with one of the examples given in Heterogeneous GNN - Automatically Converting GNN Models (GAT). from torch_geometric.nn import GATConv, Linear, to_hetero
class GAT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
self.lin1 = Linear(-1, hidden_channels)
self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
self.lin2 = Linear(-1, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index) + self.lin1(x)
x = x.relu()
x = self.conv2(x, edge_index) + self.lin2(x)
return x
model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum') and am training it like this: def train():
model.train()
for data_train in train_dataset: # Iterate in batches over the training dataset.
out = model(data_train.x_dict, data_train.edge_index_dict) # Perform a single forward pass.
loss = criterion(out, data_train.y) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
def test(loader):
model.eval()
correct = 0
for data_test in loader: # Iterate in batches over the training/test dataset.
out = model(data_test.x_dict, data_test.edge_index_dict)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data_test.y).sum()) # Check against ground-truth labels.
return correct / len(loader) # Derive ratio of correct predictions.
for epoch in range(1, 15):
train()
train_acc = test(train_dataset)
test_acc = test(test_dataset)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}') Since my dataset consists of only 29 graphs, I don't need to batch the graphs, but I get this mysterious error:
I am genuinely confused as to why the Help would be greatly appreciated! Best regards, Zytrus |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 17 replies
-
Thanks for reporting and sorry for this mysterious error message. My assumption is that the metadata does not match across all your graphs (which is a requirement). That means, the following should pass: for data in dataset:
assert data.metadata() == dataset[0].metadata() |
Beta Was this translation helpful? Give feedback.
Thanks for reporting and sorry for this mysterious error message. My assumption is that the metadata does not match across all your graphs (which is a requirement). That means, the following should pass: