Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `EGTConv` layer and example ([#8280](https://github.com/pyg-team/pytorch_geometric/pull/8280))
- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436))
- Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432))
Expand Down
122 changes: 122 additions & 0 deletions examples/egt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os.path as osp

import torch
import torch.nn.functional as F
from ogb.graphproppred import Evaluator
from ogb.graphproppred import PygGraphPropPredDataset as OGBG
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch.nn import BatchNorm1d, Linear, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import EGTConv, global_mean_pool
from torch_geometric.typing import WITH_TORCH_SPARSE
from torch_geometric.utils import to_dense_adj, to_edge_index

if not WITH_TORCH_SPARSE:
quit("This example requires 'torch-sparse'")

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')
dataset = OGBG('ogbg-molhiv', path, pre_transform=T.ToSparseTensor())
evaluator = Evaluator('ogbg-molhiv')

split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx['train']]
val_dataset = dataset[split_idx['valid']]
test_dataset = dataset[split_idx['test']]

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4,
shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256)
test_loader = DataLoader(test_dataset, batch_size=256)


class Net(torch.nn.Module):
def __init__(self, hidden_channels, num_layers):
super().__init__()

self.atom_encoder = AtomEncoder(hidden_channels)
self.bond_encoder = BondEncoder(hidden_channels)

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(
EGTConv(channels=hidden_channels, edge_dim=hidden_channels,
edge_update=False, heads=4, attn_dropout=0.3,
num_virtual_nodes=0))

self.mlp = Sequential(
Linear(hidden_channels, hidden_channels // 2, bias=False),
BatchNorm1d(hidden_channels // 2),
ReLU(inplace=True),
Linear(hidden_channels // 2, hidden_channels // 4, bias=False),
BatchNorm1d(hidden_channels // 4),
ReLU(inplace=True),
Linear(hidden_channels // 4, 1),
)

def forward(self, x, adj_t, edge_attr, batch):
x = self.atom_encoder(x)
edge_attr = self.bond_encoder(edge_attr)
edge_index, _ = to_edge_index(adj_t)
edge_attr = to_dense_adj(edge_index, batch, edge_attr)

for conv in self.convs:
x = conv(x, edge_attr, batch)

x = global_mean_pool(x, batch)

return self.mlp(x)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(hidden_channels=128, num_layers=8).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20,
min_lr=1e-5)


def train():
model.train()

total_loss = total_examples = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()

out = model(data.x, data.adj_t, data.edge_attr, data.batch)
loss = F.binary_cross_entropy_with_logits(out, data.y.to(torch.float))
loss.backward()
optimizer.step()

total_loss += float(loss) * data.num_graphs
total_examples += data.num_graphs

return total_loss / total_examples


@torch.no_grad()
def evaluate(loader):
model.eval()

y_pred, y_true = [], []
for data in loader:
data = data.to(device)
pred = model(data.x, data.adj_t, data.edge_attr, data.batch)
y_pred.append(pred.cpu())
y_true.append(data.y.cpu())

y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
return evaluator.eval({'y_true': y_true, 'y_pred': y_pred})['rocauc']


for epoch in range(1, 31):
loss = train()
val_rocauc = evaluate(val_loader)
test_rocauc = evaluate(test_loader)
scheduler.step(val_rocauc)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_rocauc:.4f}, '
f'Test: {test_rocauc:.4f}')
27 changes: 27 additions & 0 deletions test/nn/conv/test_egt_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from torch_geometric.nn import EGTConv
from torch_geometric.utils import to_dense_adj


@pytest.mark.parametrize('edge_update', [True, False])
def test_egt_conv(edge_update):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])
batch = torch.tensor([0, 0, 1, 1])
edge_attr = to_dense_adj(edge_index, batch,
edge_attr=torch.randn(edge_index.size(1), 8))

conv = EGTConv(16, edge_dim=8, edge_update=edge_update,
num_virtual_nodes=4, heads=4)
conv.reset_parameters()
assert str(conv) == 'EGTConv(16, heads=4, num_virtual_nodes=4)'

if edge_update:
out_x, out_edge_attr = conv(x, edge_attr, batch)
assert out_x.size() == (4, 16)
assert out_edge_attr.size() == edge_attr.size()
else:
out_x = conv(x, edge_attr, batch)
assert out_x.size() == (4, 16)
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .antisymmetric_conv import AntiSymmetricConv
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv
from .egt_conv import EGTConv
from .meshcnn_conv import MeshCNNConv

import torch_geometric.nn.conv.utils # noqa
Expand Down Expand Up @@ -132,6 +133,7 @@
'AntiSymmetricConv',
'DirGNNConv',
'MixHopConv',
'EGTConv',
'MeshCNNConv',
]

Expand Down
143 changes: 143 additions & 0 deletions torch_geometric/nn/conv/egt_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, LayerNorm, Linear, Sequential

from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import to_dense_batch


