From e34eab1691123e8b0efa3b065bf049882dc18dff Mon Sep 17 00:00:00 2001 From: AMMAS1 Date: Tue, 11 Nov 2025 14:52:31 -0800 Subject: [PATCH 1/4] added kge/simplE.py --- torch_geometric/contrib/nn/__init__.py | 1 + torch_geometric/contrib/nn/kge/simplE.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 torch_geometric/contrib/nn/kge/simplE.py diff --git a/torch_geometric/contrib/nn/__init__.py b/torch_geometric/contrib/nn/__init__.py index 04d6ebb6056c..228482ed3acf 100644 --- a/torch_geometric/contrib/nn/__init__.py +++ b/torch_geometric/contrib/nn/__init__.py @@ -1,4 +1,5 @@ from .conv import * # noqa from .models import * # noqa +from .kge import * # noqa __all__ = [] diff --git a/torch_geometric/contrib/nn/kge/simplE.py b/torch_geometric/contrib/nn/kge/simplE.py new file mode 100644 index 000000000000..37415df4280c --- /dev/null +++ b/torch_geometric/contrib/nn/kge/simplE.py @@ -0,0 +1 @@ +# simplE implementation \ No newline at end of file From 00dd6cfaa288e2c2c405665f2c8f9feb48ffee8b Mon Sep 17 00:00:00 2001 From: laraselinseyahi Date: Tue, 11 Nov 2025 14:56:48 -0800 Subject: [PATCH 2/4] test --- torch_geometric/contrib/nn/kge/simplE.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_geometric/contrib/nn/kge/simplE.py b/torch_geometric/contrib/nn/kge/simplE.py index 37415df4280c..27ddc95fcf15 100644 --- a/torch_geometric/contrib/nn/kge/simplE.py +++ b/torch_geometric/contrib/nn/kge/simplE.py @@ -1 +1,3 @@ -# simplE implementation \ No newline at end of file +# simplE implementation + +# test From b503a247d53c6881070aaf07de993550f3006794 Mon Sep 17 00:00:00 2001 From: AMMAS1 Date: Tue, 11 Nov 2025 15:21:14 -0800 Subject: [PATCH 3/4] added the implementation file for simplE --- torch_geometric/contrib/nn/kge/__init__.py | 6 ++ torch_geometric/contrib/nn/kge/simplE.py | 105 ++++++++++++++++++++- 2 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 torch_geometric/contrib/nn/kge/__init__.py diff --git a/torch_geometric/contrib/nn/kge/__init__.py b/torch_geometric/contrib/nn/kge/__init__.py new file mode 100644 index 000000000000..1778f57ad4db --- /dev/null +++ b/torch_geometric/contrib/nn/kge/__init__.py @@ -0,0 +1,6 @@ +from .simplE import SimplE + +__all__ = classes = [ + 'SimplE', +] + diff --git a/torch_geometric/contrib/nn/kge/simplE.py b/torch_geometric/contrib/nn/kge/simplE.py index 27ddc95fcf15..2e2a962a4f83 100644 --- a/torch_geometric/contrib/nn/kge/simplE.py +++ b/torch_geometric/contrib/nn/kge/simplE.py @@ -1,3 +1,104 @@ -# simplE implementation +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Embedding -# test +from torch_geometric.nn.kge import KGEModel + + +class SimplE(KGEModel): + r"""The SimplE model from the `"SimplE Embedding for Link Prediction in + Knowledge Graphs" `_ paper. + + :class:`SimplE` addresses the independence of the two embedding vectors + for each entity in CP decomposition by using the inverse of relations. + The scoring function for a triple :math:`(h, r, t)` is defined as: + + .. math:: + d(h, r, t) = \frac{1}{2}(\langle \mathbf{e}_h, \mathbf{v}_r, + \mathbf{e}_t \rangle + \langle \mathbf{e}_t, \mathbf{v}_{r^{-1}}, + \mathbf{e}_h \rangle) + + where :math:`\langle \cdot, \cdot, \cdot \rangle` denotes the element-wise + product followed by sum, and :math:`\mathbf{v}_{r^{-1}}` is the embedding + for the inverse relation. + + .. note:: + + For an example of using the :class:`SimplE` model, see + `examples/contrib/simplE.py + `_. [TODO: ADD THE SIMPLE EXAMPLE] + + Args: + num_nodes (int): The number of nodes/entities in the graph. + num_relations (int): The number of relations in the graph. + hidden_channels (int): The hidden embedding size. + sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to + the embedding matrices will be sparse. (default: :obj:`False`) + """ + def __init__( + self, + num_nodes: int, + num_relations: int, + hidden_channels: int, + sparse: bool = False, + ): + super().__init__(num_nodes, num_relations, hidden_channels, sparse) + + # Additional embeddings for tail entities and inverse relations + self.node_emb_tail = Embedding(num_nodes, hidden_channels, sparse=sparse) + self.rel_emb_inv = Embedding(num_relations, hidden_channels, sparse=sparse) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.node_emb.weight) + torch.nn.init.xavier_uniform_(self.node_emb_tail.weight) + torch.nn.init.xavier_uniform_(self.rel_emb.weight) + torch.nn.init.xavier_uniform_(self.rel_emb_inv.weight) + + def forward( + self, + head_index: Tensor, + rel_type: Tensor, + tail_index: Tensor, + ) -> Tensor: + + # Get embeddings + head = self.node_emb(head_index) # h_{e_i} + tail = self.node_emb_tail(tail_index) # t_{e_j} + rel = self.rel_emb(rel_type) # v_r + rel_inv = self.rel_emb_inv(rel_type) # v_{r^{-1}} + + # Get tail entity head embedding and head entity tail embedding + # for the inverse part + tail_head = self.node_emb(tail_index) # h_{e_j} + head_tail = self.node_emb_tail(head_index) # t_{e_i} + + # Compute the two CP scores + # Score 1: ⟨h_{e_i}, v_r, t_{e_j}⟩ + score1 = (head * rel * tail).sum(dim=-1) + + # Score 2: ⟨h_{e_j}, v_{r^{-1}}, t_{e_i}⟩ + score2 = (tail_head * rel_inv * head_tail).sum(dim=-1) + + # SimplE score is the average of the two + return 0.5 * (score1 + score2) + + def loss( + self, + head_index: Tensor, + rel_type: Tensor, + tail_index: Tensor, + ) -> Tensor: + + pos_score = self(head_index, rel_type, tail_index) + neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) + scores = torch.cat([pos_score, neg_score], dim=0) + + pos_target = torch.ones_like(pos_score) + neg_target = torch.zeros_like(neg_score) + target = torch.cat([pos_target, neg_target], dim=0) + + return F.binary_cross_entropy_with_logits(scores, target) From 99bed8b32f0fa594729f74eecf4cfba269a3cbc5 Mon Sep 17 00:00:00 2001 From: AMMAS1 Date: Tue, 11 Nov 2025 15:50:16 -0800 Subject: [PATCH 4/4] added an example and tested it + modified simpleE function desc to include the example file --- examples/contrib/README.md | 1 + examples/contrib/simple_fb15k_237.py | 81 ++++++++++++++++++++++++ torch_geometric/contrib/nn/kge/simplE.py | 5 +- 3 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 examples/contrib/simple_fb15k_237.py diff --git a/examples/contrib/README.md b/examples/contrib/README.md index 3a45117741d2..44d037f72841 100644 --- a/examples/contrib/README.md +++ b/examples/contrib/README.md @@ -10,3 +10,4 @@ Modules included here might be moved to the main library in the future. | [`rbcd_attack_poisoning.py`](./rbcd_attack_poisoning.py) | An example of the RBCD (Resource-Based Critical Data) attack with data poisoning strategies | | [`pgm_explainer_node_classification.py`](./pgm_explainer_node_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for node classification | | [`pgm_explainer_graph_classification.py`](./pgm_explainer_graph_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for graph classification | +| [`simple_fb15k_237.py`](./simple_fb15k_237.py) | An example of the SimplE knowledge graph embedding model on FB15k-237 dataset | diff --git a/examples/contrib/simple_fb15k_237.py b/examples/contrib/simple_fb15k_237.py new file mode 100644 index 000000000000..f1f8eadd289f --- /dev/null +++ b/examples/contrib/simple_fb15k_237.py @@ -0,0 +1,81 @@ +import argparse +import os.path as osp + +import torch +import torch.optim as optim + +from torch_geometric.datasets import FB15k_237 +from torch_geometric.contrib.nn import SimplE + +parser = argparse.ArgumentParser() +parser.add_argument('--hidden_channels', type=int, default=200, + help='Hidden embedding size (default: 200)') +parser.add_argument('--batch_size', type=int, default=1000, + help='Batch size (default: 1000)') +parser.add_argument('--lr', type=float, default=0.05, + help='Learning rate (default: 0.05)') +parser.add_argument('--epochs', type=int, default=500, + help='Number of epochs (default: 500)') +args = parser.parse_args() + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data', 'FB15k') + +train_data = FB15k_237(path, split='train')[0].to(device) +val_data = FB15k_237(path, split='val')[0].to(device) +test_data = FB15k_237(path, split='test')[0].to(device) + +model = SimplE( + num_nodes=train_data.num_nodes, + num_relations=train_data.num_edge_types, + hidden_channels=args.hidden_channels, +).to(device) + +loader = model.loader( + head_index=train_data.edge_index[0], + rel_type=train_data.edge_type, + tail_index=train_data.edge_index[1], + batch_size=args.batch_size, + shuffle=True, +) + +# Use Adagrad optimizer as recommended in the SimplE paper +optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=1e-6) + + +def train(): + model.train() + total_loss = total_examples = 0 + for head_index, rel_type, tail_index in loader: + optimizer.zero_grad() + loss = model.loss(head_index, rel_type, tail_index) + loss.backward() + optimizer.step() + total_loss += float(loss) * head_index.numel() + total_examples += head_index.numel() + return total_loss / total_examples + + +@torch.no_grad() +def test(data): + model.eval() + return model.test( + head_index=data.edge_index[0], + rel_type=data.edge_type, + tail_index=data.edge_index[1], + batch_size=20000, + k=10, + ) + + +for epoch in range(1, args.epochs + 1): + loss = train() + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}') + if epoch % 25 == 0: + rank, mrr, hits = test(val_data) + print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, ' + f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}') + +rank, mrr, hits_at_10 = test(test_data) +print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, ' + f'Test Hits@10: {hits_at_10:.4f}') diff --git a/torch_geometric/contrib/nn/kge/simplE.py b/torch_geometric/contrib/nn/kge/simplE.py index 2e2a962a4f83..bb0d463bf218 100644 --- a/torch_geometric/contrib/nn/kge/simplE.py +++ b/torch_geometric/contrib/nn/kge/simplE.py @@ -27,8 +27,9 @@ class SimplE(KGEModel): .. note:: For an example of using the :class:`SimplE` model, see - `examples/contrib/simplE.py - `_. [TODO: ADD THE SIMPLE EXAMPLE] + `examples/contrib/simple_fb15k_237.py + `_. Args: num_nodes (int): The number of nodes/entities in the graph.