@@ -45,16 +45,22 @@ def __init__(
4545 :param torch.nn.Module generator: The generator model.
4646 :param torch.nn.Module discriminator: The discriminator model.
4747 :param torch.nn.Module loss: The loss function to be minimized.
48- If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
48+ If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
49+ is used. Default is ``None``.
4950 :param Optimizer optimizer_generator: The optimizer for the generator.
50- If `None`, the Adam optimizer is used. Default is ``None``.
51- :param Optimizer optimizer_discriminator: The optimizer for the
52- discriminator. If `None`, the Adam optimizer is used.
51+ If `None`, the :class:`torch.optim.Adam` optimizer is used.
5352 Default is ``None``.
53+ :param Optimizer optimizer_discriminator: The optimizer for the
54+ discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
55+ used. Default is ``None``.
5456 :param Scheduler scheduler_generator: The learning rate scheduler for
5557 the generator.
58+ If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
59+ scheduler is used. Default is ``None``.
5660 :param Scheduler scheduler_discriminator: The learning rate scheduler
5761 for the discriminator.
62+ If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
63+ scheduler is used. Default is ``None``.
5864 :param float gamma: Ratio of expected loss for generator and
5965 discriminator. Default is ``0.3``.
6066 :param float lambda_k: Learning rate for control theory optimization.
@@ -109,7 +115,7 @@ def forward(self, x, mc_steps=20, variance=False):
109115 of the solution. Default is ``False``.
110116 :return: The expected value of the generator distribution. If
111117 ``variance=True``, the method returns also the variance.
112- :rtype: torch.Tensor | tuple( torch.Tensor, torch.Tensor)
118+ :rtype: torch.Tensor | tuple[ torch.Tensor, torch.Tensor]
113119 """
114120
115121 # sampling
@@ -143,7 +149,7 @@ def _train_generator(self, parameters, snapshots):
143149 :param torch.Tensor parameters: The input tensor.
144150 :param torch.Tensor snapshots: The target tensor.
145151 :return: The residual loss and the generator loss.
146- :rtype: tuple( torch.Tensor, torch.Tensor)
152+ :rtype: tuple[ torch.Tensor, torch.Tensor]
147153 """
148154 optimizer = self .optimizer_generator
149155 optimizer .zero_grad ()
@@ -170,7 +176,8 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
170176
171177 :param torch.Tensor outputs: The ``model``'s output for the current
172178 batch.
173- :param dict batch: The current batch of data.
179+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
180+ tuple containing a condition name and a dictionary of points.
174181 :param int batch_idx: The index of the current batch.
175182 """
176183 # increase by one the counter of optimization to save loggers
@@ -187,7 +194,7 @@ def _train_discriminator(self, parameters, snapshots):
187194 :param torch.Tensor parameters: The input tensor.
188195 :param torch.Tensor snapshots: The target tensor.
189196 :return: The residual loss and the generator loss.
190- :rtype: tuple( torch.Tensor, torch.Tensor)
197+ :rtype: tuple[ torch.Tensor, torch.Tensor]
191198 """
192199 optimizer = self .optimizer_discriminator
193200 optimizer .zero_grad ()
@@ -234,9 +241,12 @@ def optimization_cycle(self, batch):
234241 """
235242 The optimization cycle for the GAROM solver.
236243
237- :param tuple batch: The batch element in the dataloader.
238- :return: The loss of the optimization cycle.
239- :rtype: LabelTensor
244+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
245+ tuple containing a condition name and a dictionary of points.
246+ :return: The losses computed for all conditions in the batch, casted
247+ to a subclass of :class:`torch.Tensor`. It should return a dict
248+ containing the condition name and the associated scalar loss.
249+ :rtype: dict
240250 """
241251 condition_loss = {}
242252 for condition_name , points in batch :
@@ -265,7 +275,8 @@ def validation_step(self, batch):
265275 """
266276 The validation step for the PINN solver.
267277
268- :param dict batch: The batch of data to use in the validation step.
278+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
279+ tuple containing a condition name and a dictionary of points.
269280 :return: The loss of the validation step.
270281 :rtype: torch.Tensor
271282 """
@@ -287,7 +298,8 @@ def test_step(self, batch):
287298 """
288299 The test step for the PINN solver.
289300
290- :param dict batch: The batch of data to use in the test step.
301+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
302+ tuple containing a condition name and a dictionary of points.
291303 :return: The loss of the test step.
292304 :rtype: torch.Tensor
293305 """
0 commit comments