@@ -23,9 +23,8 @@ def __new__(
2323 Create a new instance of the :class:`~pina.graph.Graph` class by
2424 checking the consistency of the input data and storing the attributes.
2525
26- :param kwargs: Parameters used to initialize the
26+ :param dict kwargs: Parameters used to initialize the
2727 :class:`~pina.graph.Graph` object.
28- :type kwargs: dict
2928 :return: A new instance of the :class:`~pina.graph.Graph` class.
3029 :rtype: Graph
3130 """
@@ -56,8 +55,8 @@ def __init__(
5655 :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
5756 number of features per node.
5857 :type x: torch.Tensor, LabelTensor
59- :param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
60- the indices of the graph's edges.
58+ :param torch.Tensor edge_index: A tensor of shape ``(2, E)``
59+ representing the indices of the graph's edges.
6160 :param pos: A tensor of shape ``(N, D)`` representing the positions of
6261 ``N`` points in ``D``-dimensional space.
6362 :type pos: torch.Tensor | LabelTensor
@@ -80,8 +79,7 @@ def _check_type_consistency(self, **kwargs):
8079 """
8180 Check the consistency of the types of the input data.
8281
83- :param kwargs: Attributes to be checked for consistency.
84- :type kwargs: dict
82+ :param dict kwargs: Attributes to be checked for consistency.
8583 """
8684 # default types, specified in cls.__new__, by default they are Nont
8785 # if specified in **kwargs they get override
@@ -134,7 +132,8 @@ def _check_edge_attr_consistency(edge_attr, edge_index):
134132 Check if the edge attribute tensor is consistent in type and shape
135133 with the edge index.
136134
137- :param torch.Tensor edge_attr: The edge attribute tensor.
135+ :param edge_attr: The edge attribute tensor.
136+ :type edge_attr: torch.Tensor | LabelTensor
138137 :param torch.Tensor edge_index: The edge index tensor.
139138 :raises ValueError: If the edge attribute tensor is not consistent.
140139 """
@@ -156,8 +155,11 @@ def _check_x_consistency(x, pos=None):
156155 Check if the input tensor x is consistent with the position tensor
157156 `pos`.
158157
159- :param torch.Tensor x: The input tensor.
160- :param torch.Tensor pos: The position tensor.
158+ :param x: The input tensor.
159+ :type x: torch.Tensor | LabelTensor
160+ :param pos: The position tensor.
161+ :type pos: torch.Tensor | LabelTensor
162+ :raises ValueError: If the input tensor is not consistent.
161163 """
162164 if x is not None :
163165 check_consistency (x , (torch .Tensor , LabelTensor ))
@@ -166,9 +168,6 @@ def _check_x_consistency(x, pos=None):
166168 if pos is not None :
167169 if x .size (0 ) != pos .size (0 ):
168170 raise ValueError ("Inconsistent number of nodes." )
169- if pos is not None :
170- if x .size (0 ) != pos .size (0 ):
171- raise ValueError ("Inconsistent number of nodes." )
172171
173172 @staticmethod
174173 def _preprocess_edge_index (edge_index , undirected ):
@@ -292,7 +291,7 @@ def _build_edge_attr(pos, edge_index):
292291class RadiusGraph (GraphBuilder ):
293292 """
294293 Extends the :class:`~pina.graph.GraphBuilder` class to compute
295- edge_index based on a radius. Each point is connected to all the points
294+ `` edge_index`` based on a radius. Each point is connected to all the points
296295 within the radius.
297296 """
298297
@@ -305,11 +304,10 @@ def __new__(cls, pos, radius, **kwargs):
305304 ``N`` points in ``D``-dimensional space.
306305 :type pos: torch.Tensor | LabelTensor
307306 :param float radius: The radius within which points are connected.
308- :param kwargs: Additional keyword arguments to be passed to the
309- :class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
310- constructors.
311- :return: A :class:`~pina.graph.Graph` instance containing the input
312- information and the computed ``edge_index``.
307+ :param dict kwargs: The additional keyword arguments to be passed to
308+ :class:`GraphBuilder` and :class:`Graph` classes.
309+ :return: A :class:`~pina.graph.Graph` instance with the computed
310+ ``edge_index``.
313311 :rtype: Graph
314312 """
315313 edge_index = cls .compute_radius_graph (pos , radius )
@@ -318,16 +316,16 @@ def __new__(cls, pos, radius, **kwargs):
318316 @staticmethod
319317 def compute_radius_graph (points , radius ):
320318 """
321- Computes ``edge_index`` for a given set of points base on the radius.
322- Each point is connected to all the points within the radius.
319+ Computes the ``edge_index`` based on the radius. Each point is connected
320+ to all the points within the radius.
323321
324322 :param points: A tensor of shape ``(N, D)`` representing the positions
325323 of ``N`` points in ``D``-dimensional space.
326324 :type points: torch.Tensor | LabelTensor
327- :param float radius: The number of nearest neighbors to find for each
328- point.
329- :rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
330- number of edges, representing the edge indices of the KNN graph.
325+ :param float radius: The radius within which points are connected.
326+ :return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
327+ representing the edge indices of the graph.
328+ :rtype: torch.Tensor
331329 """
332330 dist = torch .cdist (points , points , p = 2 )
333331 return (
@@ -340,7 +338,7 @@ def compute_radius_graph(points, radius):
340338class KNNGraph (GraphBuilder ):
341339 """
342340 Extends the :class:`~pina.graph.GraphBuilder` class to compute
343- edge_index based on a K-nearest neighbors algorithm.
341+ `` edge_index`` based on a K-nearest neighbors algorithm.
344342 """
345343
346344 def __new__ (cls , pos , neighbours , ** kwargs ):
@@ -353,54 +351,57 @@ def __new__(cls, pos, neighbours, **kwargs):
353351 :type pos: torch.Tensor | LabelTensor
354352 :param int neighbours: The number of nearest neighbors to consider when
355353 building the graph.
356- :Keyword Arguments:
357- The additional keyword arguments to be passed to GraphBuilder
358- and Graph classes
354+ :param dict kwargs: The additional keyword arguments to be passed to
355+ :class:`GraphBuilder` and :class:`Graph` classes.
359356
360- :return: A :class:`~pina.graph.Graph` instance containg the
361- information passed in input and the computed ``edge_index``
357+ :return: A :class:`~pina.graph.Graph` instance with the computed
358+ ``edge_index``.
362359 :rtype: Graph
363360 """
364361
365362 edge_index = cls .compute_knn_graph (pos , neighbours )
366363 return super ().__new__ (cls , pos = pos , edge_index = edge_index , ** kwargs )
367364
368365 @staticmethod
369- def compute_knn_graph (points , k ):
366+ def compute_knn_graph (points , neighbours ):
370367 """
371- Computes the edge_index based k -nearest neighbors graph algorithm
368+ Computes the `` edge_index`` based on the K -nearest neighbors algorithm.
372369
373370 :param points: A tensor of shape ``(N, D)`` representing the positions
374371 of ``N`` points in ``D``-dimensional space.
375372 :type points: torch.Tensor | LabelTensor
376- :param int k: The number of nearest neighbors to find for each point.
377- :return: A tensor of shape ``(2, E)``, where ``E`` is the number of
378- edges, representing the edge indices of the KNN graph.
373+ :param int neighbours: The number of nearest neighbors to consider when
374+ building the graph.
375+ :return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
376+ representing the edge indices of the graph.
379377 :rtype: torch.Tensor
380378 """
381379
382380 dist = torch .cdist (points , points , p = 2 )
383- knn_indices = torch .topk (dist , k = k + 1 , largest = False ).indices [:, 1 :]
384- row = torch .arange (points .size (0 )).repeat_interleave (k )
381+ knn_indices = torch .topk (dist , k = neighbours + 1 , largest = False ).indices [
382+ :, 1 :
383+ ]
384+ row = torch .arange (points .size (0 )).repeat_interleave (neighbours )
385385 col = knn_indices .flatten ()
386386 return torch .stack ([row , col ], dim = 0 ).as_subclass (torch .Tensor )
387387
388388
389389class LabelBatch (Batch ):
390390 """
391- Add extract function to torch_geometric Batch object
391+ Extends the :class:`~torch_geometric.data.Batch` class to include
392+ :class:`~pina.label_tensor.LabelTensor` objects.
392393 """
393394
394395 @classmethod
395396 def from_data_list (cls , data_list ):
396397 """
397398 Create a Batch object from a list of :class:`~torch_geometric.data.Data`
398- objects.
399+ or :class:`~pina.graph.Graph` objects.
399400
400401 :param data_list: List of :class:`~torch_geometric.data.Data` or
401402 :class:`~pina.graph.Graph` objects.
402403 :type data_list: list[Data] | list[Graph]
403- :return: A Batch object containing the data in the list
404+ :return: A :class:` Batch` object containing the input data.
404405 :rtype: Batch
405406 """
406407 # Store the labels of Data/Graph objects (all data have the same labels)
0 commit comments