Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ full = [
"torch_geometric[graphgym, modelhub]",
"torchmetrics",
"trimesh",
"gfn-layer"
]

[project.urls]
Expand Down
41 changes: 41 additions & 0 deletions test/nn/unpool/test_gfn_unpooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch

from torch_geometric.nn import GFNUnpooling
from torch_geometric.testing import withPackage


def _setup_gfn_unpooling():
weight = torch.tensor(
[[1, 10, 100], [2, 20, 200], [3, 30, 300], [4, 40, 400]], dtype=float)
bias = torch.tensor([1, 10, 100, 1000], dtype=float)

out_graph = torch.tensor([(-3, 2), (1, 0), (2, 1), (3, 2)])

unpool = GFNUnpooling(3, out_graph)

with torch.no_grad():
unpool._gfn.weight.copy_(weight)
unpool._gfn.bias.copy_(bias)

return unpool


@withPackage('gfn')
def test_gfn_unpooling():
unpool = _setup_gfn_unpooling()

x = torch.tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]])
pos_y = torch.tensor([
[-1.0, -1.0],
[1.0, 1.0],
[-2.0, -2.0],
[2.0, 2.0],
])
batch_x = torch.tensor([0, 0, 0, 1, 1, 1])
batch_y = torch.tensor([0, 0, 1, 1])

y = unpool(x, pos_y, batch_x, batch_y)

expected = torch.tensor([[15157.], [30673.], [-15146.], [-29933.]])

torch.testing.assert_close(y, expected)
5 changes: 2 additions & 3 deletions torch_geometric/nn/unpool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
r"""Unpooling package."""

from .knn_interpolate import knn_interpolate
from .gfn import GFNUnpooling

__all__ = [
'knn_interpolate',
]
__all__ = ['knn_interpolate', 'GFNUnpooling']

classes = __all__
88 changes: 88 additions & 0 deletions torch_geometric/nn/unpool/gfn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch


class GFNUnpooling(torch.nn.Module):
r"""The Graph Feedforward Network unpooling layer from
`"GFN: A graph feedforward network for resolution-invariant
reduced operator learning in multifidelity applications"
<https://doi.org/10.1016/j.cma.2024.117458>`_.

The GFN unpooling equation is given by:

.. math::
:nowrap:

\begin{equation*}
\begin{aligned}
\tilde{W}_{i_{\mathcal{M}_{n}}j} &= \underset{\forall
k_{\mathcal{M}_{o}} \text{ s.t } k_{\mathcal{M}_{o}}
{\leftarrow}\!{\backslash}\!{\rightarrow}
i_{\mathcal{M}_{n}}}{\operatorname{mean}}
{W}_{k_{\mathcal{M}_{o}}j}, \\
\tilde{b}_{i_{\mathcal{M}_{n}}} &= \underset{\forall
k_{\mathcal{M}_{o}} \text{ s.t } k_{\mathcal{M}_{o}}
{\leftarrow}\!{\backslash}\!{\rightarrow}
i_{\mathcal{M}_{n}}}{\operatorname{mean}}
{b}_{k_{\mathcal{M}_{o}}}.
\end{aligned}
\end{equation*}

where:

- :math:`\mathcal{M}_{o}` is the original output graph,
- :math:`\mathcal{M}_{n}` is the new output graph,
- :math:`W` and :math:`b` are the weights and biases
associated to the original graph,
- :math:`\tilde{W}` and :math:`\tilde{b}` are the new
weights and biases associated to the new graph,
- :math:`i_{\mathcal{M}_o} {\leftarrow}\!{\backslash}\!
{\rightarrow} j_{\mathcal{M}_n}` indicates that either
node :math:`i` in graph :math:`\mathcal{M}_o` is the
nearest neighbor of node :math:`j` in graph
:math:`\mathcal{M}_n` or vice versa.

Args:
in_size (int): Size of the input vector.
pos_y (torch.tensor): Original output graph (node position matrix)
:math:`\in \mathbb{R}^{M^{\prime} \times d}`.
**kwargs (optional): Additional arguments of :class:`gfn.GFN`.
"""
def __init__(self, in_size: int, pos_y: torch.Tensor, **kwargs):
try:
import gfn # noqa
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"GFNUnpooling requires `gfn` to be installed. "
"Please install it via `pip install gfn-layer`.") from e
super().__init__()
self._gfn = gfn.GFN(in_features=in_size, out_features=pos_y, **kwargs)

def forward(self, x: torch.Tensor, pos_y: torch.Tensor,
batch_x: torch.Tensor = None, batch_y: torch.Tensor = None):
r"""Runs the forward pass of the module.

Args:
x (torch.Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
pos_y (torch.Tensor): New output graph (new node position matrix)
:math:`\in \mathbb{R}^{M \times d}`.
batch_x (torch.Tensor, optional): Batch vector
:math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^{N}`, which
assigns each node from :math:`\mathbf{X}` to a specific
example. (default: :obj:`None`)
batch_y (torch.Tensor, optional): Batch vector
:math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^{M}`, which
assigns each node from :math:`\mathbf{Y}` to a specific
example. (default: :obj:`None`)

:rtype: :class:`torch.Tensor`
"""
out = torch.empty((batch_y.shape[0]), *x.shape[1:],
dtype=self._gfn.weight.dtype,
device=self._gfn.weight.device)
for batch_label in batch_x.unique():
mask = batch_y == batch_label
pos = pos_y[mask, ...]
x_ = x[batch_x == batch_label, ...]
out[mask] = self._gfn(x_.T, out_graph=pos).T
return out
4 changes: 3 additions & 1 deletion torch_geometric/nn/unpool/knn_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor,
each node from :math:`\mathbf{X}` to a specific example.
(default: :obj:`None`)
batch_y (torch.Tensor, optional): Batch vector
:math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
:math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^M`, which assigns
each node from :math:`\mathbf{Y}` to a specific example.
(default: :obj:`None`)
k (int, optional): Number of neighbors. (default: :obj:`3`)
num_workers (int, optional): Number of workers to use for computation.
Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)

:rtype: :class:`torch.Tensor`
"""
with torch.no_grad():
assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y,
Expand Down
Loading