Link prediction on in-memory heterogeneous graph dataset #3633
-
Hi, I am modifying link_pred.py to support heterogeneous graphs for link prediction. The dataset I am using has three entities ('molecule', 'disease', 'protein') for which I have explicitly listed the relationships when building the data['molecule'].x = torch.from_numpy(mol_data)
data['protein'].x = torch.from_numpy(protein_data)
data['disease'].x = torch.from_numpy(dis_data)
data['protein', 'interacts with', 'protein'].edge_index = torch.from_numpy(np.vstack([np.array(p2p_edge_1), np.array(p2p_edge_2)]))
data['protein', 'is targeted by', 'molecule'].edge_index = torch.from_numpy(np.vstack([np.array(m2p_edge_2), np.array(m2p_edge_1)]))
data['molecule', 'is targeted by', 'protein'].edge_index = torch.from_numpy(np.vstack([np.array(m2p_edge_1), np.array(m2p_edge_2)]))
data['disease', 'is associated with', 'protein'].edge_index = torch.from_numpy(np.vstack([np.array(d2p_edge_2), np.array(d2p_edge_1)]))
data['protein', 'is associated with', 'disease'].edge_index = torch.from_numpy(np.vstack([np.array(d2p_edge_1), np.array(d2p_edge_2)])) The code I wrote for the model uses import torch
import torch.nn.functional as F
from torch_geometric.datasets import DBLP
from torch_geometric.nn import Linear, HGTConv
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
self.lin_dict[node_type] = Linear(-1, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
num_heads, group='sum')
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def encode(self, x_dict, edge_index_dict):
for node_type, x in x_dict.items():
x_dict[node_type] = self.lin_dict[node_type](x).relu_()
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return x_dict['protein'], x_dict['molecule']
def decode(self, prot, mol, edge_label_index):
return (prot[edge_label_index[0]] * mol[edge_label_index[1]]).sum(dim=-1)
model = HGT(hidden_channels=256, out_channels=64, num_heads=2, num_layers=2)
prot_rep, mol_rep = model.encode(train_data.x_dict, train_data.edge_index_dict)
out = model.decode(prot_rep, mol_rep, train_data['protein', 'is targeted by', 'molecule'].edge_label_index) Then I use negative sampling restricted to the 'target' edge class and def train():
model.train()
optimizer.zero_grad()
prot_rep, mol_rep = model.encode(data.x_dict, data.edge_index_dict)
neg_edge_index = negative_sampling(
edge_index=data['protein', 'is targeted by', 'molecule'].edge_index,
num_nodes=(data['protein'].x.shape[0], data['molecule'].x.shape[0]),
num_neg_samples=data['protein', 'is targeted by', 'molecule'].edge_index.size(1),
method='sparse')
edge_label_index = torch.cat(
[train_data['protein', 'is targeted by', 'molecule'].edge_label_index, neg_edge_index],
dim=-1,
)
edge_label = torch.cat([
train_data['protein', 'is targeted by', 'molecule'].edge_label,
train_data['protein', 'is targeted by', 'molecule'].edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)
out = model.decode(prot_rep, mol_rep, edge_label_index).view(-1)
loss = criterion(out, edge_label)
loss.backward()
optimizer.step()
return loss I am wondering if this reasoning makes sense or there is another route I can take to perform link prediction on heterogeneous graphs. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This looks absolutely perfect :) One thing I need to tell you is that
Edit: We also provide a |
Beta Was this translation helpful? Give feedback.
This looks absolutely perfect :)
One thing I need to tell you is that
HGTConv
is sadly bugged in2.0.2
, but it's already fixed in master (see 5f8e99d). You may need to install from master to fix this:Edit: We also provide a
hetero_link_pred
example here, which looks mostly similar.