1- """Trainer module ."""
1+ """Module for the Trainer ."""
22
33import sys
44import torch
1010
1111class Trainer (lightning .pytorch .Trainer ):
1212 """
13- PINA custom Trainer class which allows to customize standard Lightning
14- Trainer class for PINNs training.
13+ PINA custom Trainer class to extend the standard Lightning functionality.
14+
15+ This class enables specific features or behaviors required by the PINA
16+ framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
17+ to better support the training process in PINA.
1518 """
1619
1720 def __init__ (
@@ -29,42 +32,35 @@ def __init__(
2932 ** kwargs ,
3033 ):
3134 """
32- Initialize the Trainer class for by calling Lightning costructor and
33- adding many other functionalities.
34-
35- :param solver: A pina:class:`SolverInterface` solver for the
36- differential problem.
37- :type solver: SolverInterface
38- :param batch_size: How many samples per batch to load.
39- If ``batch_size=None`` all
40- samples are loaded and data are not batched, defaults to None.
41- :type batch_size: int | None
42- :param train_size: Percentage of elements in the train dataset.
43- :type train_size: float
44- :param test_size: Percentage of elements in the test dataset.
45- :type test_size: float
46- :param val_size: Percentage of elements in the val dataset.
47- :type val_size: float
48- :param compile: if True model is compiled before training,
49- default False. For Windows users compilation is always disabled.
50- :type compile: bool
51- :param automatic_batching: if True automatic PyTorch batching is
52- performed. Please avoid using automatic batching when batch_size is
53- large, default False.
54- :type automatic_batching: bool
55- :param num_workers: Number of worker threads for data loading.
56- Default 0 (serial loading).
57- :type num_workers: int
58- :param pin_memory: Whether to use pinned memory for faster data
59- transfer to GPU. Default False.
60- :type pin_memory: bool
61- :param shuffle: Whether to shuffle the data for training. Default True.
62- :type pin_memory: bool
35+ Initialization of the :class:`Trainer` class.
36+
37+ :param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
38+ solver used to solve a :class:`~pina.problem.AbstractProblem`.
39+ :param int batch_size: The number of samples per batch to load.
40+ If ``None``, all samples are loaded and data is not batched.
41+ Default is ``None``.
42+ :param float train_size: The percentage of elements to include in the
43+ training dataset. Default is ``1.0``.
44+ :param float test_size: The percentage of elements to include in the
45+ test dataset. Default is ``0.0``.
46+ :param float val_size: The percentage of elements to include in the
47+ validation dataset. Default is ``0.0``.
48+ :param bool compile: If ``True``, the model is compiled before training.
49+ Default is ``False``. For Windows users, it is always disabled.
50+ :param bool automatic_batching: If ``True``, automatic PyTorch batching
51+ is performed. Avoid using automatic batching when ``batch_size`` is
52+ large. Default is ``False``.
53+ :param int num_workers: The number of worker threads for data loading.
54+ Default is ``0`` (serial loading).
55+ :param bool pin_memory: Whether to use pinned memory for faster data
56+ transfer to GPU. Default is ``False``.
57+ :param bool shuffle: Whether to shuffle the data during training.
58+ Default is ``True``.
6359
6460 :Keyword Arguments:
65- The additional keyword arguments specify the training setup
66- and can be choosen from the ` pytorch-lightning
67- Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
61+ Additional keyword arguments that specify the training setup.
62+ These can be selected from the pytorch-lightning Trainer API
63+ <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
6864 """
6965 # check consistency for init types
7066 self ._check_input_consistency (
@@ -134,6 +130,10 @@ def __init__(
134130 }
135131
136132 def _move_to_device (self ):
133+ """
134+ Moves the ``unknown_parameters`` of an instance of
135+ :class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
136+ """
137137 device = self ._accelerator_connector ._parallel_devices [0 ]
138138 # move parameters to device
139139 pb = self .solver .problem
@@ -155,9 +155,25 @@ def _create_datamodule(
155155 shuffle ,
156156 ):
157157 """
158- This method is used here because is resampling is needed
159- during training, there is no need to define to touch the
160- trainer dataloader, just call the method.
158+ This method is designed to handle the creation of a data module when
159+ resampling is needed during training. Instead of manually defining and
160+ modifying the trainer's dataloaders, this method is called to
161+ automatically configure the data module.
162+
163+ :param float train_size: The percentage of elements to include in the
164+ training dataset.
165+ :param float test_size: The percentage of elements to include in the
166+ test dataset.
167+ :param float val_size: The percentage of elements to include in the
168+ validation dataset.
169+ :param int batch_size: The number of samples per batch to load.
170+ :param bool automatic_batching: Whether to perform automatic batching
171+ with PyTorch.
172+ :param bool pin_memory: Whether to use pinned memory for faster data
173+ transfer to GPU.
174+ :param int num_workers: The number of worker threads for data loading.
175+ :param bool shuffle: Whether to shuffle the data during training.
176+ :raises RuntimeError: If not all conditions are sampled.
161177 """
162178 if not self .solver .problem .are_all_domains_discretised :
163179 error_message = "\n " .join (
@@ -188,33 +204,52 @@ def _create_datamodule(
188204
189205 def train (self , ** kwargs ):
190206 """
191- Train the solver method .
207+ Manage the training process of the solver .
192208 """
193209 return super ().fit (self .solver , datamodule = self .data_module , ** kwargs )
194210
195211 def test (self , ** kwargs ):
196212 """
197- Test the solver method .
213+ Manage the test process of the solver .
198214 """
199215 return super ().test (self .solver , datamodule = self .data_module , ** kwargs )
200216
201217 @property
202218 def solver (self ):
203219 """
204- Returning trainer solver.
220+ Get the solver.
221+
222+ :return: The solver.
223+ :rtype: SolverInterface
205224 """
206225 return self ._solver
207226
208227 @solver .setter
209228 def solver (self , solver ):
229+ """
230+ Set the solver.
231+
232+ :param SolverInterface solver: The solver to set.
233+ """
210234 self ._solver = solver
211235
212236 @staticmethod
213237 def _check_input_consistency (
214238 solver , train_size , test_size , val_size , automatic_batching , compile
215239 ):
216240 """
217- Check the consistency of the input parameters."
241+ Verifies the consistency of the parameters for the solver configuration.
242+
243+ :param SolverInterface solver: The solver.
244+ :param float train_size: The percentage of elements to include in the
245+ training dataset.
246+ :param float test_size: The percentage of elements to include in the
247+ test dataset.
248+ :param float val_size: The percentage of elements to include in the
249+ validation dataset.
250+ :param bool automatic_batching: Whether to perform automatic batching
251+ with PyTorch.
252+ :param bool compile: If ``True``, the model is compiled before training.
218253 """
219254
220255 check_consistency (solver , SolverInterface )
@@ -231,8 +266,14 @@ def _check_consistency_and_set_defaults(
231266 pin_memory , num_workers , shuffle , batch_size
232267 ):
233268 """
234- Check the consistency of the input parameters and set the default
235- values.
269+ Checks the consistency of input parameters and sets default values
270+ for missing or invalid parameters.
271+
272+ :param bool pin_memory: Whether to use pinned memory for faster data
273+ transfer to GPU.
274+ :param int num_workers: The number of worker threads for data loading.
275+ :param bool shuffle: Whether to shuffle the data during training.
276+ :param int batch_size: The number of samples per batch to load.
236277 """
237278 if pin_memory is not None :
238279 check_consistency (pin_memory , bool )
0 commit comments