diff --git a/CHANGELOG.md b/CHANGELOG.md index a590afe5e190..7ddf176898ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for Relative Temporal Encoding (RTE) in `HGTConv` to handle dynamic heterogeneous graphs ([#10469](https://github.com/pyg-team/pytorch_geometric/pull/10469)) - Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918)) - Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436)) - Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432)) diff --git a/test/nn/conv/test_hgt_conv.py b/test/nn/conv/test_hgt_conv.py index 541dacafef73..050447e1f2a0 100644 --- a/test/nn/conv/test_hgt_conv.py +++ b/test/nn/conv/test_hgt_conv.py @@ -1,4 +1,7 @@ +import pytest import torch +from torch.nn import CrossEntropyLoss, Linear +from torch.optim import Adam import torch_geometric.typing from torch_geometric.data import HeteroData @@ -234,6 +237,285 @@ def test_hgt_conv_missing_edge_type(): assert 'university' not in out_dict +def test_rte_on_vs_off(): + """Test whether RTE has an effect when enabled vs. disabled.""" + data = HeteroData() + data['author'].x = torch.randn(4, 16) + data['paper'].x = torch.randn(6, 32) + data['university'].x = torch.randn(10, 32) + + awp_edge = data['author', 'writes', 'paper'] + awp_edge.edge_index = get_random_edge_index(4, 6, 20) + awp_edge.time_diff = torch.randint(0, 100, (awp_edge.num_edges, )) + + uea_edge = data['university', 'employs', 'author'] + uea_edge.edge_index = get_random_edge_index(10, 4, 15) + uea_edge.time_diff = torch.zeros(uea_edge.num_edges, dtype=torch.long) + + metadata = data.metadata() + + torch.manual_seed(42) + conv_with_rte = HGTConv(-1, 64, metadata, heads=2, use_RTE=True) + + torch.manual_seed(42) + conv_without_rte = HGTConv(-1, 64, metadata, heads=2, use_RTE=False) + + out_dict_with_rte = conv_with_rte(data.x_dict, data.edge_index_dict, + data.time_diff_dict) + out_dict_without_rte = conv_without_rte(data.x_dict, data.edge_index_dict) + + author_out_with_rte = out_dict_with_rte['author'] + author_out_without_rte = out_dict_without_rte['author'] + + assert not torch.allclose(author_out_with_rte, author_out_without_rte) + + +def test_rte_sensitivity_to_time_values(): + """Tests the sensitivity of the HGTConv layer to its temporal inputs. + + This test ensures that when the `edge_time_diff_dict` values are + modified, the output embeddings of the HGTConv layer with RTE enabled + also change. + """ + data = HeteroData() + data['author'].x = torch.randn(4, 16) + data['paper'].x = torch.randn(6, 32) + data['university'].x = torch.randn(10, 32) + + awp_edge = data['author', 'writes', 'paper'] + awp_edge.edge_index = get_random_edge_index(4, 6, 20) + awp_edge.time_diff = torch.randint(0, 100, (awp_edge.num_edges, )) + + uae_edge = data['university', 'employs', 'author'] + uae_edge.edge_index = get_random_edge_index(10, 4, 15) + uae_edge.time_diff = torch.zeros(uae_edge.num_edges, dtype=torch.long) + + metadata = data.metadata() + torch.manual_seed(42) + conv = HGTConv(-1, 64, metadata, heads=2, use_RTE=True) + + out_dict_1 = conv(data.x_dict, data.edge_index_dict, data.time_diff_dict) + author_out_1 = out_dict_1['author'] + + data_alt_time = data.clone() + for edge_type in data.edge_types: + if 'time_diff' in data[edge_type]: + data_alt_time[ + edge_type].time_diff = data[edge_type].time_diff + 100 + + out_dict_2 = conv(data.x_dict, data.edge_index_dict, + data_alt_time.time_diff_dict) + author_out_2 = out_dict_2['author'] + + assert not torch.allclose(author_out_1, author_out_2) + + +def test_rte_zero_time_diff(): + """Tests that a zero time difference produces a different output. + + This test ensures that the output of the HGTConv layer with RTE is + different when given zero time differences compared to when RTE is + set to false. + """ + data = HeteroData() + data['author'].x = torch.randn(4, 16) + data['paper'].x = torch.randn(6, 32) + data['university'].x = torch.randn(10, 32) + + uea_edge = data['university', 'employs', 'author'] + uea_edge.edge_index = get_random_edge_index(10, 4, 15) + uea_edge.time_diff = torch.zeros(uea_edge.num_edges, dtype=torch.long) + + awp_edge = data['author', 'writes', 'paper'] + awp_edge.edge_index = get_random_edge_index(4, 6, 20) + awp_edge.time_diff = torch.zeros(awp_edge.num_edges, dtype=torch.long) + + metadata = data.metadata() + torch.manual_seed(42) + conv_with_rte = HGTConv(-1, 64, metadata, heads=2, use_RTE=True) + + out_dict_zero = conv_with_rte(data.x_dict, data.edge_index_dict, + data.time_diff_dict) + author_out_zero = out_dict_zero['author'] + + torch.manual_seed(42) + conv_without_rte = HGTConv(-1, 64, metadata, heads=2, use_RTE=False) + out_dict_without_rte = conv_without_rte(data.x_dict, data.edge_index_dict) + author_out_without_rte = out_dict_without_rte['author'] + + assert not torch.allclose(author_out_zero, author_out_without_rte) + + +def test_rte_raises_error_if_time_is_missing(): + """Tests that a ValueError is raised if RTE is on but no time is given.""" + data = HeteroData() + data['author'].x = torch.randn(4, 16) + data['paper'].x = torch.randn(6, 16) + + awp_edge = data['author', 'writes', 'paper'] + awp_edge.edge_index = get_random_edge_index(4, 6, 20) + + metadata = data.metadata() + conv = HGTConv(-1, 32, metadata, heads=2, use_RTE=True) + + with pytest.raises(ValueError, match="RTE enabled, but no"): + conv(data.x_dict, data.edge_index_dict) + + +def test_rte_warns_if_time_is_provided_but_unused(): + """Tests that a warning is raised if time is given but RTE deactivated.""" + data = HeteroData() + data['author'].x = torch.randn(4, 16) + data['paper'].x = torch.randn(6, 16) + awp_edge = data['author', 'writes', 'paper'] + awp_edge.edge_index = get_random_edge_index(4, 6, 20) + awp_edge.time_diff = torch.randint(0, 100, (awp_edge.num_edges, )) + + metadata = data.metadata() + conv = HGTConv(-1, 32, metadata, heads=2, use_RTE=False) + + with pytest.warns(UserWarning, match="'use_RTE' is False, but"): + conv(data.x_dict, data.edge_index_dict, data.time_diff_dict) + + +def test_rte_raises_error_if_time_key_is_missing(): + """Tests ValueError is raised if time for one edge type is missing.""" + data = HeteroData() + data['author'].x = torch.randn(4, 16) + data['paper'].x = torch.randn(6, 32) + data['university'].x = torch.randn(10, 32) + + uea_edge = data['university', 'employs', 'author'] + uea_edge.edge_index = get_random_edge_index(10, 4, 15) + uea_edge.time_diff = torch.randint(0, 100, (uea_edge.num_edges, )) + + awp_edge = data['author', 'writes', 'paper'] + awp_edge.edge_index = get_random_edge_index(4, 6, 20) + + metadata = data.metadata() + torch.manual_seed(42) + + conv = HGTConv(-1, 32, metadata, heads=2, use_RTE=True) + + with pytest.raises(ValueError, match="'time_diff' missing for edge type"): + conv(data.x_dict, data.edge_index_dict, data.time_diff_dict) + + +def test_hgt_conv_rte_behavioral(): + """Tests if HGTConv with RTE can learn a purely time-dependent rule. + + Each 'source' node has two outgoing edges. The edge with the smaller + `time_diff` is labeled as correct (1). + + The test asserts that the model with `use_RTE=True` successfully + learns this rule (high accuracy), while the model with `use_RTE=False` + fails (accuracy near random chance of 0.5). + """ + num_source_nodes = 50 + data = HeteroData() + data['source'].x = torch.randn(num_source_nodes, 16) + data['target'].x = torch.randn(num_source_nodes * 2, 16) + + source_indices = [] + target_indices = [] + time_list = [] + label_list = [] + + for i in range(num_source_nodes): + target1 = i * 2 + target2 = i * 2 + 1 + + identical_target_features = torch.randn(1, 16) + data['target'].x[target1] = identical_target_features + data['target'].x[target2] = identical_target_features + + # Randomly decide which target receives the "fast" edge + if torch.rand(1) > 0.5: + # In this case, the edge to target1 is the faster one + time_for_target1 = 5.0 + label_for_target1 = 1 + + time_for_target2 = 50.0 + label_for_target2 = 0 + else: + # In this case, the edge to target2 is the faster one + time_for_target1 = 50.0 + label_for_target1 = 0 + + time_for_target2 = 5.0 + label_for_target2 = 1 + + source_indices.extend([i, i]) + target_indices.extend([target1, target2]) + time_list.extend([time_for_target1, time_for_target2]) + label_list.extend([label_for_target1, label_for_target2]) + + edge_index = torch.tensor([source_indices, target_indices]) + data['source', 'to', 'target'].edge_index = edge_index + data['source', 'to', 'target'].time_diff = torch.tensor(time_list) + data['source', 'to', 'target'].y = torch.tensor(label_list) + + data['target', 'rev_to', 'source'].edge_index = edge_index.flip(0) + data['target', 'rev_to', 'source'].time_diff = torch.zeros(len(time_list)) + + metadata = data.metadata() + + class HGTEdgeClassifier(torch.nn.Module): + def __init__(self, out_channels, use_rte=True): + super().__init__() + self.conv = HGTConv(-1, out_channels, metadata, heads=2, + use_RTE=use_rte) + self.classifier = Linear(out_channels * 2, 2) + + def forward(self, x_dict, edge_index_dict, edge_time_diff_dict, + edge_label_index): + x_dict = self.conv(x_dict, edge_index_dict, edge_time_diff_dict) + src_emb = x_dict['source'][edge_label_index[0]] + dst_emb = x_dict['target'][edge_label_index[1]] + edge_emb = torch.cat([src_emb, dst_emb], dim=-1) + return self.classifier(edge_emb) + + def train_and_test(use_rte): + torch.manual_seed(42) + model = HGTEdgeClassifier(out_channels=16, use_rte=use_rte) + optimizer = Adam(model.parameters(), lr=0.01) + criterion = CrossEntropyLoss() + + args = [ + data.x_dict, data.edge_index_dict, data.time_diff_dict, + data['source', 'to', 'target'].edge_index + ] + edge_data = data['source', 'to', 'target'] + + for _ in range(20): + optimizer.zero_grad() + + if not use_rte: + with pytest.warns(UserWarning, match="'use_RTE' is False"): + logits = model(*args) + else: + logits = model(*args) + + loss = criterion(logits, edge_data.y) + loss.backward() + optimizer.step() + + with torch.no_grad(): + if not use_rte: + with pytest.warns(UserWarning, match="'use_RTE' is False"): + pred = model(*args).argmax(dim=-1) + else: + pred = model(*args).argmax(dim=-1) + + return (pred == edge_data.y).float().mean().item() + + acc_with_rte = train_and_test(use_rte=True) + assert acc_with_rte >= 0.95 + + acc_without_rte = train_and_test(use_rte=False) + assert acc_without_rte <= 0.6 + + if __name__ == '__main__': import argparse diff --git a/torch_geometric/nn/conv/hgt_conv.py b/torch_geometric/nn/conv/hgt_conv.py index 2e4101ef7003..3fff7878c9c9 100644 --- a/torch_geometric/nn/conv/hgt_conv.py +++ b/torch_geometric/nn/conv/hgt_conv.py @@ -1,4 +1,5 @@ import math +import warnings from typing import Dict, List, Optional, Tuple, Union import torch @@ -7,6 +8,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear +from torch_geometric.nn.encoding import PositionalEncoding from torch_geometric.nn.inits import ones from torch_geometric.nn.parameter_dict import ParameterDict from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType @@ -37,6 +39,9 @@ class HGTConv(MessagePassing): information. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) + use_RTE (bool, optional): If set to :obj:`True`, the layer uses + Relative Temporal Encoding (RTE). + (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ @@ -46,6 +51,7 @@ def __init__( out_channels: int, metadata: Metadata, heads: int = 1, + use_RTE: bool = False, **kwargs, ): super().__init__(aggr='add', node_dim=0, **kwargs) @@ -60,6 +66,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.heads = heads + self.use_RTE = use_RTE self.node_types = metadata[0] self.edge_types = metadata[1] self.edge_types_map = { @@ -93,6 +100,9 @@ def __init__( edge_type = '__'.join(edge_type) self.p_rel[edge_type] = Parameter(torch.empty(1, heads)) + if self.use_RTE: + self.rte = PositionalEncoding(self.out_channels) + self.reset_parameters() def reset_parameters(self): @@ -153,10 +163,35 @@ def _construct_src_node_feat( return k, v, offset + def _validate_inputs( + self, + edge_index_dict: Dict[EdgeType, Adj], + edge_time_diff_dict: Optional[Dict[EdgeType, Tensor]], + ) -> None: + """Helper function to validate inputs for temporal encoding.""" + if not self.use_RTE and edge_time_diff_dict is not None: + warnings.warn( + "'use_RTE' is False, but 'edge_time_diff_dict' was provided. " + "The temporal data will be ignored.", stacklevel=2) + return + + if self.use_RTE: + if edge_time_diff_dict is None: + raise ValueError( + "RTE enabled, but no 'edge_time_diff_dict' was provided.") + + for edge_type in edge_index_dict.keys(): + if edge_type not in edge_time_diff_dict: + raise ValueError( + "RTE enabled, but 'time_diff' missing for edge type: " + f"{edge_type}. " + "All edge types must have a time_diff attribute.") + def forward( self, x_dict: Dict[NodeType, Tensor], - edge_index_dict: Dict[EdgeType, Adj] # Support both. + edge_index_dict: Dict[EdgeType, Adj], # Support both. + edge_time_diff_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> Dict[NodeType, Optional[Tensor]]: r"""Runs the forward pass of the module. @@ -168,12 +203,19 @@ def forward( individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. + edge_time_diff_dict (Dict[EdgeType, torch.Tensor], optional): + A dictionary holding time differences (∆T) for each + individual edge type. Each entry must be a 1D + :class:`torch.Tensor` of shape :obj:`[num_edges]`. It must be + provided if :obj:`use_RTE=True`. (default: :obj:`None`) :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to :obj:`None`. """ + self._validate_inputs(edge_index_dict, edge_time_diff_dict) + F = self.out_channels H = self.heads D = F // H @@ -196,7 +238,16 @@ def forward( edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel, num_nodes=k.size(0)) - out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr) + temporal_features = None + if self.use_RTE: + _, edge_time_diff = construct_bipartite_edge_index( + edge_index_dict, src_offset, dst_offset, + edge_attr_dict=edge_time_diff_dict, num_nodes=k.size(0)) + + temporal_features = self.rte(edge_time_diff).view(-1, H, D) + + out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr, + temporal_features=temporal_features) # Reconstruct output node embeddings dict: for node_type, start_offset in dst_offset.items(): @@ -224,7 +275,11 @@ def forward( def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: Optional[Tensor], + temporal_features: Optional[Tensor], size_i: Optional[int]) -> Tensor: + if temporal_features is not None: + k_j = k_j + temporal_features + v_j = v_j + temporal_features alpha = (q_i * k_j).sum(dim=-1) * edge_attr alpha = alpha / math.sqrt(q_i.size(-1)) alpha = softmax(alpha, index, ptr, size_i)