@@ -63,8 +63,9 @@ def __init__(
6363 :type pos: torch.Tensor | LabelTensor
6464 :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
6565 ``F'`` is the number of edge features
66+ :type edge_attr: torch.Tensor | LabelTensor
6667 :param bool undirected: Whether to make the graph undirected
67- :param kwargs: Additional keyword arguments passed to the
68+ :param dict kwargs: Additional keyword arguments passed to the
6869 :class:`~torch_geometric.data.Data` class constructor.
6970 """
7071 # preprocessing
@@ -201,7 +202,7 @@ def extract(self, labels, attr="x"):
201202
202203class GraphBuilder :
203204 """
204- A class that allows the simple definition of Graph instances.
205+ A class that allows an easy definition of :class:` Graph` instances.
205206 """
206207
207208 def __new__ (
@@ -217,25 +218,25 @@ def __new__(
217218 Compute the edge attributes and create a new instance of the
218219 :class:`~pina.graph.Graph` class.
219220
220- :param pos: A tensor of shape `(N, D)` representing the positions of `N`
221- points in `D `-dimensional space.
221+ :param pos: A tensor of shape `` (N, D)`` representing the positions of
222+ ``N`` points in ``D` `-dimensional space.
222223 :type pos: torch.Tensor or LabelTensor
223- :param edge_index: A tensor of shape `(2, E)` representing the indices
224+ :param edge_index: A tensor of shape `` (2, E)` ` representing the indices
224225 of the graph's edges.
225226 :type edge_index: torch.Tensor
226- :param x: Optional tensor of node features of shape `(N, F)`, where `F`
227- is the number of features per node.
227+ :param x: Optional tensor of node features of shape `` (N, F)`` , where
228+ ``F`` is the number of features per node.
228229 :type x: torch.Tensor | LabelTensor, optional
229- :param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
230- where `F ` is the number of features per edge.
230+ :param edge_attr: Optional tensor of edge attributes of shape `` (E, F)``
231+ , where ``F` ` is the number of features per edge.
231232 :type edge_attr: torch.Tensor, optional
232233 :param custom_edge_func: A custom function to compute edge attributes.
233- If provided, overrides `edge_attr`.
234+ If provided, overrides `` edge_attr` `.
234235 :type custom_edge_func: Callable, optional
235236 :param kwargs: Additional keyword arguments passed to the
236237 :class:`~pina.graph.Graph` class constructor.
237238 :return: A :class:`~pina.graph.Graph` instance constructed using the
238- provided information.
239+ provided information.
239240 :rtype: Graph
240241 """
241242 edge_attr = cls ._create_edge_attr (
@@ -274,6 +275,7 @@ def _create_edge_attr(pos, edge_index, edge_attr, func):
274275 def _build_edge_attr (pos , edge_index ):
275276 """
276277 Default function to compute the edge attributes.
278+
277279 :param pos: Positions of the points.
278280 :type pos: torch.Tensor | LabelTensor
279281 :param torch.Tensor edge_index: Edge indices.
@@ -289,14 +291,15 @@ def _build_edge_attr(pos, edge_index):
289291
290292class RadiusGraph (GraphBuilder ):
291293 """
292- A class to build a graph based on a radius.
294+ Extends the :class:`~pina.graph.GraphBuilder` class to compute
295+ edge_index based on a radius. Each point is connected to all the points
296+ within the radius.
293297 """
294298
295299 def __new__ (cls , pos , radius , ** kwargs ):
296300 """
297- Extends the :class:`~pina.graph.GraphBuilder` class to compute
298- edge_index based on a radius. Each point is connected to all the points
299- within the radius.
301+ Instantiate the :class:`~pina.graph.Graph` class by computing the
302+ ``edge_index`` based on the radius provided.
300303
301304 :param pos: A tensor of shape ``(N, D)`` representing the positions of
302305 ``N`` points in ``D``-dimensional space.
@@ -336,13 +339,14 @@ def compute_radius_graph(points, radius):
336339
337340class KNNGraph (GraphBuilder ):
338341 """
339- A class to build a K-nearest neighbors graph.
342+ Extends the :class:`~pina.graph.GraphBuilder` class to compute
343+ edge_index based on a K-nearest neighbors algorithm.
340344 """
341345
342346 def __new__ (cls , pos , neighbours , ** kwargs ):
343347 """
344- Extends the :class:`~pina.graph.GraphBuilder ` class to compute
345- edge_index based on a K-nearest neighbors algorithm.
348+ Instantiate the :class:`~pina.graph.Graph ` class by computing the
349+ `` edge_index`` based on the K-nearest neighbors algorithm.
346350
347351 :param pos: A tensor of shape ``(N, D)`` representing the positions of
348352 ``N`` points in ``D``-dimensional space.
0 commit comments