1- """Module for Causal PINN."""
1+ """Module for the Causal PINN solver ."""
22
33import torch
44
99
1010class CausalPINN (PINN ):
1111 r"""
12- Causal Physics Informed Neural Network (CausalPINN) solver class.
13- This class implements Causal Physics Informed Neural
14- Network solver, using a user specified ``model`` to solve a specific
15- ``problem``. It can be used for solving both forward and inverse problems.
12+ Causal Physics- Informed Neural Network (CausalPINN) solver class.
13+ This class implements the Causal Physics- Informed Neural Network solver,
14+ using a user specified ``model`` to solve a specific ``problem``.
15+ It can be used to solve both forward and inverse problems.
1616
17- The Causal Physics Informed Network aims to find
18- the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
19- of the differential problem:
17+ The Causal Physics-Informed Neural Network solver aims to find the solution
18+ :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem:
2019
2120 .. math::
2221
@@ -26,7 +25,7 @@ class CausalPINN(PINN):
2625 \mathbf{x}\in\partial\Omega
2726 \end{cases}
2827
29- minimizing the loss function
28+ minimizing the loss function:
3029
3130 .. math::
3231 \mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t}
@@ -45,14 +44,12 @@ class CausalPINN(PINN):
4544 .. math::
4645 \omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right).
4746
48- :math:`\epsilon` is an hyperparameter, default set to :math:`100`, while
49- :math:`\mathcal{L}` is a specific loss function,
50- default Mean Square Error:
47+ :math:`\epsilon` is an hyperparameter, set by default to :math:`100`, while
48+ :math:`\mathcal{L}` is a specific loss function, typically the MSE:
5149
5250 .. math::
5351 \mathcal{L}(v) = \| v \|^2_2.
5452
55-
5653 .. seealso::
5754
5855 **Original reference**: Wang, Sifan, Shyam Sankaran, and Paris
@@ -62,9 +59,8 @@ class CausalPINN(PINN):
6259 DOI `10.1016 <https://doi.org/10.1016/j.cma.2024.116813>`_.
6360
6461 .. note::
65- This class can only work for problems inheriting
66- from at least
67- :class:`~pina.problem.timedep_problem.TimeDependentProblem` class.
62+ This class is only compatible with problems that inherit from the
63+ :class:`~pina.problem.TimeDependentProblem` class.
6864 """
6965
7066 def __init__ (
@@ -78,17 +74,23 @@ def __init__(
7874 eps = 100 ,
7975 ):
8076 """
81- :param torch.nn.Module model: The neural network model to use.
82- :param AbstractProblem problem: The formulation of the problem.
83- :param torch.optim.Optimizer optimizer: The neural network optimizer to
84- use; default `None`.
85- :param torch.optim.LRScheduler scheduler: Learning rate scheduler;
86- default `None`.
87- :param WeightingInterface weighting: The weighting schema to use;
88- default `None`.
89- :param torch.nn.Module loss: The loss function to be minimized;
90- default `None`.
91- :param float eps: The exponential decay parameter; default `100`.
77+ Initialization of the :class:`CausalPINN` class.
78+
79+ :param AbstractProblem problem: The problem to be solved. It must
80+ inherit from at least :class:`~pina.problem.TimeDependentProblem`.
81+ :param torch.nn.Module model: The neural network model to be used.
82+ :param torch.optim.Optimizer optimizer: The optimizer to be used
83+ If `None`, the Adam optimizer is used. Default is ``None``.
84+ :param torch.optim.LRScheduler scheduler: Learning rate scheduler.
85+ If `None`, the constant learning rate scheduler is used.
86+ Default is ``None``.
87+ :param WeightingInterface weighting: The weighting schema to be used.
88+ If `None`, no weighting schema is used. Default is ``None``.
89+ :param torch.nn.Module loss: The loss function to be minimized.
90+ If `None`, the Mean Squared Error (MSE) loss is used.
91+ Default is `None`.
92+ :param float eps: The exponential decay parameter. Default is ``100``.
93+ :raises ValueError: If the problem is not a TimeDependentProblem.
9294 """
9395 super ().__init__ (
9496 model = model ,
@@ -110,14 +112,12 @@ def __init__(
110112
111113 def loss_phys (self , samples , equation ):
112114 """
113- Computes the physics loss for the Causal PINN solver based on given
114- samples and equation.
115+ Computes the physics loss for the physics-informed solver based on the
116+ provided samples and equation.
115117
116118 :param LabelTensor samples: The samples to evaluate the physics loss.
117- :param EquationInterface equation: The governing equation
118- representing the physics.
119- :return: The physics loss calculated based on given
120- samples and equation.
119+ :param EquationInterface equation: The governing equation.
120+ :return: The computed physics loss.
121121 :rtype: LabelTensor
122122 """
123123 # split sequentially ordered time tensors into chunks
@@ -146,13 +146,16 @@ def loss_phys(self, samples, equation):
146146 def eps (self ):
147147 """
148148 The exponential decay parameter.
149+
150+ :return: The exponential decay parameter.
151+ :rtype: float
149152 """
150153 return self ._eps
151154
152155 @eps .setter
153156 def eps (self , value ):
154157 """
155- Setter method for the eps parameter.
158+ Set the exponential decay parameter.
156159
157160 :param float value: The exponential decay parameter.
158161 """
@@ -161,10 +164,10 @@ def eps(self, value):
161164
162165 def _sort_label_tensor (self , tensor ):
163166 """
164- Sorts the label tensor based on time variables.
167+ Sort the tensor with respect to the temporal variables.
165168
166- :param LabelTensor tensor: The label tensor to be sorted.
167- :return: The sorted label tensor based on time variables.
169+ :param LabelTensor tensor: The tensor to be sorted.
170+ :return: The tensor sorted with respect to the temporal variables.
168171 :rtype: LabelTensor
169172 """
170173 # labels input tensors
@@ -179,11 +182,12 @@ def _sort_label_tensor(self, tensor):
179182
180183 def _split_tensor_into_chunks (self , tensor ):
181184 """
182- Splits the label tensor into chunks based on time.
185+ Split the tensor into chunks based on time.
183186
184- :param LabelTensor tensor: The label tensor to be split.
185- :return: Tuple containing the chunks and the original labels.
186- :rtype: Tuple[List[LabelTensor], List]
187+ :param LabelTensor tensor: The tensor to be split.
188+ :return: A tuple containing the list of tensor chunks and the
189+ corresponding labels.
190+ :rtype: tuple[list[LabelTensor], list[str]]
187191 """
188192 # extract labels
189193 labels = tensor .labels
@@ -199,7 +203,7 @@ def _split_tensor_into_chunks(self, tensor):
199203
200204 def _compute_weights (self , loss ):
201205 """
202- Computes the weights for the physics loss based on the cumulative loss.
206+ Compute the weights for the physics loss based on the cumulative loss.
203207
204208 :param LabelTensor loss: The physics loss values.
205209 :return: The computed weights for the physics loss.
0 commit comments