-
I am trying to test the costum convolutional layer presented in the torch geometric documentation: I made two similar models, one using the GCNConv from pytorch, the other the custom conv based on their documentation. Im testing these two models with the GNNBenchmarkDataset 'PATTERN' (node level prediction). The built-in version works well on the validation set, but the custom one does not work at all (no positives predicted). the settings for the two tests are absolutely similar (just copied the code and changed the model name). Am i missing something here? As an additional information, I checked and the models have the same number of parameters. built_in model: class Gcn(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
# Init parent
super(Gcn, self).__init__()
torch.manual_seed(42)
self.learning_rate = 0.01
# GCN layers
self.initial_conv = GCNConv(input_dim, hidden_dim)
self.conv1 = GCNConv(hidden_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, hidden_dim)
# Output layer
self.out = Linear(hidden_dim, output_dim)
def forward(self, data):
x, edge_index, batch_index = data.x, data.edge_index, data.batch
# First Conv layer
hidden = self.initial_conv(x, edge_index)
hidden = torch.tanh(hidden)
# Other Conv layers
hidden = self.conv1(hidden, edge_index)
hidden = torch.tanh(hidden)
hidden = self.conv2(hidden, edge_index)
hidden = torch.tanh(hidden)
hidden = self.conv3(hidden, edge_index)
hidden = torch.tanh(hidden)
# Apply a final (linear) classifier.
out = self.out(hidden)
return out custom model: class Gcn(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
# Init parent
super(Gcn, self).__init__()
torch.manual_seed(42)
self.learning_rate = 0.01
# GCN layers
self.initial_conv = GcnConv(input_dim, hidden_dim)
self.conv1 = GcnConv(hidden_dim, hidden_dim)
self.conv2 = GcnConv(hidden_dim, hidden_dim)
self.conv3 = GcnConv(hidden_dim, hidden_dim)
# Output layer
self.out = Linear(hidden_dim, output_dim)
def forward(self, data):
x, edge_index, batch_index = data.x, data.edge_index, data.batch
# First Conv layer
hidden = self.initial_conv(x, edge_index)
hidden = torch.tanh(hidden)
# Other Conv layers
hidden = self.conv1(hidden, edge_index)
hidden = torch.tanh(hidden)
hidden = self.conv2(hidden, edge_index)
hidden = torch.tanh(hidden)
hidden = self.conv3(hidden, edge_index)
hidden = torch.tanh(hidden)
# Apply a final (linear) classifier.
out = self.out(hidden)
return out
class GcnConv(MessagePassing):
"""
using this website: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
"""
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j EDIT: Knowing all this, the custom conv should look more like this:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Yes, it looks like applying the bias before-hand leads to worse results since it will get additionally aggregated across neighbors instead of being applied once. Do you mind fixing this in our documentation? :) |
Beta Was this translation helpful? Give feedback.
Yes, it looks like applying the bias before-hand leads to worse results since it will get additionally aggregated across neighbors instead of being applied once. Do you mind fixing this in our documentation? :)