class EGTConv(torch.nn.Module):
r"""The Edge-augmented Graph Transformer (EGT) from the
`"Global Self-Attention as a Replacement for Graph Convolution"
<https://arxiv.org/abs/2108.03348>`_ paper.

Args:
channels (int): Size of each input sample.
edge_dim (int): Edge feature dimensionality.
num_virtual_nodes (int): Number of virtual nodes.
edge_update (Bool, optional): Whether to update the edge embedding.
(default: :obj:`True`)
heads (int, optional): Number of attention heads,
by which :attr:`channels` is divisible. (default: :obj:`1`)
dropout (float, optional): Dropout probability of intermediate
embeddings. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
attn_dropout (float, optional): Attention dropout probability
of intermediate embeddings. (default: :obj:`0.`)
"""
def __init__(
self,
channels: int,
edge_dim: int,
num_virtual_nodes: int,
edge_update: bool = True,
heads: int = 1,
dropout: float = 0.0,
act: str = 'elu',
act_kwargs: Optional[Dict[str, Any]] = None,
attn_dropout: float = 0.0,
):
super().__init__()

assert channels % heads == 0, "channels must be divisible by heads"

self.channels = channels
self.num_virtual_nodes = num_virtual_nodes
self.heads = heads
self.edge_update = edge_update
self.hidden_channels = channels // heads
self.mha_ln_h = LayerNorm(channels)
self.mha_ln_e = LayerNorm(edge_dim)
self.mha_dropout_h = Dropout(dropout)
self.edge_input = Linear(edge_dim, heads)
self.qkv_proj = Linear(channels, channels * 3)
self.gate = Linear(edge_dim, heads)
self.attn_dropout = Dropout(attn_dropout)
self.node_output = Linear(channels, channels)

self.node_ffn = Sequential(
LayerNorm(channels),
Linear(channels, channels),
activation_resolver(act, **(act_kwargs or {})),
Linear(channels, channels),
Dropout(dropout),
)

if self.edge_update:
self.edge_output = Linear(heads, edge_dim)
self.mha_dropout_e = Dropout(dropout)
self.edge_ffn = Sequential(
LayerNorm(edge_dim), Linear(edge_dim, edge_dim),
activation_resolver(act, **(act_kwargs or {})),
Linear(edge_dim, edge_dim), Dropout(dropout))

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.mha_ln_h.reset_parameters()
self.mha_ln_e.reset_parameters()
self.edge_input.reset_parameters()
self.qkv_proj.reset_parameters()
self.gate.reset_parameters()
self.node_output.reset_parameters()
reset(self.node_ffn)
if self.edge_update:
self.edge_output.reset_parameters()
reset(self.edge_ffn)

def forward(
self,
x: Tensor,
edge_attr: Tensor,
batch: Optional[Tensor] = None,
) -> Tensor:
r"""Runs the forward pass of the module."""
h, mask = to_dense_batch(x, batch) # [B, N, channels]
e = edge_attr # [B, N, N, edge_dim]

h_ln = self.mha_ln_h(h) # [B, N, channels]
e_ln = self.mha_ln_e(e) # [B, N, N, edge_dim]
qkv = self.qkv_proj(h_ln) # [B, N, channels * 3]
e_bias = self.edge_input(e_ln) # [B, N, N, heads]
gates = self.gate(e_ln) # [B, N, N, heads]
B, N, _ = qkv.size()
q_h, k_h, v_h = qkv.view(B, N, -1, self.heads).split(
self.hidden_channels,
dim=2) # each should be [B, N, hidden_channels, heads]
attn_hat = torch.einsum("bldh,bmdh->blmh", q_h, k_h)
attn_hat = attn_hat.clamp(-5, 5) + e_bias # [B, N, N, heads]

gates = F.sigmoid(gates) # [B, N, N, heads]
attn_tild = F.softmax(attn_hat, dim=2) * gates # [B, N, N, heads]
attn_tild = self.attn_dropout(attn_tild)

v_attn = torch.einsum("blmh,bmkh->blkh", attn_tild, v_h)

# Scale the aggregated values by degree.
degrees = torch.sum(gates, dim=2, keepdim=True) # [B, N, 1, heads]
degrees_scalers = torch.log(1 + degrees)
degrees_scalers[:, :self.num_virtual_nodes] = 1.0
v_attn = v_attn * degrees_scalers
v_attn = v_attn.reshape(B, N, -1) # [B, N, channels]

out1_h = self.mha_dropout_h(self.node_output(v_attn)) + h
out2_h = self.node_ffn(out1_h) + out1_h # [B, N, channels]

if self.edge_update:
out1_e = self.mha_dropout_e(self.edge_output(attn_hat)) + e
out2_e = self.edge_ffn(out1_e) + out1_e # [B, N, N, edge_dim]
return out2_h[mask], out2_e

return out2_h[mask]

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}({self.channels}, '
f'heads={self.heads}, num_virtual_nodes={self.num_virtual_nodes})')
Loading