-
I am trying to use RGCN for link prediction using the DistMult interaction model as was done in the original paper and just wanted to sanity check it. DGl have an implementation using this approach which I have based my own on and is available here: https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn/link_predict.py I have implemented the DistMult score function in PyG as follows: out = model(data.edge_index, data.edge_type) # Create the embeddings
# DistMult
s = out[data.edge_index[0,:]]
r = model.conv2.weight[data.edge_type.long(), :, 0]
o = out[data.edge_index[1,:]]
pos_score = torch.sum(s * r * o , dim=1) I also compute the score for some negative edges and then calculate the loss as follows: labels = torch.cat((torch.ones_like(pos_score), torch.zeros_like(neg_score)),dim=0)
preds = torch.cat((pos_score, neg_score), dim=0)
loss = F.binary_cross_entropy_with_logits(preds, labels) Where the model is defined as: class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = RGCNConv(data.num_nodes, args.hidden_units, data.num_relations)
self.conv2 = RGCNConv(args.hidden_units, args.hidden_units, data.num_relations)
def forward(self, edge_index, edge_type):
x = F.relu(self.conv1(None, edge_index, edge_type))
x = self.conv2(x, edge_index, edge_type)
return x Does this look roughly correct, is it indeed conv2.weight I should be using to get the relation embedding? If I were to persist with this, could it be an example I submitted as a pull request? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 12 replies
-
This looks correct to me. Instead of using rel_weight = torch.nn.Parameter(torch.Tensor(num_relations, hidden_channels))
# Initialize `rel_weight`
score = torch.sum(s * rel_weight[edge_type] * o, dim=-1) |
Beta Was this translation helpful? Give feedback.
-
@rusty1s |
Beta Was this translation helpful? Give feedback.
This looks correct to me. Instead of using
model.conv2.weight
, I suggest to use a distinct parameter for this (like DGL is doing, too):