Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Added `EGT` model and example ([#8280](https://github.com/pyg-team/pytorch_geometric/pull/8280))
- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436))
- Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432))
Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see
- Uses SGFormer (a kind of GraphTransformer) by default.
- [SGFormer Paper](https://arxiv.org/pdf/2306.10759)
- [Polynormer](https://arxiv.org/pdf/2403.01232)
- [EGT](https://arxiv.org/abs/2108.03348)
- [Kumo.ai x NVIDIA x Stanford Graph Transformer Webinar](https://www.youtube.com/watch?v=wAYryx3GjLw)
- [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset.
- [`ogbn_train_perforatedai.py`](https://github.com/PerforatedAI/PerforatedAI-Examples/tree/master/otherExamples/torch_geometric/OGBNProducts) shows how to optimize the `ogbn_train.py` workflow using [Perforated AI](https://github.com/PerforatedAI/PerforatedAI-API). Perforated AI provides a PyTorch add-on which increases network accuracy by empowering each artificial neuron with artificial dendrites.
Expand Down
40 changes: 30 additions & 10 deletions examples/ogbn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from torch_geometric import seed_everything
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer
from torch_geometric.nn.models import EGT, GAT, GraphSAGE, Polynormer, SGFormer
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
Expand All @@ -37,7 +37,7 @@
"--model",
type=str.lower,
default='SGFormer',
choices=['sage', 'gat', 'sgformer', 'polynormer'],
choices=['sage', 'gat', 'sgformer', 'polynormer', 'egt'],
help="Model used for training",
)

Expand Down Expand Up @@ -103,6 +103,13 @@
data.edge_index, _ = add_self_loops(data.edge_index,
num_nodes=data.num_nodes)

if args.model == 'egt':
print("EGT model requires edge features, using node features"
"to initialize edge features.")
row, col = data.edge_index
data.edge_attr = torch.stack([data.x[row].mean(-1), data.x[col].mean(-1)],
dim=1)

data.to(device, 'x', 'y')


Expand Down Expand Up @@ -132,16 +139,18 @@ def train(epoch: int) -> tuple[Tensor, float]:

total_loss = total_correct = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
if args.model in ['sgformer', 'polynormer']:
if args.model == 'polynormer' and epoch == args.local_epochs:
print('start global attention')
model._global = True
out = model(batch.x, batch.edge_index.to(device),
batch.batch.to(device))[:batch.batch_size]
out = model(batch.x, batch.edge_index, batch.batch)
elif args.model in ['egt']:
out = model(batch.x, batch.edge_index, batch.edge_attr)
else:
out = model(batch.x,
batch.edge_index.to(device))[:batch.batch_size]
out = model(batch.x, batch.edge_index)
out = out[:batch.batch_size]
y = batch.y[:batch.batch_size].squeeze().to(torch.long)
loss = F.cross_entropy(out, y)
loss.backward()
Expand All @@ -166,11 +175,12 @@ def test(loader: NeighborLoader) -> float:
batch = batch.to(device)
batch_size = batch.num_sampled_nodes[0]
if args.model in ['sgformer', 'polynormer']:
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
out = model(batch.x, batch.edge_index, batch.batch)
elif args.model in ['egt']:
out = model(batch.x, batch.edge_index, batch.edge_attr)
else:
out = model(batch.x, batch.edge_index)[:batch_size]
pred = out.argmax(dim=-1)
out = model(batch.x, batch.edge_index)
pred = out[:batch_size].argmax(dim=-1)
y = batch.y[:batch_size].view(-1).to(torch.long)

total_correct += int((pred == y).sum())
Expand Down Expand Up @@ -214,6 +224,16 @@ def get_model(model_name: str) -> torch.nn.Module:
out_channels=dataset.num_classes,
local_layers=num_layers,
)
elif model_name == 'egt':
model = EGT(
node_channels=dataset.num_features,
edge_channels=2,
out_channels=dataset.num_classes,
edge_update=False,
num_layers=num_layers,
num_heads=args.num_heads,
dropout=args.dropout,
)
else:
raise ValueError(f'Unsupported model type: {model_name}')

Expand Down
27 changes: 27 additions & 0 deletions test/nn/conv/test_egt_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from torch_geometric.nn import EGTConv
from torch_geometric.utils import to_dense_adj


@pytest.mark.parametrize('edge_update', [True, False])
def test_egt_conv(edge_update):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])
batch = torch.tensor([0, 0, 1, 1])
edge_attr = to_dense_adj(edge_index, batch,
edge_attr=torch.randn(edge_index.size(1), 8))

conv = EGTConv(16, edge_dim=8, edge_update=edge_update,
num_virtual_nodes=4, heads=4)
conv.reset_parameters()
assert str(conv) == 'EGTConv(16, heads=4, num_virtual_nodes=4)'

if edge_update:
out_x, out_edge_attr = conv(x, edge_attr, batch)
assert out_x.size() == (4, 16)
assert out_edge_attr.size() == edge_attr.size()
else:
out_x = conv(x, edge_attr, batch)
assert out_x.size() == (4, 16)
27 changes: 27 additions & 0 deletions test/nn/models/test_egt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from torch_geometric.nn.models import EGT


@pytest.mark.parametrize('edge_update', [True, False])
def test_egt(edge_update):
x = torch.randn(10, 16)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
])
edge_attr = torch.randn(edge_index.size(1), 16)

model = EGT(
node_channels=16,
edge_channels=16,
out_channels=40,
edge_update=edge_update,
num_layers=3,
num_heads=4,
dropout=0.3,
)

out = model(x, edge_index, edge_attr)
assert out.size() == (10, 40)
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .antisymmetric_conv import AntiSymmetricConv
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv
from .egt_conv import EGTConv
from .meshcnn_conv import MeshCNNConv

