Skip to content

Commit e102537

Browse files
committed
Fix rendering and codacy
1 parent 53be672 commit e102537

File tree

8 files changed

+101
-99
lines changed

8 files changed

+101
-99
lines changed

pina/collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def store_sample_domains(self):
118118
"""
119119
Store inside data collections the sampled data of the problem. These
120120
comes from the conditions that require sampling (e.g.
121-
:class:`~pina.condition.domain_equation_condition.
121+
:class:`~pina.condition.domain_equation_condition.\
122122
DomainEquationCondition`).
123123
"""
124124

pina/data/data_module.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __new__(cls, dataset, shuffle):
244244

245245
class PinaDataModule(LightningDataModule):
246246
"""
247-
This class extends :class:`lightning.pytorch.LightningDataModule`,
247+
This class extends :class:`~lightning.pytorch.core.LightningDataModule`,
248248
allowing proper creation and management of different types of datasets
249249
defined in PINA.
250250
"""
@@ -268,24 +268,24 @@ def __init__(
268268
:param AbstractProblem problem: The problem containing the data on which
269269
to create the datasets and dataloaders.
270270
:param float train_size: Fraction of elements in the training split. It
271-
must be in the range [0, 1].
271+
must be in the range [0, 1].
272272
:param float test_size: Fraction of elements in the test split. It must
273-
be in the range [0, 1].
273+
be in the range [0, 1].
274274
:param float val_size: Fraction of elements in the validation split. It
275-
must be in the range [0, 1].
275+
must be in the range [0, 1].
276276
:param batch_size: The batch size used for training. If ``None``, the
277-
entire dataset is returned in a single batch.
278-
:type batch_size: int | None
277+
entire dataset is returned in a single batch. Default is ``None``.
278+
:type batch_size: int
279279
:param bool shuffle: Whether to shuffle the dataset before splitting.
280-
Default True.
280+
Default ``Tru``e.
281281
:param bool repeat: Whether to repeat the dataset indefinitely.
282-
Default False.
282+
Default ``False``.
283283
:param automatic_batching: Whether to enable automatic batching.
284-
Default False.
284+
Default ``False``.
285285
:param int num_workers: Number of worker threads for data loading.
286-
Default 0 (serial loading).
286+
Default ``0`` (serial loading).
287287
:param bool pin_memory: Whether to use pinned memory for faster data
288-
transfer to GPU. Default False.
288+
transfer to GPU. Default ``False``.
289289
290290
:raises ValueError: If at least one of the splits is negative.
291291
:raises ValueError: If the sum of the splits is different from 1.
@@ -643,7 +643,7 @@ def input(self):
643643
Return all the input points coming from all the datasets.
644644
645645
:return: The input points for training.
646-
:rtype dict
646+
:rtype: dict
647647
"""
648648

649649
to_return = {}

pina/data/dataset.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ class PinaDatasetFactory:
1212
"""
1313
Factory class for the PINA dataset.
1414
15-
Depending on the type inside the conditions, it creates a different dataset
16-
object:
15+
Depending on the data type inside the conditions, it instanciate an object
16+
belonging to the appropriate subclass of
17+
:class:`~pina.data.dataset.PinaDataset`. The possible subclasses are:
1718
18-
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
19+
- :class:`~pina.data.dataset.PinaTensorDataset`, for handling \
1920
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
20-
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
21+
- :class:`~pina.data.dataset.PinaGraphDataset`, for handling \
2122
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
2223
"""
2324

@@ -33,8 +34,7 @@ def __new__(cls, conditions_dict, **kwargs):
3334
:param dict conditions_dict: Dictionary containing all the conditions
3435
to be included in the dataset instance.
3536
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
36-
:rtype: pina.data.dataset.PinaTensorDataset |
37-
pina.data.dataset.PinaGraphDataset
37+
:rtype: PinaTensorDataset | PinaGraphDataset
3838
3939
:raises ValueError: If an empty dictionary is provided.
4040
"""
@@ -74,22 +74,25 @@ def _is_graph_dataset(conditions_dict):
7474

7575
class PinaDataset(Dataset, ABC):
7676
"""
77-
Abstract class for the PINA dataset. It defines the common interface for
78-
:class:`~pina.data.dataset.PinaTensorDataset` and
77+
Abstract class for the PINA dataset which extends the PyTorch
78+
:class:`~torch.utils.data.Dataset` class. It defines the common interface
79+
for :class:`~pina.data.dataset.PinaTensorDataset` and
7980
:class:`~pina.data.dataset.PinaGraphDataset` classes.
8081
"""
8182

8283
def __init__(
8384
self, conditions_dict, max_conditions_lengths, automatic_batching
8485
):
8586
"""
86-
Initialize :class:`~pina.data.dataset.PinaDataset` instance by storing
87-
the provided conditions dictionary, and the automatic batching flag.
88-
89-
:param dict conditions_dict: Dictionary containing the conditions with
90-
data.
91-
:param dict max_conditions_lengths: Specifies the maximum number of data
92-
points to include in a single batch for each condition.
87+
Initialize the instance by storing the conditions dictionary, the
88+
maximum number of items per conditions to consider, and the automatic
89+
batching flag.
90+
91+
:param dict conditions_dict: A dictionary mapping condition names to
92+
their respective data. Each key represents a condition name, and the
93+
corresponding value is a dictionary containing the associated data.
94+
:param dict max_conditions_lengths: Maximum number of data points that
95+
can be included in a single batch per condition.
9396
:param bool automatic_batching: Indicates whether PyTorch automatic
9497
batching is enabled in
9598
:class:`~pina.data.data_module.PinaDataModule`.
@@ -258,8 +261,8 @@ def _create_tensor_batch(self, data):
258261
Reshape properly ``data`` tensor to be processed handle by the graph
259262
based models.
260263
261-
:param data: torch.Tensor object of shape (N, ...) where N is the
262-
number of data points.
264+
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
265+
the number of data objects.
263266
:type data: torch.Tensor | LabelTensor
264267
:return: Reshaped tensor object.
265268
:rtype: torch.Tensor | LabelTensor
@@ -275,7 +278,7 @@ def create_batch(self, data):
275278
:param data: List of items to collate in a single batch.
276279
:type data: list[Data] | list[Graph]
277280
:return: Batch object.
278-
:rtype: Batch | PinaBatch
281+
:rtype: Batch | LabelBatch
279282
"""
280283

281284
if isinstance(data[0], Data):

pina/graph.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ def __new__(
2323
Create a new instance of the :class:`~pina.graph.Graph` class by
2424
checking the consistency of the input data and storing the attributes.
2525
26-
:param kwargs: Parameters used to initialize the
26+
:param dict kwargs: Parameters used to initialize the
2727
:class:`~pina.graph.Graph` object.
28-
:type kwargs: dict
2928
:return: A new instance of the :class:`~pina.graph.Graph` class.
3029
:rtype: Graph
3130
"""
@@ -56,8 +55,8 @@ def __init__(
5655
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
5756
number of features per node.
5857
:type x: torch.Tensor, LabelTensor
59-
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
60-
the indices of the graph's edges.
58+
:param torch.Tensor edge_index: A tensor of shape ``(2, E)``
59+
representing the indices of the graph's edges.
6160
:param pos: A tensor of shape ``(N, D)`` representing the positions of
6261
``N`` points in ``D``-dimensional space.
6362
:type pos: torch.Tensor | LabelTensor
@@ -80,8 +79,7 @@ def _check_type_consistency(self, **kwargs):
8079
"""
8180
Check the consistency of the types of the input data.
8281
83-
:param kwargs: Attributes to be checked for consistency.
84-
:type kwargs: dict
82+
:param dict kwargs: Attributes to be checked for consistency.
8583
"""
8684
# default types, specified in cls.__new__, by default they are Nont
8785
# if specified in **kwargs they get override
@@ -134,7 +132,8 @@ def _check_edge_attr_consistency(edge_attr, edge_index):
134132
Check if the edge attribute tensor is consistent in type and shape
135133
with the edge index.
136134
137-
:param torch.Tensor edge_attr: The edge attribute tensor.
135+
:param edge_attr: The edge attribute tensor.
136+
:type edge_attr: torch.Tensor | LabelTensor
138137
:param torch.Tensor edge_index: The edge index tensor.
139138
:raises ValueError: If the edge attribute tensor is not consistent.
140139
"""
@@ -156,8 +155,11 @@ def _check_x_consistency(x, pos=None):
156155
Check if the input tensor x is consistent with the position tensor
157156
`pos`.
158157
159-
:param torch.Tensor x: The input tensor.
160-
:param torch.Tensor pos: The position tensor.
158+
:param x: The input tensor.
159+
:type x: torch.Tensor | LabelTensor
160+
:param pos: The position tensor.
161+
:type pos: torch.Tensor | LabelTensor
162+
:raises ValueError: If the input tensor is not consistent.
161163
"""
162164
if x is not None:
163165
check_consistency(x, (torch.Tensor, LabelTensor))
@@ -166,9 +168,6 @@ def _check_x_consistency(x, pos=None):
166168
if pos is not None:
167169
if x.size(0) != pos.size(0):
168170
raise ValueError("Inconsistent number of nodes.")
169-
if pos is not None:
170-
if x.size(0) != pos.size(0):
171-
raise ValueError("Inconsistent number of nodes.")
172171

173172
@staticmethod
174173
def _preprocess_edge_index(edge_index, undirected):
@@ -292,7 +291,7 @@ def _build_edge_attr(pos, edge_index):
292291
class RadiusGraph(GraphBuilder):
293292
"""
294293
Extends the :class:`~pina.graph.GraphBuilder` class to compute
295-
edge_index based on a radius. Each point is connected to all the points
294+
``edge_index`` based on a radius. Each point is connected to all the points
296295
within the radius.
297296
"""
298297

@@ -305,11 +304,10 @@ def __new__(cls, pos, radius, **kwargs):
305304
``N`` points in ``D``-dimensional space.
306305
:type pos: torch.Tensor | LabelTensor
307306
:param float radius: The radius within which points are connected.
308-
:param kwargs: Additional keyword arguments to be passed to the
309-
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
310-
constructors.
311-
:return: A :class:`~pina.graph.Graph` instance containing the input
312-
information and the computed ``edge_index``.
307+
:param dict kwargs: The additional keyword arguments to be passed to
308+
:class:`GraphBuilder` and :class:`Graph` classes.
309+
:return: A :class:`~pina.graph.Graph` instance with the computed
310+
``edge_index``.
313311
:rtype: Graph
314312
"""
315313
edge_index = cls.compute_radius_graph(pos, radius)
@@ -318,16 +316,16 @@ def __new__(cls, pos, radius, **kwargs):
318316
@staticmethod
319317
def compute_radius_graph(points, radius):
320318
"""
321-
Computes ``edge_index`` for a given set of points base on the radius.
322-
Each point is connected to all the points within the radius.
319+
Computes the ``edge_index`` based on the radius. Each point is connected
320+
to all the points within the radius.
323321
324322
:param points: A tensor of shape ``(N, D)`` representing the positions
325323
of ``N`` points in ``D``-dimensional space.
326324
:type points: torch.Tensor | LabelTensor
327-
:param float radius: The number of nearest neighbors to find for each
328-
point.
329-
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
330-
number of edges, representing the edge indices of the KNN graph.
325+
:param float radius: The radius within which points are connected.
326+
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
327+
representing the edge indices of the graph.
328+
:rtype: torch.Tensor
331329
"""
332330
dist = torch.cdist(points, points, p=2)
333331
return (
@@ -340,7 +338,7 @@ def compute_radius_graph(points, radius):
340338
class KNNGraph(GraphBuilder):
341339
"""
342340
Extends the :class:`~pina.graph.GraphBuilder` class to compute
343-
edge_index based on a K-nearest neighbors algorithm.
341+
``edge_index`` based on a K-nearest neighbors algorithm.
344342
"""
345343

346344
def __new__(cls, pos, neighbours, **kwargs):
@@ -353,54 +351,57 @@ def __new__(cls, pos, neighbours, **kwargs):
353351
:type pos: torch.Tensor | LabelTensor
354352
:param int neighbours: The number of nearest neighbors to consider when
355353
building the graph.
356-
:Keyword Arguments:
357-
The additional keyword arguments to be passed to GraphBuilder
358-
and Graph classes
354+
:param dict kwargs: The additional keyword arguments to be passed to
355+
:class:`GraphBuilder` and :class:`Graph` classes.
359356
360-
:return: A :class:`~pina.graph.Graph` instance containg the
361-
information passed in input and the computed ``edge_index``
357+
:return: A :class:`~pina.graph.Graph` instance with the computed
358+
``edge_index``.
362359
:rtype: Graph
363360
"""
364361

365362
edge_index = cls.compute_knn_graph(pos, neighbours)
366363
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
367364

368365
@staticmethod
369-
def compute_knn_graph(points, k):
366+
def compute_knn_graph(points, neighbours):
370367
"""
371-
Computes the edge_index based k-nearest neighbors graph algorithm
368+
Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
372369
373370
:param points: A tensor of shape ``(N, D)`` representing the positions
374371
of ``N`` points in ``D``-dimensional space.
375372
:type points: torch.Tensor | LabelTensor
376-
:param int k: The number of nearest neighbors to find for each point.
377-
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of
378-
edges, representing the edge indices of the KNN graph.
373+
:param int neighbours: The number of nearest neighbors to consider when
374+
building the graph.
375+
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
376+
representing the edge indices of the graph.
379377
:rtype: torch.Tensor
380378
"""
381379

382380
dist = torch.cdist(points, points, p=2)
383-
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
384-
row = torch.arange(points.size(0)).repeat_interleave(k)
381+
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
382+
:, 1:
383+
]
384+
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
385385
col = knn_indices.flatten()
386386
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
387387

388388

389389
class LabelBatch(Batch):
390390
"""
391-
Add extract function to torch_geometric Batch object
391+
Extends the :class:`~torch_geometric.data.Batch` class to include
392+
:class:`~pina.label_tensor.LabelTensor` objects.
392393
"""
393394

394395
@classmethod
395396
def from_data_list(cls, data_list):
396397
"""
397398
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
398-
objects.
399+
or :class:`~pina.graph.Graph` objects.
399400
400401
:param data_list: List of :class:`~torch_geometric.data.Data` or
401402
:class:`~pina.graph.Graph` objects.
402403
:type data_list: list[Data] | list[Graph]
403-
:return: A Batch object containing the data in the list
404+
:return: A :class:`Batch` object containing the input data.
404405
:rtype: Batch
405406
"""
406407
# Store the labels of Data/Graph objects (all data have the same labels)

pina/label_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ def _init_labels_from_list(self, labels):
216216
def extract(self, labels_to_extract):
217217
"""
218218
Extract the subset of the original tensor by returning all the positions
219-
corresponding to the passed ``label_to_extract``. If ``label_to_extract``
220-
is a dictionary, the keys are the dimension names and the values are the
221-
labels to extract. If a single label or a list of labels is passed, the
222-
last dimension is considered.
219+
corresponding to the passed ``label_to_extract``. If
220+
``label_to_extract`` is a dictionary, the keys are the dimension names
221+
and the values are the labels to extract. If a single label or a list
222+
of labels is passed, the last dimension is considered.
223223
224224
:Example:
225225
>>> from pina import LabelTensor

pina/model/feed_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
:param transformer_nets: The two :class:`torch.nn.Module` acting as
155155
transformer network. The input dimension of both networks must be
156156
equal to ``input_dimensions``, and the output dimension must be
157-
equal to ``inner_size``. If ``None``, two
157+
equal to ``inner_size``. If ``None``, two
158158
:class:`~pina.model.block.residual.EnhancedLinear` layers are used.
159159
Default is ``None``.
160160
:type transformer_nets: list[torch.nn.Module] | tuple[torch.nn.Module]

0 commit comments

Comments
 (0)