Skip to content

Commit 9d9a01b

Browse files
committed
Fix rendering graph
1 parent 89f8e4e commit 9d9a01b

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

pina/graph.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ def __init__(
6363
:type pos: torch.Tensor | LabelTensor
6464
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
6565
``F'`` is the number of edge features
66+
:type edge_attr: torch.Tensor | LabelTensor
6667
:param bool undirected: Whether to make the graph undirected
67-
:param kwargs: Additional keyword arguments passed to the
68+
:param dict kwargs: Additional keyword arguments passed to the
6869
:class:`~torch_geometric.data.Data` class constructor.
6970
"""
7071
# preprocessing
@@ -201,7 +202,7 @@ def extract(self, labels, attr="x"):
201202

202203
class GraphBuilder:
203204
"""
204-
A class that allows the simple definition of Graph instances.
205+
A class that allows an easy definition of :class:`Graph` instances.
205206
"""
206207

207208
def __new__(
@@ -217,25 +218,25 @@ def __new__(
217218
Compute the edge attributes and create a new instance of the
218219
:class:`~pina.graph.Graph` class.
219220
220-
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
221-
points in `D`-dimensional space.
221+
:param pos: A tensor of shape ``(N, D)`` representing the positions of
222+
``N`` points in ``D``-dimensional space.
222223
:type pos: torch.Tensor or LabelTensor
223-
:param edge_index: A tensor of shape `(2, E)` representing the indices
224+
:param edge_index: A tensor of shape ``(2, E)`` representing the indices
224225
of the graph's edges.
225226
:type edge_index: torch.Tensor
226-
:param x: Optional tensor of node features of shape `(N, F)`, where `F`
227-
is the number of features per node.
227+
:param x: Optional tensor of node features of shape ``(N, F)``, where
228+
``F`` is the number of features per node.
228229
:type x: torch.Tensor | LabelTensor, optional
229-
:param edge_attr: Optional tensor of edge attributes of shape `(E, F)`,
230-
where `F` is the number of features per edge.
230+
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
231+
, where ``F`` is the number of features per edge.
231232
:type edge_attr: torch.Tensor, optional
232233
:param custom_edge_func: A custom function to compute edge attributes.
233-
If provided, overrides `edge_attr`.
234+
If provided, overrides ``edge_attr``.
234235
:type custom_edge_func: Callable, optional
235236
:param kwargs: Additional keyword arguments passed to the
236237
:class:`~pina.graph.Graph` class constructor.
237238
:return: A :class:`~pina.graph.Graph` instance constructed using the
238-
provided information.
239+
provided information.
239240
:rtype: Graph
240241
"""
241242
edge_attr = cls._create_edge_attr(
@@ -274,6 +275,7 @@ def _create_edge_attr(pos, edge_index, edge_attr, func):
274275
def _build_edge_attr(pos, edge_index):
275276
"""
276277
Default function to compute the edge attributes.
278+
277279
:param pos: Positions of the points.
278280
:type pos: torch.Tensor | LabelTensor
279281
:param torch.Tensor edge_index: Edge indices.
@@ -289,14 +291,15 @@ def _build_edge_attr(pos, edge_index):
289291

290292
class RadiusGraph(GraphBuilder):
291293
"""
292-
A class to build a graph based on a radius.
294+
Extends the :class:`~pina.graph.GraphBuilder` class to compute
295+
edge_index based on a radius. Each point is connected to all the points
296+
within the radius.
293297
"""
294298

295299
def __new__(cls, pos, radius, **kwargs):
296300
"""
297-
Extends the :class:`~pina.graph.GraphBuilder` class to compute
298-
edge_index based on a radius. Each point is connected to all the points
299-
within the radius.
301+
Instantiate the :class:`~pina.graph.Graph` class by computing the
302+
``edge_index`` based on the radius provided.
300303
301304
:param pos: A tensor of shape ``(N, D)`` representing the positions of
302305
``N`` points in ``D``-dimensional space.
@@ -336,13 +339,14 @@ def compute_radius_graph(points, radius):
336339

337340
class KNNGraph(GraphBuilder):
338341
"""
339-
A class to build a K-nearest neighbors graph.
342+
Extends the :class:`~pina.graph.GraphBuilder` class to compute
343+
edge_index based on a K-nearest neighbors algorithm.
340344
"""
341345

342346
def __new__(cls, pos, neighbours, **kwargs):
343347
"""
344-
Extends the :class:`~pina.graph.GraphBuilder` class to compute
345-
edge_index based on a K-nearest neighbors algorithm.
348+
Instantiate the :class:`~pina.graph.Graph` class by computing the
349+
``edge_index`` based on the K-nearest neighbors algorithm.
346350
347351
:param pos: A tensor of shape ``(N, D)`` representing the positions of
348352
``N`` points in ``D``-dimensional space.

0 commit comments

Comments
 (0)