diff --git a/pina/graph.py b/pina/graph.py index 1340ed69a..201f37a24 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -3,6 +3,7 @@ import torch from torch_geometric.data import Data, Batch from torch_geometric.utils import to_undirected +from torch_geometric.utils.loop import remove_self_loops from .label_tensor import LabelTensor from .utils import check_consistency, is_function @@ -209,6 +210,7 @@ def __new__( x=None, edge_attr=False, custom_edge_func=None, + loop=True, **kwargs, ): """ @@ -224,18 +226,19 @@ def __new__( :param x: Optional tensor of node features of shape ``(N, F)``, where ``F`` is the number of features per node. :type x: torch.Tensor | LabelTensor, optional - :param edge_attr: Optional tensor of edge attributes of shape ``(E, F)`` - , where ``F`` is the number of features per edge. - :type edge_attr: torch.Tensor, optional + :param bool edge_attr: Whether to compute the edge attributes. :param custom_edge_func: A custom function to compute edge attributes. If provided, overrides ``edge_attr``. :type custom_edge_func: Callable, optional + :param bool loop: Whether to include self-loops. :param kwargs: Additional keyword arguments passed to the :class:`~pina.graph.Graph` class constructor. :return: A :class:`~pina.graph.Graph` instance constructed using the provided information. :rtype: Graph """ + if not loop: + edge_index = remove_self_loops(edge_index)[0] edge_attr = cls._create_edge_attr( pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr ) @@ -374,11 +377,8 @@ def compute_knn_graph(points, neighbours): representing the edge indices of the graph. :rtype: torch.Tensor """ - dist = torch.cdist(points, points, p=2) - knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[ - :, 1: - ] + knn_indices = torch.topk(dist, k=neighbours, largest=False).indices row = torch.arange(points.size(0)).repeat_interleave(neighbours) col = knn_indices.flatten() return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) diff --git a/tests/test_graph.py b/tests/test_graph.py index bf053a89f..1ea51cfa3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -67,8 +67,9 @@ def test_build_graph(x, pos): ), ], ) -def test_build_radius_graph(x, pos): - graph = RadiusGraph(x=x, pos=pos, radius=0.5) +@pytest.mark.parametrize("loop", [True, False]) +def test_build_radius_graph(x, pos, loop): + graph = RadiusGraph(x=x, pos=pos, radius=0.5, loop=loop) assert hasattr(graph, "x") assert hasattr(graph, "pos") assert hasattr(graph, "edge_index") @@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos): assert graph.pos.labels == pos.labels else: assert isinstance(graph.pos, torch.Tensor) + if not loop: + assert ( + len( + torch.nonzero( + graph.edge_index[0] == graph.edge_index[1], as_tuple=True + )[0] + ) + == 0 + ) # Detect self loops @pytest.mark.parametrize( @@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos): ), ], ) -def test_build_knn_graph(x, pos): - graph = KNNGraph(x=x, pos=pos, neighbours=2) +@pytest.mark.parametrize("loop", [True, False]) +def test_build_knn_graph(x, pos, loop): + graph = KNNGraph(x=x, pos=pos, neighbours=2, loop=loop) assert hasattr(graph, "x") assert hasattr(graph, "pos") assert hasattr(graph, "edge_index") @@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos): else: assert isinstance(graph.pos, torch.Tensor) assert graph.edge_attr is None + self_loops = len( + torch.nonzero( + graph.edge_index[0] == graph.edge_index[1], as_tuple=True + )[0] + ) + if loop: + assert self_loops != 0 + else: + assert self_loops == 0 @pytest.mark.parametrize(