Batching possible in link prediction? #3151
Unanswered
paulilioaica
asked this question in
Q&A
Replies: 1 comment 4 replies
-
I think your code looks correct, so I think this is a problem with your data. You should check that it holds The following code works for me: import torch
from torch_geometric.nn import SAGEConv
import torch_geometric.transforms as T
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import batched_negative_sampling
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
add_negative_train_samples=False)
dataset = TUDataset('/tmp/TU', name='MUTAG', transform=transform)
train_dataset, val_dataset, test_dataset = zip(*dataset)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
model = GNN(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
def train():
model.train()
loss_epoch = 0
for data in train_loader:
optimizer.zero_grad()
z = model.encode(data.x, data.edge_index)
neg_edge_index = batched_negative_sampling(
data.edge_index, batch=data.batch,
num_neg_samples=data.edge_label_index.size(1), method='sparse')
edge_label_index = torch.cat(
[data.edge_label_index, neg_edge_index],
dim=-1,
)
edge_label = torch.cat([
data.edge_label,
data.edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)
out = model.decode(z, edge_label_index).view(-1)
loss = criterion(out, edge_label)
loss.backward()
loss_epoch += loss.item()
optimizer.step()
return loss_epoch / len(train_loader)
for epoch in range(1, 101):
loss = train()
print(loss) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am following the notebook example for the link prediction and as far as I have seen, all examples for link prediction only focus on one graph by the
data = dataset[0]
I am unsure wether training link prediction on graph batches is possible,
I have tried the following:
where dataloaders are created by
This yields the following error:
I thought this was related to some error in graph creation where the edge index went over the number of nodes, but I have checked this and I am sure this is not the issue
Any suggestions?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions