Skip to content

Commit 3374176

Browse files
GiovanniCanalidario-coscia
authored andcommitted
fix utils and trainer doc
1 parent 593ab6b commit 3374176

File tree

2 files changed

+162
-91
lines changed

2 files changed

+162
-91
lines changed

pina/trainer.py

Lines changed: 87 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Trainer module."""
1+
"""Module for the Trainer."""
22

33
import sys
44
import torch
@@ -10,8 +10,11 @@
1010

1111
class Trainer(lightning.pytorch.Trainer):
1212
"""
13-
PINA custom Trainer class which allows to customize standard Lightning
14-
Trainer class for PINNs training.
13+
PINA custom Trainer class to extend the standard Lightning functionality.
14+
15+
This class enables specific features or behaviors required by the PINA
16+
framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
17+
to better support the training process in PINA.
1518
"""
1619

1720
def __init__(
@@ -29,42 +32,35 @@ def __init__(
2932
**kwargs,
3033
):
3134
"""
32-
Initialize the Trainer class for by calling Lightning costructor and
33-
adding many other functionalities.
34-
35-
:param solver: A pina:class:`SolverInterface` solver for the
36-
differential problem.
37-
:type solver: SolverInterface
38-
:param batch_size: How many samples per batch to load.
39-
If ``batch_size=None`` all
40-
samples are loaded and data are not batched, defaults to None.
41-
:type batch_size: int | None
42-
:param train_size: Percentage of elements in the train dataset.
43-
:type train_size: float
44-
:param test_size: Percentage of elements in the test dataset.
45-
:type test_size: float
46-
:param val_size: Percentage of elements in the val dataset.
47-
:type val_size: float
48-
:param compile: if True model is compiled before training,
49-
default False. For Windows users compilation is always disabled.
50-
:type compile: bool
51-
:param automatic_batching: if True automatic PyTorch batching is
52-
performed. Please avoid using automatic batching when batch_size is
53-
large, default False.
54-
:type automatic_batching: bool
55-
:param num_workers: Number of worker threads for data loading.
56-
Default 0 (serial loading).
57-
:type num_workers: int
58-
:param pin_memory: Whether to use pinned memory for faster data
59-
transfer to GPU. Default False.
60-
:type pin_memory: bool
61-
:param shuffle: Whether to shuffle the data for training. Default True.
62-
:type pin_memory: bool
35+
Initialization of the :class:`Trainer` class.
36+
37+
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
38+
solver used to solve a :class:`~pina.problem.AbstractProblem`.
39+
:param int batch_size: The number of samples per batch to load.
40+
If ``None``, all samples are loaded and data is not batched.
41+
Default is ``None``.
42+
:param float train_size: The percentage of elements to include in the
43+
training dataset. Default is ``1.0``.
44+
:param float test_size: The percentage of elements to include in the
45+
test dataset. Default is ``0.0``.
46+
:param float val_size: The percentage of elements to include in the
47+
validation dataset. Default is ``0.0``.
48+
:param bool compile: If ``True``, the model is compiled before training.
49+
Default is ``False``. For Windows users, it is always disabled.
50+
:param bool automatic_batching: If ``True``, automatic PyTorch batching
51+
is performed. Avoid using automatic batching when ``batch_size`` is
52+
large. Default is ``False``.
53+
:param int num_workers: The number of worker threads for data loading.
54+
Default is ``0`` (serial loading).
55+
:param bool pin_memory: Whether to use pinned memory for faster data
56+
transfer to GPU. Default is ``False``.
57+
:param bool shuffle: Whether to shuffle the data during training.
58+
Default is ``True``.
6359
6460
:Keyword Arguments:
65-
The additional keyword arguments specify the training setup
66-
and can be choosen from the `pytorch-lightning
67-
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
61+
Additional keyword arguments that specify the training setup.
62+
These can be selected from the pytorch-lightning Trainer API
63+
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
6864
"""
6965
# check consistency for init types
7066
self._check_input_consistency(
@@ -134,6 +130,10 @@ def __init__(
134130
}
135131

136132
def _move_to_device(self):
133+
"""
134+
Moves the ``unknown_parameters`` of an instance of
135+
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
136+
"""
137137
device = self._accelerator_connector._parallel_devices[0]
138138
# move parameters to device
139139
pb = self.solver.problem
@@ -155,9 +155,25 @@ def _create_datamodule(
155155
shuffle,
156156
):
157157
"""
158-
This method is used here because is resampling is needed
159-
during training, there is no need to define to touch the
160-
trainer dataloader, just call the method.
158+
This method is designed to handle the creation of a data module when
159+
resampling is needed during training. Instead of manually defining and
160+
modifying the trainer's dataloaders, this method is called to
161+
automatically configure the data module.
162+
163+
:param float train_size: The percentage of elements to include in the
164+
training dataset.
165+
:param float test_size: The percentage of elements to include in the
166+
test dataset.
167+
:param float val_size: The percentage of elements to include in the
168+
validation dataset.
169+
:param int batch_size: The number of samples per batch to load.
170+
:param bool automatic_batching: Whether to perform automatic batching
171+
with PyTorch.
172+
:param bool pin_memory: Whether to use pinned memory for faster data
173+
transfer to GPU.
174+
:param int num_workers: The number of worker threads for data loading.
175+
:param bool shuffle: Whether to shuffle the data during training.
176+
:raises RuntimeError: If not all conditions are sampled.
161177
"""
162178
if not self.solver.problem.are_all_domains_discretised:
163179
error_message = "\n".join(
@@ -188,33 +204,52 @@ def _create_datamodule(
188204

189205
def train(self, **kwargs):
190206
"""
191-
Train the solver method.
207+
Manage the training process of the solver.
192208
"""
193209
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
194210

195211
def test(self, **kwargs):
196212
"""
197-
Test the solver method.
213+
Manage the test process of the solver.
198214
"""
199215
return super().test(self.solver, datamodule=self.data_module, **kwargs)
200216

201217
@property
202218
def solver(self):
203219
"""
204-
Returning trainer solver.
220+
Get the solver.
221+
222+
:return: The solver.
223+
:rtype: SolverInterface
205224
"""
206225
return self._solver
207226

208227
@solver.setter
209228
def solver(self, solver):
229+
"""
230+
Set the solver.
231+
232+
:param SolverInterface solver: The solver to set.
233+
"""
210234
self._solver = solver
211235

212236
@staticmethod
213237
def _check_input_consistency(
214238
solver, train_size, test_size, val_size, automatic_batching, compile
215239
):
216240
"""
217-
Check the consistency of the input parameters."
241+
Verifies the consistency of the parameters for the solver configuration.
242+
243+
:param SolverInterface solver: The solver.
244+
:param float train_size: The percentage of elements to include in the
245+
training dataset.
246+
:param float test_size: The percentage of elements to include in the
247+
test dataset.
248+
:param float val_size: The percentage of elements to include in the
249+
validation dataset.
250+
:param bool automatic_batching: Whether to perform automatic batching
251+
with PyTorch.
252+
:param bool compile: If ``True``, the model is compiled before training.
218253
"""
219254

220255
check_consistency(solver, SolverInterface)
@@ -231,8 +266,14 @@ def _check_consistency_and_set_defaults(
231266
pin_memory, num_workers, shuffle, batch_size
232267
):
233268
"""
234-
Check the consistency of the input parameters and set the default
235-
values.
269+
Checks the consistency of input parameters and sets default values
270+
for missing or invalid parameters.
271+
272+
:param bool pin_memory: Whether to use pinned memory for faster data
273+
transfer to GPU.
274+
:param int num_workers: The number of worker threads for data loading.
275+
:param bool shuffle: Whether to shuffle the data during training.
276+
:param int batch_size: The number of samples per batch to load.
236277
"""
237278
if pin_memory is not None:
238279
check_consistency(pin_memory, bool)

pina/utils.py

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Utils module."""
1+
"""Module for utility functions."""
22

33
import types
44
from functools import reduce
@@ -12,35 +12,36 @@ def custom_warning_format(
1212
message, category, filename, lineno, file=None, line=None
1313
):
1414
"""
15-
Depewarning custom format.
15+
Custom warning formatting function.
1616
1717
:param str message: The warning message.
18-
:param class category: The warning category.
19-
:param str filename: The filename where the warning was raised.
20-
:param int lineno: The line number where the warning was raised.
21-
:param str file: The file object where the warning was raised.
22-
:param inr line: The line where the warning was raised.
18+
:param Warning category: The warning category.
19+
:param str filename: The filename where the warning is raised.
20+
:param int lineno: The line number where the warning is raised.
21+
:param str file: The file object where the warning is raised.
22+
Default is None.
23+
:param int line: The line where the warning is raised.
2324
:return: The formatted warning message.
2425
:rtype: str
2526
"""
2627
return f"{filename}: {category.__name__}: {message}\n"
2728

2829

2930
def check_consistency(object_, object_instance, subclass=False):
30-
"""Helper function to check object inheritance consistency.
31-
Given a specific ``'object'`` we check if the object is
32-
instance of a specific ``'object_instance'``, or in case
33-
``'subclass=True'`` we check if the object is subclass
34-
if the ``'object_instance'``.
35-
36-
:param (iterable or class object) object: The object to check the
37-
inheritance
38-
:param Object object_instance: The parent class from where the object
39-
is expected to inherit
40-
:param str object_name: The name of the object
41-
:param bool subclass: Check if is a subclass and not instance
42-
:raises ValueError: If the object does not inherit from the
43-
specified class
31+
"""
32+
Check if an object maintains inheritance consistency.
33+
34+
This function checks whether a given object is an instance of a specified
35+
class or, if ``subclass=True``, whether it is a subclass of the specified
36+
class.
37+
38+
:param object: The object to check.
39+
:type object: Iterable | Object
40+
:param Object object_instance: The expected parent class.
41+
:param bool subclass: If True, checks whether ``object_`` is a subclass
42+
of ``object_instance`` instead of an instance. Default is ``False``.
43+
:raises ValueError: If ``object_`` does not inherit from ``object_instance``
44+
as expected.
4445
"""
4546
if not isinstance(object_, (list, set, tuple)):
4647
object_ = [object_]
@@ -59,18 +60,28 @@ def check_consistency(object_, object_instance, subclass=False):
5960

6061
def labelize_forward(forward, input_variables, output_variables):
6162
"""
62-
Wrapper decorator to allow users to enable or disable the use of
63-
LabelTensors during the forward pass.
64-
65-
:param forward: The torch.nn.Module forward function.
66-
:type forward: Callable
67-
:param input_variables: The problem input variables.
68-
:type input_variables: list[str] | tuple[str]
69-
:param output_variables: The problem output variables.
70-
:type output_variables: list[str] | tuple[str]
63+
Decorator to enable or disable the use of :class:`~pina.LabelTensor`
64+
during the forward pass.
65+
66+
:param Callable forward: The forward function of a :class:`torch.nn.Module`.
67+
:param list[str] input_variables: The names of the input variables of a
68+
:class:`~pina.problem.AbstractProblem`.
69+
:param list[str] output_variables: The names of the output variables of a
70+
:class:`~pina.problem.AbstractProblem`.
71+
:return: The decorated forward function.
72+
:rtype: Callable
7173
"""
7274

7375
def wrapper(x):
76+
"""
77+
Decorated forward function.
78+
79+
:param LabelTensor x: The labelized input of the forward pass of an
80+
instance of :class:`torch.nn.Module`.
81+
:return: The labelized output of the forward pass of an instance of
82+
:class:`torch.nn.Module`.
83+
:rtype: LabelTensor
84+
"""
7485
x = x.extract(input_variables)
7586
output = forward(x)
7687
# keep it like this, directly using LabelTensor(...) raises errors
@@ -82,15 +93,32 @@ def wrapper(x):
8293
return wrapper
8394

8495

85-
def merge_tensors(tensors): # name to be changed
86-
"""TODO"""
96+
def merge_tensors(tensors):
97+
"""
98+
Merge a list of :class:`~pina.LabelTensor` instances into a single
99+
:class:`~pina.LabelTensor` tensor, by applying iteratively the cartesian
100+
product.
101+
102+
:param list[LabelTensor] tensors: The list of tensors to merge.
103+
:raises ValueError: If the list of tensors is empty.
104+
:return: The merged tensor.
105+
:rtype: LabelTensor
106+
"""
87107
if tensors:
88108
return reduce(merge_two_tensors, tensors[1:], tensors[0])
89109
raise ValueError("Expected at least one tensor")
90110

91111

92112
def merge_two_tensors(tensor1, tensor2):
93-
"""TODO"""
113+
"""
114+
Merge two :class:`~pina.LabelTensor` instances into a single
115+
:class:`~pina.LabelTensor` tensor, by applying the cartesian product.
116+
117+
:param LabelTensor tensor1: The first tensor to merge.
118+
:param LabelTensor tensor2: The second tensor to merge.
119+
:return: The merged tensor.
120+
:rtype: LabelTensor
121+
"""
94122
n1 = tensor1.shape[0]
95123
n2 = tensor2.shape[0]
96124

@@ -102,12 +130,14 @@ def merge_two_tensors(tensor1, tensor2):
102130

103131

104132
def torch_lhs(n, dim):
105-
"""Latin Hypercube Sampling torch routine.
106-
Sampling in range $[0, 1)^d$.
133+
"""
134+
The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$.
107135
108-
:param int n: number of samples
109-
:param int dim: dimensions of latin hypercube
110-
:return: samples
136+
:param int n: The number of points to sample.
137+
:param int dim: The number of dimensions of the sampling space.
138+
:raises TypeError: If `n` or `dim` are not integers.
139+
:raises ValueError: If `dim` is less than 1.
140+
:return: The sampled points.
111141
:rtype: torch.tensor
112142
"""
113143

@@ -137,22 +167,22 @@ def torch_lhs(n, dim):
137167

138168
def is_function(f):
139169
"""
140-
Checks whether the given object `f` is a function or lambda.
170+
Check if the given object is a function or a lambda.
141171
142-
:param object f: The object to be checked.
143-
:return: `True` if `f` is a function, `False` otherwise.
172+
:param Object f: The object to be checked.
173+
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
144174
:rtype: bool
145175
"""
146176
return isinstance(f, (types.FunctionType, types.LambdaType))
147177

148178

149179
def chebyshev_roots(n):
150180
"""
151-
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
181+
Compute the roots of the Chebyshev polynomial of degree ``n``.
152182
153-
:param int n: number of roots
154-
:return: roots
155-
:rtype: torch.tensor
183+
:param int n: The number of roots to return.
184+
:return: The roots of the Chebyshev polynomials.
185+
:rtype: torch.Tensor
156186
"""
157187
pi = torch.acos(torch.zeros(1)).item() * 2
158188
k = torch.arange(n)

0 commit comments

Comments
 (0)