Skip to content

Commit cba9f90

Browse files
GiovanniCanalidario-coscia
authored andcommitted
fix doc solver
1 parent 33abb66 commit cba9f90

File tree

6 files changed

+376
-253
lines changed

6 files changed

+376
-253
lines changed

pina/solver/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
"""
2-
TODO
3-
"""
1+
"""Module for the solver classes."""
42

53
__all__ = [
64
"SolverInterface",

pina/solver/garom.py

Lines changed: 119 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module for GAROM"""
1+
"""Module for the GAROM solver."""
22

33
import torch
44
from torch.nn.modules.loss import _Loss
@@ -10,9 +10,9 @@
1010

1111
class GAROM(MultiSolverInterface):
1212
"""
13-
GAROM solver class. This class implements Generative Adversarial
14-
Reduced Order Model solver, using user specified ``models`` to solve
15-
a specific order reduction``problem``.
13+
GAROM solver class. This class implements Generative Adversarial Reduced
14+
Order Model solver, using user specified ``models`` to solve a specific
15+
order reduction ``problem``.
1616
1717
.. seealso::
1818
@@ -39,40 +39,28 @@ def __init__(
3939
regularizer=False,
4040
):
4141
"""
42-
:param AbstractProblem problem: The formualation of the problem.
43-
:param torch.nn.Module generator: The neural network model to use
44-
for the generator.
45-
:param torch.nn.Module discriminator: The neural network model to use
42+
Initialization of the :class:`GAROM` class.
43+
44+
:param AbstractProblem problem: The formulation of the problem.
45+
:param torch.nn.Module generator: The generator model.
46+
:param torch.nn.Module discriminator: The discriminator model.
47+
:param torch.nn.Module loss: The loss function to be minimized.
48+
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
49+
:param Optimizer optimizer_generator: The optimizer for the generator.
50+
If `None`, the Adam optimizer is used. Default is ``None``.
51+
:param Optimizer optimizer_discriminator: The optimizer for the
52+
discriminator. If `None`, the Adam optimizer is used.
53+
Default is ``None``.
54+
:param Scheduler scheduler_generator: The learning rate scheduler for
55+
the generator.
56+
:param Scheduler scheduler_discriminator: The learning rate scheduler
4657
for the discriminator.
47-
:param torch.nn.Module loss: The loss function used as minimizer,
48-
default ``None``. If ``loss`` is ``None`` the defualt
49-
``PowerLoss(p=1)`` is used, as in the original paper.
50-
:param Optimizer optimizer_generator: The neural
51-
network optimizer to use for the generator network
52-
, default is `torch.optim.Adam`.
53-
:param Optimizer optimizer_discriminator: The neural
54-
network optimizer to use for the discriminator network
55-
, default is `torch.optim.Adam`.
56-
:param Scheduler scheduler_generator: Learning
57-
rate scheduler for the generator.
58-
:param Scheduler scheduler_discriminator: Learning
59-
rate scheduler for the discriminator.
60-
:param dict scheduler_discriminator_kwargs: LR scheduler constructor
61-
keyword args.
62-
:param gamma: Ratio of expected loss for generator and discriminator,
63-
defaults to 0.3.
64-
:type gamma: float
65-
:param lambda_k: Learning rate for control theory optimization,
66-
defaults to 0.001.
67-
:type lambda_k: float
68-
:param regularizer: Regularization term in the GAROM loss,
69-
defaults to False.
70-
:type regularizer: bool
71-
72-
.. warning::
73-
The algorithm works only for data-driven model. Hence in the
74-
``problem`` definition the codition must only contain ``input``
75-
(e.g. coefficient parameters, time parameters), and ``target``.
58+
:param float gamma: Ratio of expected loss for generator and
59+
discriminator. Default is ``0.3``.
60+
:param float lambda_k: Learning rate for control theory optimization.
61+
Default is ``0.001``.
62+
:param bool regularizer: If ``True``, uses a regularization term in the
63+
GAROM loss. Default is ``False``.
7664
"""
7765

7866
# set loss
@@ -112,19 +100,15 @@ def __init__(
112100

113101
def forward(self, x, mc_steps=20, variance=False):
114102
"""
115-
Forward step for GAROM solver
103+
Forward pass implementation.
116104
117-
:param x: The input tensor.
118-
:type x: torch.Tensor
119-
:param mc_steps: Number of montecarlo samples to approximate the
120-
expected value, defaults to 20.
121-
:type mc_steps: int
122-
:param variance: Returining also the sample variance of the solution,
123-
defaults to False.
124-
:type variance: bool
105+
:param torch.Tensor x: The input tensor.
106+
:param int mc_steps: Number of Montecarlo samples to approximate the
107+
expected value. Default is ``20``.
108+
:param bool variance: If ``True``, the method returns also the variance
109+
of the solution. Default is ``False``.
125110
:return: The expected value of the generator distribution. If
126-
``variance=True`` also the
127-
sample variance is returned.
111+
``variance=True``, the method returns also the variance.
128112
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
129113
"""
130114

@@ -142,13 +126,24 @@ def forward(self, x, mc_steps=20, variance=False):
142126
return mean
143127

144128
def sample(self, x):
145-
"""TODO"""
129+
"""
130+
Sample from the generator distribution.
131+
132+
:param torch.Tensor x: The input tensor.
133+
:return: The generated sample.
134+
:rtype: torch.Tensor
135+
"""
146136
# sampling
147137
return self.generator(x)
148138

149139
def _train_generator(self, parameters, snapshots):
150140
"""
151-
Private method to train the generator network.
141+
Train the generator model.
142+
143+
:param torch.Tensor parameters: The input tensor.
144+
:param torch.Tensor snapshots: The target tensor.
145+
:return: The residual loss and the generator loss.
146+
:rtype: tuple(torch.Tensor, torch.Tensor)
152147
"""
153148
optimizer = self.optimizer_generator
154149
optimizer.zero_grad()
@@ -170,16 +165,13 @@ def _train_generator(self, parameters, snapshots):
170165

171166
def on_train_batch_end(self, outputs, batch, batch_idx):
172167
"""
173-
This method is called at the end of each training batch, and ovverides
174-
the PytorchLightining implementation for logging the checkpoints.
168+
This method is called at the end of each training batch and overrides
169+
the PyTorch Lightning implementation to log checkpoints.
175170
176-
:param torch.Tensor outputs: The output from the model for the
177-
current batch.
178-
:param tuple batch: The current batch of data.
171+
:param torch.Tensor outputs: The ``model``'s output for the current
172+
batch.
173+
:param dict batch: The current batch of data.
179174
:param int batch_idx: The index of the current batch.
180-
:return: Whatever is returned by the parent
181-
method ``on_train_batch_end``.
182-
:rtype: Any
183175
"""
184176
# increase by one the counter of optimization to save loggers
185177
(
@@ -190,7 +182,12 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
190182

191183
def _train_discriminator(self, parameters, snapshots):
192184
"""
193-
Private method to train the discriminator network.
185+
Train the discriminator model.
186+
187+
:param torch.Tensor parameters: The input tensor.
188+
:param torch.Tensor snapshots: The target tensor.
189+
:return: The residual loss and the generator loss.
190+
:rtype: tuple(torch.Tensor, torch.Tensor)
194191
"""
195192
optimizer = self.optimizer_discriminator
196193
optimizer.zero_grad()
@@ -215,8 +212,15 @@ def _train_discriminator(self, parameters, snapshots):
215212

216213
def _update_weights(self, d_loss_real, d_loss_fake):
217214
"""
218-
Private method to Update the weights of the generator and discriminator
219-
networks.
215+
Update the weights of the generator and discriminator models.
216+
217+
:param torch.Tensor d_loss_real: The discriminator loss computed on
218+
dataset samples.
219+
:param torch.Tensor d_loss_fake: The discriminator loss computed on
220+
generated samples.
221+
:return: The difference between the loss computed on the dataset samples
222+
and the loss computed on the generated samples.
223+
:rtype: torch.Tensor
220224
"""
221225

222226
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
@@ -227,11 +231,11 @@ def _update_weights(self, d_loss_real, d_loss_fake):
227231
return diff
228232

229233
def optimization_cycle(self, batch):
230-
"""GAROM solver training step.
234+
"""
235+
The optimization cycle for the GAROM solver.
231236
232-
:param batch: The batch element in the dataloader.
233-
:type batch: tuple
234-
:return: The sum of the loss functions.
237+
:param tuple batch: The batch element in the dataloader.
238+
:return: The loss of the optimization cycle.
235239
:rtype: LabelTensor
236240
"""
237241
condition_loss = {}
@@ -258,6 +262,13 @@ def optimization_cycle(self, batch):
258262
return condition_loss
259263

260264
def validation_step(self, batch):
265+
"""
266+
The validation step for the PINN solver.
267+
268+
:param dict batch: The batch of data to use in the validation step.
269+
:return: The loss of the validation step.
270+
:rtype: torch.Tensor
271+
"""
261272
condition_loss = {}
262273
for condition_name, points in batch:
263274
parameters, snapshots = (
@@ -273,6 +284,13 @@ def validation_step(self, batch):
273284
return loss
274285

275286
def test_step(self, batch):
287+
"""
288+
The test step for the PINN solver.
289+
290+
:param dict batch: The batch of data to use in the test step.
291+
:return: The loss of the test step.
292+
:rtype: torch.Tensor
293+
"""
276294
condition_loss = {}
277295
for condition_name, points in batch:
278296
parameters, snapshots = (
@@ -289,30 +307,60 @@ def test_step(self, batch):
289307

290308
@property
291309
def generator(self):
292-
"""TODO"""
310+
"""
311+
The generator model.
312+
313+
:return: The generator model.
314+
:rtype: torch.nn.Module
315+
"""
293316
return self.models[0]
294317

295318
@property
296319
def discriminator(self):
297-
"""TODO"""
320+
"""
321+
The discriminator model.
322+
323+
:return: The discriminator model.
324+
:rtype: torch.nn.Module
325+
"""
298326
return self.models[1]
299327

300328
@property
301329
def optimizer_generator(self):
302-
"""TODO"""
330+
"""
331+
The optimizer for the generator.
332+
333+
:return: The optimizer for the generator.
334+
:rtype: Optimizer
335+
"""
303336
return self.optimizers[0].instance
304337

305338
@property
306339
def optimizer_discriminator(self):
307-
"""TODO"""
340+
"""
341+
The optimizer for the discriminator.
342+
343+
:return: The optimizer for the discriminator.
344+
:rtype: Optimizer
345+
"""
308346
return self.optimizers[1].instance
309347

310348
@property
311349
def scheduler_generator(self):
312-
"""TODO"""
350+
"""
351+
The scheduler for the generator.
352+
353+
:return: The scheduler for the generator.
354+
:rtype: Scheduler
355+
"""
313356
return self.schedulers[0].instance
314357

315358
@property
316359
def scheduler_discriminator(self):
317-
"""TODO"""
360+
"""
361+
The scheduler for the discriminator.
362+
363+
:return: The scheduler for the discriminator.
364+
:rtype: Scheduler
365+
"""
318366
return self.schedulers[1].instance

pina/solver/physic_informed_solver/pinn_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def test_step(self, batch):
110110
def loss_data(self, input_pts, output_pts):
111111
"""
112112
Compute the data loss for the PINN solver by evaluating the loss
113-
between the network's output and the true solution. This method
114-
should only be overridden intentionally.
113+
between the network's output and the true solution. This method should
114+
not be overridden, if not intentionally.
115115
116116
:param LabelTensor input_pts: The input points to the neural network.
117117
:param LabelTensor output_pts: The true solution to compare with the

0 commit comments

Comments
 (0)