Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pina/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected
from torch_geometric.utils.loop import remove_self_loops
from .label_tensor import LabelTensor
from .utils import check_consistency, is_function

Expand Down Expand Up @@ -209,6 +210,7 @@ def __new__(
x=None,
edge_attr=False,
custom_edge_func=None,
loop=True,
**kwargs,
):
"""
Expand All @@ -224,18 +226,19 @@ def __new__(
:param x: Optional tensor of node features of shape ``(N, F)``, where
``F`` is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
, where ``F`` is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param bool edge_attr: Whether to compute the edge attributes.
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides ``edge_attr``.
:type custom_edge_func: Callable, optional
:param bool loop: Whether to include self-loops.
:param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the
provided information.
:rtype: Graph
"""
if not loop:
edge_index = remove_self_loops(edge_index)[0]
edge_attr = cls._create_edge_attr(
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
)
Expand Down Expand Up @@ -374,11 +377,8 @@ def compute_knn_graph(points, neighbours):
representing the edge indices of the graph.
:rtype: torch.Tensor
"""

dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
:, 1:
]
knn_indices = torch.topk(dist, k=neighbours, largest=False).indices
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
col = knn_indices.flatten()
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
Expand Down
28 changes: 24 additions & 4 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def test_build_graph(x, pos):
),
],
)
def test_build_radius_graph(x, pos):
graph = RadiusGraph(x=x, pos=pos, radius=0.5)
@pytest.mark.parametrize("loop", [True, False])
def test_build_radius_graph(x, pos, loop):
graph = RadiusGraph(x=x, pos=pos, radius=0.5, loop=loop)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
Expand All @@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos):
assert graph.pos.labels == pos.labels
else:
assert isinstance(graph.pos, torch.Tensor)
if not loop:
assert (
len(
torch.nonzero(
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
)[0]
)
== 0
) # Detect self loops


@pytest.mark.parametrize(
Expand Down Expand Up @@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos):
),
],
)
def test_build_knn_graph(x, pos):
graph = KNNGraph(x=x, pos=pos, neighbours=2)
@pytest.mark.parametrize("loop", [True, False])
def test_build_knn_graph(x, pos, loop):
graph = KNNGraph(x=x, pos=pos, neighbours=2, loop=loop)
assert hasattr(graph, "x")
assert hasattr(graph, "pos")
assert hasattr(graph, "edge_index")
Expand All @@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos):
else:
assert isinstance(graph.pos, torch.Tensor)
assert graph.edge_attr is None
self_loops = len(
torch.nonzero(
graph.edge_index[0] == graph.edge_index[1], as_tuple=True
)[0]
)
if loop:
assert self_loops != 0
else:
assert self_loops == 0


@pytest.mark.parametrize(
Expand Down
Loading