22This module provides an interface to build torch_geometric.data.Data objects.
33"""
44
5- import warnings
6-
75import torch
8-
9- from . import LabelTensor
10- from .utils import check_consistency , is_function
116from torch_geometric .data import Data
127from torch_geometric .utils import to_undirected
8+ from . import LabelTensor
9+ from .utils import check_consistency , is_function
1310
1411
1512class Graph (Data ):
@@ -23,18 +20,18 @@ def __new__(
2320 ):
2421 """
2522 :param kwargs: Parameters to construct the Graph object.
26- :return: The Data object .
27- :rtype: torch_geometric.data.Data
23+ :return: A new instance of the Graph class .
24+ :rtype: Graph
2825 """
2926 # create class instance
3027 instance = Data .__new__ (cls )
3128
3229 # check the consistency of types defined in __init__, the others are not
3330 # checked (as in pyg Data object)
3431 instance ._check_type_consistency (** kwargs )
35-
32+
3633 return instance
37-
34+
3835 def __init__ (
3936 self ,
4037 x = None ,
@@ -46,18 +43,30 @@ def __init__(
4643 ):
4744 """
4845 Initialize the Graph object.
49- :param torch.Tensor pos: The position tensor.
50- :param torch.Tensor edge_index: The edge index tensor.
51- :param torch.Tensor edge_attr: The edge attribute tensor.
52- :param bool build_edge_attr: Whether to build the edge attributes.
53- :param kwargs: Additional parameters.
46+
47+ :param x: Optional tensor of node features (N, F) where F is the number
48+ of features per node.
49+ :type x: torch.Tensor, LabelTensor
50+ :param torch.Tensor edge_index: A tensor of shape (2, E) representing
51+ the indices of the graph's edges.
52+ :param pos: A tensor of shape (N, D) representing the positions of N
53+ points in D-dimensional space.
54+ :type pos: torch.Tensor | LabelTensor
55+ :param edge_attr: Optional tensor of edge_featured (E, F') where F' is
56+ the number of edge features
57+ :param bool undirected: Whether to make the graph undirected
58+ :param kwargs: Additional keyword arguments passed to the
59+ `torch_geometric.data.Data` class constructor. If the argument
60+ is a `torch.Tensor` or `LabelTensor`, it is included in the Data
61+ object as a graph parameter.
5462 """
5563 # preprocessing
5664 self ._preprocess_edge_index (edge_index , undirected )
5765
5866 # calling init
59- super ().__init__ (x = x , edge_index = edge_index , edge_attr = edge_attr ,
60- pos = pos , ** kwargs )
67+ super ().__init__ (
68+ x = x , edge_index = edge_index , edge_attr = edge_attr , pos = pos , ** kwargs
69+ )
6170
6271 def _check_type_consistency (self , ** kwargs ):
6372 # default types, specified in cls.__new__, by default they are Nont
@@ -85,9 +94,10 @@ def _check_pos_consistency(pos):
8594 Check if the position tensor is consistent.
8695 :param torch.Tensor pos: The position tensor.
8796 """
88- check_consistency (pos , (torch .Tensor , LabelTensor ))
89- if pos .ndim != 2 :
90- raise ValueError ("pos must be a 2D tensor." )
97+ if pos is not None :
98+ check_consistency (pos , (torch .Tensor , LabelTensor ))
99+ if pos .ndim != 2 :
100+ raise ValueError ("pos must be a 2D tensor." )
91101
92102 @staticmethod
93103 def _check_edge_index_consistency (edge_index ):
@@ -109,16 +119,17 @@ def _check_edge_attr_consistency(edge_attr, edge_index):
109119
110120 :param torch.Tensor edge_index: The edge index tensor.
111121 """
112- check_consistency (edge_attr , (torch .Tensor , LabelTensor ))
113- if edge_attr .ndim != 2 :
114- raise ValueError ("edge_attr must be a 2D tensor." )
115- if edge_attr .size (1 ) != edge_index .size (0 ):
116- raise ValueError (
117- "edge_attr must have shape "
118- "[num_edges, num_edge_features], expected "
119- f"num_edges { edge_index .size (0 )} "
120- f"got { edge_attr .size (1 )} ."
121- )
122+ if edge_attr is not None :
123+ check_consistency (edge_attr , (torch .Tensor , LabelTensor ))
124+ if edge_attr .ndim != 2 :
125+ raise ValueError ("edge_attr must be a 2D tensor." )
126+ if edge_attr .size (0 ) != edge_index .size (1 ):
127+ raise ValueError (
128+ "edge_attr must have shape "
129+ "[num_edges, num_edge_features], expected "
130+ f"num_edges { edge_index .size (1 )} "
131+ f"got { edge_attr .size (0 )} ."
132+ )
122133
123134 @staticmethod
124135 def _check_x_consistency (x , pos = None ):
@@ -134,6 +145,9 @@ def _check_x_consistency(x, pos=None):
134145 if pos is not None :
135146 if x .size (0 ) != pos .size (0 ):
136147 raise ValueError ("Inconsistent number of nodes." )
148+ if pos is not None :
149+ if x .size (0 ) != pos .size (0 ):
150+ raise ValueError ("Inconsistent number of nodes." )
137151
138152 @staticmethod
139153 def _preprocess_edge_index (edge_index , undirected ):
@@ -148,73 +162,158 @@ def _preprocess_edge_index(edge_index, undirected):
148162 edge_index = to_undirected (edge_index )
149163 return edge_index
150164
151- class RadiusGraph (Graph ):
152- def __init__ (
153- self ,
154- radius ,
165+
166+ class GraphBuilder :
167+ """
168+ A class that allows the simple definition of Graph instances.
169+ """
170+
171+ def __new__ (
172+ cls ,
173+ pos ,
174+ edge_index ,
155175 x = None ,
156- pos = None ,
157- edge_attr = None ,
158- undirected = False ,
176+ edge_attr = False ,
177+ custom_edge_func = None ,
159178 ** kwargs ,
160179 ):
161- super ().__init__ (x = x , edge_index = None , edge_attr = edge_attr ,
162- pos = pos , undirected = undirected , ** kwargs )
163- edge_index = self ._radius_graph (pos , radius )
164- self .radius = radius
165- self .edge_index = edge_index
166-
180+ """
181+ Creates a new instance of the Graph class.
182+
183+ :param pos: A tensor of shape (N, D) representing the positions of N
184+ points in D-dimensional space.
185+ :type pos: torch.Tensor | LabelTensor
186+ :param edge_index: A tensor of shape (2, E) representing the indices of
187+ the graph's edges.
188+ :type edge_index: torch.Tensor
189+ :param x: Optional tensor of node features (N, F) where F is the number
190+ of features per node.
191+ :type x: torch.Tensor, LabelTensor
192+ :param bool edge_attr: Optional edge attributes (E, F) where F is the
193+ number of features per edge.
194+ :param callable custom_edge_func: A custom function to compute edge
195+ attributes.
196+ :param kwargs: Additional keyword arguments passed to the Graph class
197+ constructor.
198+ :return: A Graph instance constructed using the provided information.
199+ :rtype: Graph
200+ """
201+ edge_attr = cls ._create_edge_attr (
202+ pos , edge_index , edge_attr , custom_edge_func or cls ._build_edge_attr
203+ )
204+ return Graph (
205+ x = x ,
206+ edge_index = edge_index ,
207+ edge_attr = edge_attr ,
208+ pos = pos ,
209+ ** kwargs ,
210+ )
211+
167212 @staticmethod
168- def _radius_graph (points , r ):
169- """
170- Implementation of the radius graph construction.
171- :param points: The input points.
172- :type points: torch.Tensor
173- :param r: The radius.
174- :type r: float
175- :return: The edge index.
176- :rtype: torch.Tensor
213+ def _create_edge_attr (pos , edge_index , edge_attr , func ):
214+ check_consistency (edge_attr , bool )
215+ if edge_attr :
216+ if is_function (func ):
217+ return func (pos , edge_index )
218+ raise ValueError ("custom_edge_func must be a function." )
219+ return None
220+
221+ @staticmethod
222+ def _build_edge_attr (pos , edge_index ):
223+ return (
224+ (pos [edge_index [0 ]] - pos [edge_index [1 ]])
225+ .abs ()
226+ .as_subclass (torch .Tensor )
227+ )
228+
229+
230+ class RadiusGraph (GraphBuilder ):
231+ """
232+ A class to build a radius graph.
233+ """
234+
235+ def __new__ (cls , pos , radius , ** kwargs ):
236+ """
237+ Creates a new instance of the Graph class using a radius-based graph
238+ construction.
239+
240+ :param pos: A tensor of shape (N, D) representing the positions of N
241+ points in D-dimensional space.
242+ :type pos: torch.Tensor | LabelTensor
243+ :param float radius: The radius within which points are connected.
244+ :Keyword Arguments:
245+ The additional keyword arguments to be passed to GraphBuilder
246+ and Graph classes
247+ :return: Graph instance containg the information passed in input and
248+ the computed edge_index
249+ :rtype: Graph
250+ """
251+ edge_index = cls .compute_radius_graph (pos , radius )
252+ return super ().__new__ (cls , pos = pos , edge_index = edge_index , ** kwargs )
253+
254+ @staticmethod
255+ def compute_radius_graph (points , radius ):
256+ """
257+ Computes a radius-based graph for a given set of points.
258+
259+ :param points: A tensor of shape (N, D) representing the positions of
260+ N points in D-dimensional space.
261+ :type points: torch.Tensor | LabelTensor
262+ :param float radius: The number of nearest neighbors to find for each
263+ point.
264+ :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
265+ edges, representing the edge indices of the KNN graph.
177266 """
178267 dist = torch .cdist (points , points , p = 2 )
179- edge_index = torch .nonzero (dist <= r , as_tuple = False ).t ()
180- if isinstance (edge_index , LabelTensor ):
181- edge_index = edge_index .tensor
182- return edge_index
183-
184- class KNNGraph (Graph ):
185- def __init__ (
186- self ,
187- neighboors ,
188- x = None ,
189- pos = None ,
190- edge_attr = None ,
191- undirected = False ,
192- ** kwargs ,
193- ):
194- super ().__init__ (x = x , edge_index = None , edge_attr = edge_attr ,
195- pos = pos , undirected = undirected , ** kwargs )
196- edge_index = self ._knn_graph (pos , neighboors )
197- self .neighboors = neighboors
198- self .edge_index = edge_index
199-
268+ return (
269+ torch .nonzero (dist <= radius , as_tuple = False )
270+ .t ()
271+ .as_subclass (torch .Tensor )
272+ )
273+
274+
275+ class KNNGraph (GraphBuilder ):
276+ """
277+ A class to build a KNN graph.
278+ """
279+
280+ def __new__ (cls , pos , neighbours , ** kwargs ):
281+ """
282+ Creates a new instance of the Graph class using k-nearest neighbors
283+ to compute edge_index.
284+
285+ :param pos: A tensor of shape (N, D) representing the positions of N
286+ points in D-dimensional space.
287+ :type pos: torch.Tensor | LabelTensor
288+ :param int neighbours: The number of nearest neighbors to consider when
289+ building the graph.
290+ :Keyword Arguments:
291+ The additional keyword arguments to be passed to GraphBuilder
292+ and Graph classes
293+
294+ :return: Graph instance containg the information passed in input and
295+ the computed edge_index
296+ :rtype: Graph
297+ """
298+
299+ edge_index = cls .compute_knn_graph (pos , neighbours )
300+ return super ().__new__ (cls , pos = pos , edge_index = edge_index , ** kwargs )
301+
200302 @staticmethod
201- def _knn_graph (points , k ):
202- """
203- Implementation of the k-nearest neighbors graph construction.
204- :param points: The input points.
205- :type points: torch.Tensor
206- :param k: The number of nearest neighbors.
207- :type k: int
208- :return: The edge index.
209- :rtype: torch.Tensor
303+ def compute_knn_graph (points , k ):
304+ """
305+ Computes the edge_index based k-nearest neighbors graph algorithm
306+
307+ :param points: A tensor of shape (N, D) representing the positions of
308+ N points in D-dimensional space.
309+ :type points: torch.Tensor | LabelTensor
310+ :param int k: The number of nearest neighbors to find for each point.
311+ :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
312+ edges, representing the edge indices of the KNN graph.
210313 """
211- if isinstance (points , LabelTensor ):
212- points = points .tensor
314+
213315 dist = torch .cdist (points , points , p = 2 )
214316 knn_indices = torch .topk (dist , k = k + 1 , largest = False ).indices [:, 1 :]
215317 row = torch .arange (points .size (0 )).repeat_interleave (k )
216318 col = knn_indices .flatten ()
217- edge_index = torch .stack ([row , col ], dim = 0 )
218- if isinstance (edge_index , LabelTensor ):
219- edge_index = edge_index .tensor
220- return edge_index
319+ return torch .stack ([row , col ], dim = 0 ).as_subclass (torch .Tensor )
0 commit comments