import torch_geometric.nn.conv.utils # noqa
Expand Down Expand Up @@ -132,6 +133,7 @@
'AntiSymmetricConv',
'DirGNNConv',
'MixHopConv',
'EGTConv',
'MeshCNNConv',
]

Expand Down
143 changes: 143 additions & 0 deletions torch_geometric/nn/conv/egt_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, LayerNorm, Linear, Sequential

from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import to_dense_batch


class EGTConv(torch.nn.Module):
r"""The Edge-augmented Graph Transformer (EGT) from the
`"Global Self-Attention as a Replacement for Graph Convolution"
<https://arxiv.org/abs/2108.03348>`_ paper.

Args:
channels (int): Size of each input sample.
edge_dim (int): Edge feature dimensionality.
num_virtual_nodes (int): Number of virtual nodes.
edge_update (Bool, optional): Whether to update the edge embedding.
(default: :obj:`True`)
heads (int, optional): Number of attention heads,
by which :attr:`channels` is divisible. (default: :obj:`1`)
dropout (float, optional): Dropout probability of intermediate
embeddings. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
attn_dropout (float, optional): Attention dropout probability
of intermediate embeddings. (default: :obj:`0.`)
"""
def __init__(
self,
channels: int,
edge_dim: int,
num_virtual_nodes: int,
edge_update: bool = True,
heads: int = 1,
dropout: float = 0.0,
act: str = 'elu',
act_kwargs: Optional[Dict[str, Any]] = None,
attn_dropout: float = 0.0,
):
super().__init__()

assert channels % heads == 0, "channels must be divisible by heads"

self.channels = channels
self.num_virtual_nodes = num_virtual_nodes
self.heads = heads
self.edge_update = edge_update
self.hidden_channels = channels // heads
self.mha_ln_h = LayerNorm(channels)
self.mha_ln_e = LayerNorm(edge_dim)
self.mha_dropout_h = Dropout(dropout)
self.edge_input = Linear(edge_dim, heads)
self.qkv_proj = Linear(channels, channels * 3)
self.gate = Linear(edge_dim, heads)
self.attn_dropout = Dropout(attn_dropout)
self.node_output = Linear(channels, channels)

self.node_ffn = Sequential(
LayerNorm(channels),
Linear(channels, channels),
activation_resolver(act, **(act_kwargs or {})),
Linear(channels, channels),
Dropout(dropout),
)

if self.edge_update:
self.edge_output = Linear(heads, edge_dim)
self.mha_dropout_e = Dropout(dropout)
self.edge_ffn = Sequential(
LayerNorm(edge_dim), Linear(edge_dim, edge_dim),
activation_resolver(act, **(act_kwargs or {})),
Linear(edge_dim, edge_dim), Dropout(dropout))

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.mha_ln_h.reset_parameters()
self.mha_ln_e.reset_parameters()
self.edge_input.reset_parameters()
self.qkv_proj.reset_parameters()
self.gate.reset_parameters()
self.node_output.reset_parameters()
reset(self.node_ffn)
if self.edge_update:
self.edge_output.reset_parameters()
reset(self.edge_ffn)

def forward(
self,
x: Tensor,
edge_attr: Tensor,
batch: Optional[Tensor] = None,
) -> Tensor:
r"""Runs the forward pass of the module."""
h, mask = to_dense_batch(x, batch) # [B, N, channels]
e = edge_attr # [B, N, N, edge_dim]

h_ln = self.mha_ln_h(h) # [B, N, channels]
e_ln = self.mha_ln_e(e) # [B, N, N, edge_dim]
qkv = self.qkv_proj(h_ln) # [B, N, channels * 3]
e_bias = self.edge_input(e_ln) # [B, N, N, heads]
gates = self.gate(e_ln) # [B, N, N, heads]
B, N, _ = qkv.size()
q_h, k_h, v_h = qkv.view(B, N, -1, self.heads).split(
self.hidden_channels,
dim=2) # each should be [B, N, hidden_channels, heads]
attn_hat = torch.einsum("bldh,bmdh->blmh", q_h, k_h)
attn_hat = attn_hat.clamp(-5, 5) + e_bias # [B, N, N, heads]

gates = F.sigmoid(gates) # [B, N, N, heads]
attn_tild = F.softmax(attn_hat, dim=2) * gates # [B, N, N, heads]
attn_tild = self.attn_dropout(attn_tild)

v_attn = torch.einsum("blmh,bmkh->blkh", attn_tild, v_h)

# Scale the aggregated values by degree.
degrees = torch.sum(gates, dim=2, keepdim=True) # [B, N, 1, heads]
degrees_scalers = torch.log(1 + degrees)
degrees_scalers[:, :self.num_virtual_nodes] = 1.0
v_attn = v_attn * degrees_scalers
v_attn = v_attn.reshape(B, N, -1) # [B, N, channels]

out1_h = self.mha_dropout_h(self.node_output(v_attn)) + h
out2_h = self.node_ffn(out1_h) + out1_h # [B, N, channels]

if self.edge_update:
out1_e = self.mha_dropout_e(self.edge_output(attn_hat)) + e
out2_e = self.edge_ffn(out1_e) + out1_e # [B, N, N, edge_dim]
return out2_h[mask], out2_e

return out2_h[mask]

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}({self.channels}, '
f'heads={self.heads}, num_virtual_nodes={self.num_virtual_nodes})')
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .visnet import ViSNet
from .lpformer import LPFormer
from .sgformer import SGFormer
from .egt import EGT

from .polynormer import Polynormer
# Deprecated:
Expand Down Expand Up @@ -86,4 +87,5 @@
'SGFormer',
'Polynormer',
'ARLinkPredictor',
'EGT',
]
Loading
Loading