Skip to content

Commit 9ac9aa3

Browse files
Adding PatchTransformerAggregation (#9487)
Adding patch transformer architecture from dynamic graph learning literature as a PyG aggregation module. Mostly based on "Simplifying Temporal Heterogeneous Network for Continuous-Time Link Prediction" - [x] added patch transformer aggregation module in `torch_geometric/nn.aggr.patch_transformer.py` - [x] added all related util functions in `pt_utils.py` - [x] passed precommit hooks - [x] added `test_patch_transformer_aggr.py` in `test/nn/aggr/ , test passed - [x] add documentation for helper functions - [ ] integrate into workflow - [x] test on dataset Notes: - `PatchTransformerAggregation` requires edge features(`edge_feats`) and edge timestamps(`edge_ts`) - `PatchTransformerAggregation` primarily focus on aggregating edge information for each node based on their past interactions, the embedding of edges are fed into the patch transformer blocks which learns interaction between these edges and by pooling over multiple patches, the representation of each node is generated. Lastly, this representation is merged with the input node features (`x`) to produce an output. - `PatchTransformerAggregation` do not run `self.to_dense_batch()` at the moment thus do not reshape the input `x` tensor. - the output dimension is set via the `out_dim` argument --------- Co-authored-by: rusty1s <[email protected]>
1 parent fbafbc4 commit 9ac9aa3

File tree

4 files changed

+173
-0
lines changed

4 files changed

+173
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487))
1011
- Added the `nn.nlp.LLM` model ([#9462](https://github.com/pyg-team/pytorch_geometric/pull/9462))
1112
- Added an example of training GNNs for a graph-level regression task ([#9070](https://github.com/pyg-team/pytorch_geometric/pull/9070))
1213
- Added `utils.from_rdmol`/`utils.to_rdmol` functionality ([#9452](https://github.com/pyg-team/pytorch_geometric/pull/9452))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
3+
from torch_geometric.nn import PatchTransformerAggregation
4+
from torch_geometric.testing import withCUDA
5+
6+
7+
@withCUDA
8+
def test_patch_transformer_aggregation(device: torch.device) -> None:
9+
aggr = PatchTransformerAggregation(
10+
in_channels=16,
11+
out_channels=32,
12+
patch_size=2,
13+
hidden_channels=8,
14+
num_transformer_blocks=1,
15+
heads=2,
16+
dropout=0.2,
17+
aggr=['sum', 'mean', 'min', 'max', 'var', 'std'],
18+
).to(device)
19+
aggr.reset_parameters()
20+
assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)'
21+
22+
index = torch.tensor([0, 0, 1, 1, 1, 2], device=device)
23+
x = torch.randn(index.size(0), 16, device=device)
24+
25+
out = aggr(x, index)
26+
assert out.device == device
27+
assert out.size() == (3, aggr.out_channels)

torch_geometric/nn/aggr/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .set_transformer import SetTransformerAggregation
2626
from .lcm import LCMAggregation
2727
from .variance_preserving import VariancePreservingAggregation
28+
from .patch_transformer import PatchTransformerAggregation
2829

2930
__all__ = classes = [
3031
'Aggregation',
@@ -53,4 +54,5 @@
5354
'SetTransformerAggregation',
5455
'LCMAggregation',
5556
'VariancePreservingAggregation',
57+
'PatchTransformerAggregation',
5658
]
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import math
2+
from typing import List, Optional, Union
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from torch_geometric.experimental import disable_dynamic_shapes
8+
from torch_geometric.nn.aggr import Aggregation
9+
from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
10+
from torch_geometric.nn.encoding import PositionalEncoding
11+
from torch_geometric.utils import scatter
12+
13+
14+
class PatchTransformerAggregation(Aggregation):
15+
r"""Performs patch transformer aggregation in which the elements to
16+
aggregate are processed by multi-head attention blocks across patches, as
17+
described in the `"Simplifying Temporal Heterogeneous Network for
18+
Continuous-Time Link Prediction"
19+
<https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.
20+
21+
Args:
22+
in_channels (int): Size of each input sample.
23+
out_channels (int): Size of each output sample.
24+
patch_size (int): Number of elements in a patch.
25+
hidden_channels (int): Intermediate size of each sample.
26+
num_transformer_blocks (int, optional): Number of transformer blocks
27+
(default: :obj:`1`).
28+
heads (int, optional): Number of multi-head-attentions.
29+
(default: :obj:`1`)
30+
dropout (float, optional): Dropout probability of attention weights.
31+
(default: :obj:`0.0`)
32+
aggr (str or list[str], optional): The aggregation module, *e.g.*,
33+
:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
34+
:obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
35+
"""
36+
def __init__(
37+
self,
38+
in_channels: int,
39+
out_channels: int,
40+
patch_size: int,
41+
hidden_channels: int,
42+
num_transformer_blocks: int = 1,
43+
heads: int = 1,
44+
dropout: float = 0.0,
45+
aggr: Union[str, List[str]] = 'mean',
46+
) -> None:
47+
super().__init__()
48+
49+
self.in_channels = in_channels
50+
self.out_channels = out_channels
51+
self.patch_size = patch_size
52+
self.aggrs = [aggr] if isinstance(aggr, str) else aggr
53+
54+
assert len(self.aggrs) > 0
55+
for aggr in self.aggrs:
56+
assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
57+
58+
self.lin = torch.nn.Linear(in_channels, hidden_channels)
59+
self.pad_projector = torch.nn.Linear(
60+
patch_size * hidden_channels,
61+
hidden_channels,
62+
)
63+
self.pe = PositionalEncoding(hidden_channels)
64+
65+
self.blocks = torch.nn.ModuleList([
66+
MultiheadAttentionBlock(
67+
channels=hidden_channels,
68+
heads=heads,
69+
layer_norm=True,
70+
dropout=dropout,
71+
) for _ in range(num_transformer_blocks)
72+
])
73+
74+
self.fc = torch.nn.Linear(
75+
hidden_channels * len(self.aggrs),
76+
out_channels,
77+
)
78+
79+
def reset_parameters(self) -> None:
80+
self.lin.reset_parameters()
81+
self.pad_projector.reset_parameters()
82+
self.pe.reset_parameters()
83+
for block in self.blocks:
84+
block.reset_parameters()
85+
self.fc.reset_parameters()
86+
87+
@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
88+
def forward(
89+
self,
90+
x: Tensor,
91+
index: Tensor,
92+
ptr: Optional[Tensor] = None,
93+
dim_size: Optional[int] = None,
94+
dim: int = -2,
95+
max_num_elements: Optional[int] = None,
96+
) -> Tensor:
97+
98+
if max_num_elements is None:
99+
if ptr is not None:
100+
count = ptr.diff()
101+
else:
102+
count = scatter(torch.ones_like(index), index, dim=0,
103+
dim_size=dim_size, reduce='sum')
104+
max_num_elements = int(count.max()) + 1
105+
106+
# Set `max_num_elements` to a multiple of `patch_size`:
107+
max_num_elements = (math.floor(max_num_elements / self.patch_size) *
108+
self.patch_size)
109+
110+
x = self.lin(x)
111+
112+
# TODO If groups are heavily unbalanced, this will create a lot of
113+
# "empty" patches. Try to figure out a way to fix this.
114+
# [batch_size, num_patches * patch_size, hidden_channels]
115+
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
116+
max_num_elements=max_num_elements)
117+
118+
# [batch_size, num_patches, patch_size * hidden_channels]
119+
x = x.view(x.size(0), max_num_elements // self.patch_size,
120+
self.patch_size * x.size(-1))
121+
122+
# [batch_size, num_patches, hidden_channels]
123+
x = self.pad_projector(x)
124+
125+
x = x + self.pe(torch.arange(x.size(1), device=x.device))
126+
127+
# [batch_size, num_patches, hidden_channels]
128+
for block in self.blocks:
129+
x = block(x, x)
130+
131+
# [batch_size, hidden_channels]
132+
outs: List[Tensor] = []
133+
for aggr in self.aggrs:
134+
out = getattr(torch, aggr)(x, dim=1)
135+
outs.append(out[0] if isinstance(out, tuple) else out)
136+
out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]
137+
138+
# [batch_size, out_channels]
139+
return self.fc(out)
140+
141+
def __repr__(self) -> str:
142+
return (f'{self.__class__.__name__}({self.in_channels}, '
143+
f'{self.out_channels}, patch_size={self.patch_size})')

0 commit comments

Comments
 (0)