Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/contrib/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Modules included here might be moved to the main library in the future.
| [`rbcd_attack_poisoning.py`](./rbcd_attack_poisoning.py) | An example of the RBCD (Resource-Based Critical Data) attack with data poisoning strategies |
| [`pgm_explainer_node_classification.py`](./pgm_explainer_node_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for node classification |
| [`pgm_explainer_graph_classification.py`](./pgm_explainer_graph_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for graph classification |
| [`simple_fb15k_237.py`](./simple_fb15k_237.py) | An example of the SimplE knowledge graph embedding model on FB15k-237 dataset |
81 changes: 81 additions & 0 deletions examples/contrib/simple_fb15k_237.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import argparse
import os.path as osp

import torch
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.contrib.nn import SimplE

parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=200,
help='Hidden embedding size (default: 200)')
parser.add_argument('--batch_size', type=int, default=1000,
help='Batch size (default: 1000)')
parser.add_argument('--lr', type=float, default=0.05,
help='Learning rate (default: 0.05)')
parser.add_argument('--epochs', type=int, default=500,
help='Number of epochs (default: 500)')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data', 'FB15k')

train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

model = SimplE(
num_nodes=train_data.num_nodes,
num_relations=train_data.num_edge_types,
hidden_channels=args.hidden_channels,
).to(device)

loader = model.loader(
head_index=train_data.edge_index[0],
rel_type=train_data.edge_type,
tail_index=train_data.edge_index[1],
batch_size=args.batch_size,
shuffle=True,
)

# Use Adagrad optimizer as recommended in the SimplE paper
optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=1e-6)


def train():
model.train()
total_loss = total_examples = 0
for head_index, rel_type, tail_index in loader:
optimizer.zero_grad()
loss = model.loss(head_index, rel_type, tail_index)
loss.backward()
optimizer.step()
total_loss += float(loss) * head_index.numel()
total_examples += head_index.numel()
return total_loss / total_examples


@torch.no_grad()
def test(data):
model.eval()
return model.test(
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=20000,
k=10,
)


for epoch in range(1, args.epochs + 1):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if epoch % 25 == 0:
rank, mrr, hits = test(val_data)
print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')

rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
f'Test Hits@10: {hits_at_10:.4f}')
1 change: 1 addition & 0 deletions torch_geometric/contrib/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .conv import * # noqa
from .models import * # noqa
from .kge import * # noqa

__all__ = []
6 changes: 6 additions & 0 deletions torch_geometric/contrib/nn/kge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .simplE import SimplE

__all__ = classes = [
'SimplE',
]

105 changes: 105 additions & 0 deletions torch_geometric/contrib/nn/kge/simplE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding

from torch_geometric.nn.kge import KGEModel


class SimplE(KGEModel):
r"""The SimplE model from the `"SimplE Embedding for Link Prediction in
Knowledge Graphs" <https://proceedings.neurips.cc/paper/2018/file/
b2ab001909a8a6f04b51920306046ce5-Paper.pdf>`_ paper.

:class:`SimplE` addresses the independence of the two embedding vectors
for each entity in CP decomposition by using the inverse of relations.
The scoring function for a triple :math:`(h, r, t)` is defined as:

.. math::
d(h, r, t) = \frac{1}{2}(\langle \mathbf{e}_h, \mathbf{v}_r,
\mathbf{e}_t \rangle + \langle \mathbf{e}_t, \mathbf{v}_{r^{-1}},
\mathbf{e}_h \rangle)

where :math:`\langle \cdot, \cdot, \cdot \rangle` denotes the element-wise
product followed by sum, and :math:`\mathbf{v}_{r^{-1}}` is the embedding
for the inverse relation.

.. note::

For an example of using the :class:`SimplE` model, see
`examples/contrib/simple_fb15k_237.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
contrib/simple_fb15k_237.py>`_.

Args:
num_nodes (int): The number of nodes/entities in the graph.
num_relations (int): The number of relations in the graph.
hidden_channels (int): The hidden embedding size.
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to
the embedding matrices will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
sparse: bool = False,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)

# Additional embeddings for tail entities and inverse relations
self.node_emb_tail = Embedding(num_nodes, hidden_channels, sparse=sparse)
self.rel_emb_inv = Embedding(num_relations, hidden_channels, sparse=sparse)

self.reset_parameters()

def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.node_emb.weight)
torch.nn.init.xavier_uniform_(self.node_emb_tail.weight)
torch.nn.init.xavier_uniform_(self.rel_emb.weight)
torch.nn.init.xavier_uniform_(self.rel_emb_inv.weight)

def forward(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

# Get embeddings
head = self.node_emb(head_index) # h_{e_i}
tail = self.node_emb_tail(tail_index) # t_{e_j}
rel = self.rel_emb(rel_type) # v_r
rel_inv = self.rel_emb_inv(rel_type) # v_{r^{-1}}

# Get tail entity head embedding and head entity tail embedding
# for the inverse part
tail_head = self.node_emb(tail_index) # h_{e_j}
head_tail = self.node_emb_tail(head_index) # t_{e_i}

# Compute the two CP scores
# Score 1: ⟨h_{e_i}, v_r, t_{e_j}⟩
score1 = (head * rel * tail).sum(dim=-1)

# Score 2: ⟨h_{e_j}, v_{r^{-1}}, t_{e_i}⟩
score2 = (tail_head * rel_inv * head_tail).sum(dim=-1)

# SimplE score is the average of the two
return 0.5 * (score1 + score2)

def loss(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

pos_score = self(head_index, rel_type, tail_index)
neg_score = self(*self.random_sample(head_index, rel_type, tail_index))
scores = torch.cat([pos_score, neg_score], dim=0)

pos_target = torch.ones_like(pos_score)
neg_target = torch.zeros_like(neg_score)
target = torch.cat([pos_target, neg_target], dim=0)

return F.binary_cross_entropy_with_logits(scores, target)