Skip to content

Commit 8d1f387

Browse files
committed
Bug fix and reimplement tests
1 parent b5a8150 commit 8d1f387

File tree

9 files changed

+540
-623
lines changed

9 files changed

+540
-623
lines changed

pina/graph.py

Lines changed: 186 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
This module provides an interface to build torch_geometric.data.Data objects.
33
"""
44

5-
import warnings
6-
75
import torch
8-
9-
from . import LabelTensor
10-
from .utils import check_consistency, is_function
116
from torch_geometric.data import Data
127
from torch_geometric.utils import to_undirected
8+
from . import LabelTensor
9+
from .utils import check_consistency, is_function
1310

1411

1512
class Graph(Data):
@@ -23,18 +20,18 @@ def __new__(
2320
):
2421
"""
2522
:param kwargs: Parameters to construct the Graph object.
26-
:return: The Data object.
27-
:rtype: torch_geometric.data.Data
23+
:return: A new instance of the Graph class.
24+
:rtype: Graph
2825
"""
2926
# create class instance
3027
instance = Data.__new__(cls)
3128

3229
# check the consistency of types defined in __init__, the others are not
3330
# checked (as in pyg Data object)
3431
instance._check_type_consistency(**kwargs)
35-
32+
3633
return instance
37-
34+
3835
def __init__(
3936
self,
4037
x=None,
@@ -46,18 +43,30 @@ def __init__(
4643
):
4744
"""
4845
Initialize the Graph object.
49-
:param torch.Tensor pos: The position tensor.
50-
:param torch.Tensor edge_index: The edge index tensor.
51-
:param torch.Tensor edge_attr: The edge attribute tensor.
52-
:param bool build_edge_attr: Whether to build the edge attributes.
53-
:param kwargs: Additional parameters.
46+
47+
:param x: Optional tensor of node features (N, F) where F is the number
48+
of features per node.
49+
:type x: torch.Tensor, LabelTensor
50+
:param torch.Tensor edge_index: A tensor of shape (2, E) representing
51+
the indices of the graph's edges.
52+
:param pos: A tensor of shape (N, D) representing the positions of N
53+
points in D-dimensional space.
54+
:type pos: torch.Tensor | LabelTensor
55+
:param edge_attr: Optional tensor of edge_featured (E, F') where F' is
56+
the number of edge features
57+
:param bool undirected: Whether to make the graph undirected
58+
:param kwargs: Additional keyword arguments passed to the
59+
`torch_geometric.data.Data` class constructor. If the argument
60+
is a `torch.Tensor` or `LabelTensor`, it is included in the Data
61+
object as a graph parameter.
5462
"""
5563
# preprocessing
5664
self._preprocess_edge_index(edge_index, undirected)
5765

5866
# calling init
59-
super().__init__(x=x, edge_index=edge_index, edge_attr=edge_attr,
60-
pos=pos, **kwargs)
67+
super().__init__(
68+
x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs
69+
)
6170

6271
def _check_type_consistency(self, **kwargs):
6372
# default types, specified in cls.__new__, by default they are Nont
@@ -85,9 +94,10 @@ def _check_pos_consistency(pos):
8594
Check if the position tensor is consistent.
8695
:param torch.Tensor pos: The position tensor.
8796
"""
88-
check_consistency(pos, (torch.Tensor, LabelTensor))
89-
if pos.ndim != 2:
90-
raise ValueError("pos must be a 2D tensor.")
97+
if pos is not None:
98+
check_consistency(pos, (torch.Tensor, LabelTensor))
99+
if pos.ndim != 2:
100+
raise ValueError("pos must be a 2D tensor.")
91101

92102
@staticmethod
93103
def _check_edge_index_consistency(edge_index):
@@ -109,16 +119,17 @@ def _check_edge_attr_consistency(edge_attr, edge_index):
109119
110120
:param torch.Tensor edge_index: The edge index tensor.
111121
"""
112-
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
113-
if edge_attr.ndim != 2:
114-
raise ValueError("edge_attr must be a 2D tensor.")
115-
if edge_attr.size(1) != edge_index.size(0):
116-
raise ValueError(
117-
"edge_attr must have shape "
118-
"[num_edges, num_edge_features], expected "
119-
f"num_edges {edge_index.size(0)} "
120-
f"got {edge_attr.size(1)}."
121-
)
122+
if edge_attr is not None:
123+
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
124+
if edge_attr.ndim != 2:
125+
raise ValueError("edge_attr must be a 2D tensor.")
126+
if edge_attr.size(0) != edge_index.size(1):
127+
raise ValueError(
128+
"edge_attr must have shape "
129+
"[num_edges, num_edge_features], expected "
130+
f"num_edges {edge_index.size(1)} "
131+
f"got {edge_attr.size(0)}."
132+
)
122133

123134
@staticmethod
124135
def _check_x_consistency(x, pos=None):
@@ -134,6 +145,9 @@ def _check_x_consistency(x, pos=None):
134145
if pos is not None:
135146
if x.size(0) != pos.size(0):
136147
raise ValueError("Inconsistent number of nodes.")
148+
if pos is not None:
149+
if x.size(0) != pos.size(0):
150+
raise ValueError("Inconsistent number of nodes.")
137151

138152
@staticmethod
139153
def _preprocess_edge_index(edge_index, undirected):
@@ -148,73 +162,158 @@ def _preprocess_edge_index(edge_index, undirected):
148162
edge_index = to_undirected(edge_index)
149163
return edge_index
150164

151-
class RadiusGraph(Graph):
152-
def __init__(
153-
self,
154-
radius,
165+
166+
class GraphBuilder:
167+
"""
168+
A class that allows the simple definition of Graph instances.
169+
"""
170+
171+
def __new__(
172+
cls,
173+
pos,
174+
edge_index,
155175
x=None,
156-
pos=None,
157-
edge_attr=None,
158-
undirected=False,
176+
edge_attr=False,
177+
custom_edge_func=None,
159178
**kwargs,
160179
):
161-
super().__init__(x=x, edge_index=None, edge_attr=edge_attr,
162-
pos=pos, undirected=undirected, **kwargs)
163-
edge_index = self._radius_graph(pos, radius)
164-
self.radius = radius
165-
self.edge_index = edge_index
166-
180+
"""
181+
Creates a new instance of the Graph class.
182+
183+
:param pos: A tensor of shape (N, D) representing the positions of N
184+
points in D-dimensional space.
185+
:type pos: torch.Tensor | LabelTensor
186+
:param edge_index: A tensor of shape (2, E) representing the indices of
187+
the graph's edges.
188+
:type edge_index: torch.Tensor
189+
:param x: Optional tensor of node features (N, F) where F is the number
190+
of features per node.
191+
:type x: torch.Tensor, LabelTensor
192+
:param bool edge_attr: Optional edge attributes (E, F) where F is the
193+
number of features per edge.
194+
:param callable custom_edge_func: A custom function to compute edge
195+
attributes.
196+
:param kwargs: Additional keyword arguments passed to the Graph class
197+
constructor.
198+
:return: A Graph instance constructed using the provided information.
199+
:rtype: Graph
200+
"""
201+
edge_attr = cls._create_edge_attr(
202+
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
203+
)
204+
return Graph(
205+
x=x,
206+
edge_index=edge_index,
207+
edge_attr=edge_attr,
208+
pos=pos,
209+
**kwargs,
210+
)
211+
167212
@staticmethod
168-
def _radius_graph(points, r):
169-
"""
170-
Implementation of the radius graph construction.
171-
:param points: The input points.
172-
:type points: torch.Tensor
173-
:param r: The radius.
174-
:type r: float
175-
:return: The edge index.
176-
:rtype: torch.Tensor
213+
def _create_edge_attr(pos, edge_index, edge_attr, func):
214+
check_consistency(edge_attr, bool)
215+
if edge_attr:
216+
if is_function(func):
217+
return func(pos, edge_index)
218+
raise ValueError("custom_edge_func must be a function.")
219+
return None
220+
221+
@staticmethod
222+
def _build_edge_attr(pos, edge_index):
223+
return (
224+
(pos[edge_index[0]] - pos[edge_index[1]])
225+
.abs()
226+
.as_subclass(torch.Tensor)
227+
)
228+
229+
230+
class RadiusGraph(GraphBuilder):
231+
"""
232+
A class to build a radius graph.
233+
"""
234+
235+
def __new__(cls, pos, radius, **kwargs):
236+
"""
237+
Creates a new instance of the Graph class using a radius-based graph
238+
construction.
239+
240+
:param pos: A tensor of shape (N, D) representing the positions of N
241+
points in D-dimensional space.
242+
:type pos: torch.Tensor | LabelTensor
243+
:param float radius: The radius within which points are connected.
244+
:Keyword Arguments:
245+
The additional keyword arguments to be passed to GraphBuilder
246+
and Graph classes
247+
:return: Graph instance containg the information passed in input and
248+
the computed edge_index
249+
:rtype: Graph
250+
"""
251+
edge_index = cls.compute_radius_graph(pos, radius)
252+
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
253+
254+
@staticmethod
255+
def compute_radius_graph(points, radius):
256+
"""
257+
Computes a radius-based graph for a given set of points.
258+
259+
:param points: A tensor of shape (N, D) representing the positions of
260+
N points in D-dimensional space.
261+
:type points: torch.Tensor | LabelTensor
262+
:param float radius: The number of nearest neighbors to find for each
263+
point.
264+
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
265+
edges, representing the edge indices of the KNN graph.
177266
"""
178267
dist = torch.cdist(points, points, p=2)
179-
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
180-
if isinstance(edge_index, LabelTensor):
181-
edge_index = edge_index.tensor
182-
return edge_index
183-
184-
class KNNGraph(Graph):
185-
def __init__(
186-
self,
187-
neighboors,
188-
x=None,
189-
pos=None,
190-
edge_attr=None,
191-
undirected=False,
192-
**kwargs,
193-
):
194-
super().__init__(x=x, edge_index=None, edge_attr=edge_attr,
195-
pos=pos, undirected=undirected, **kwargs)
196-
edge_index = self._knn_graph(pos, neighboors)
197-
self.neighboors = neighboors
198-
self.edge_index = edge_index
199-
268+
return (
269+
torch.nonzero(dist <= radius, as_tuple=False)
270+
.t()
271+
.as_subclass(torch.Tensor)
272+
)
273+
274+
275+
class KNNGraph(GraphBuilder):
276+
"""
277+
A class to build a KNN graph.
278+
"""
279+
280+
def __new__(cls, pos, neighbours, **kwargs):
281+
"""
282+
Creates a new instance of the Graph class using k-nearest neighbors
283+
to compute edge_index.
284+
285+
:param pos: A tensor of shape (N, D) representing the positions of N
286+
points in D-dimensional space.
287+
:type pos: torch.Tensor | LabelTensor
288+
:param int neighbours: The number of nearest neighbors to consider when
289+
building the graph.
290+
:Keyword Arguments:
291+
The additional keyword arguments to be passed to GraphBuilder
292+
and Graph classes
293+
294+
:return: Graph instance containg the information passed in input and
295+
the computed edge_index
296+
:rtype: Graph
297+
"""
298+
299+
edge_index = cls.compute_knn_graph(pos, neighbours)
300+
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
301+
200302
@staticmethod
201-
def _knn_graph(points, k):
202-
"""
203-
Implementation of the k-nearest neighbors graph construction.
204-
:param points: The input points.
205-
:type points: torch.Tensor
206-
:param k: The number of nearest neighbors.
207-
:type k: int
208-
:return: The edge index.
209-
:rtype: torch.Tensor
303+
def compute_knn_graph(points, k):
304+
"""
305+
Computes the edge_index based k-nearest neighbors graph algorithm
306+
307+
:param points: A tensor of shape (N, D) representing the positions of
308+
N points in D-dimensional space.
309+
:type points: torch.Tensor | LabelTensor
310+
:param int k: The number of nearest neighbors to find for each point.
311+
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
312+
edges, representing the edge indices of the KNN graph.
210313
"""
211-
if isinstance(points, LabelTensor):
212-
points = points.tensor
314+
213315
dist = torch.cdist(points, points, p=2)
214316
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
215317
row = torch.arange(points.size(0)).repeat_interleave(k)
216318
col = knn_indices.flatten()
217-
edge_index = torch.stack([row, col], dim=0)
218-
if isinstance(edge_index, LabelTensor):
219-
edge_index = edge_index.tensor
220-
return edge_index
319+
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)

tests/test_collector.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,12 @@ def poisson_sol(self, pts):
112112
def test_supervised_graph_collector():
113113
pos = torch.rand((100, 3))
114114
x = [torch.rand((100, 3)) for _ in range(10)]
115-
graph = RadiusGraph(pos=pos, build_edge_attr=True, r=0.4)
116-
graph_list_1 = [graph(x=x_) for x_ in x]
115+
graph_list_1 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x]
117116
out_1 = torch.rand((10, 100, 3))
118117

119118
pos = torch.rand((50, 3))
120119
x = [torch.rand((50, 3)) for _ in range(10)]
121-
graph = RadiusGraph(pos=pos, build_edge_attr=True, r=0.4)
122-
graph_list_2 = [graph(x=x_) for x_ in x]
120+
graph_list_2 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x]
123121
out_2 = torch.rand((10, 50, 3))
124122

125123
class SupervisedProblem(AbstractProblem):
@@ -135,6 +133,3 @@ class SupervisedProblem(AbstractProblem):
135133
# assert all(collector._is_conditions_ready.values())
136134
for v in collector.conditions_name.values():
137135
assert v in problem.conditions.keys()
138-
139-
140-
test_supervised_graph_collector()

tests/test_data/test_data_module.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
x = torch.rand((100, 50, 10))
1717
pos = torch.rand((100, 50, 2))
1818
input_graph = [
19-
RadiusGraph(x=x_, pos=pos_, r=0.2, build_edge_attr=True)
20-
for x_, pos_, in zip(x, pos)
19+
RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos)
2120
]
2221
output_graph = torch.rand((100, 50, 10))
2322

@@ -183,8 +182,7 @@ def test_dataloader(input_, output_, automatic_batching):
183182
x = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"])
184183
pos = LabelTensor(torch.rand((100, 50, 2)), ["x", "y"])
185184
input_graph = [
186-
RadiusGraph(x=x[i], pos=pos[i], r=0.1, build_edge_attr=True)
187-
for i in range(len(x))
185+
RadiusGraph(x=x[i], pos=pos[i], radius=0.1) for i in range(len(x))
188186
]
189187
output_graph = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"])
190188

0 commit comments

Comments
 (0)