1- """Module for GAROM"""
1+ """Module for the GAROM solver. """
22
33import torch
44from torch .nn .modules .loss import _Loss
1010
1111class GAROM (MultiSolverInterface ):
1212 """
13- GAROM solver class. This class implements Generative Adversarial
14- Reduced Order Model solver, using user specified ``models`` to solve
15- a specific order reduction``problem``.
13+ GAROM solver class. This class implements Generative Adversarial Reduced
14+ Order Model solver, using user specified ``models`` to solve a specific
15+ order reduction ``problem``.
1616
1717 .. seealso::
1818
@@ -39,40 +39,28 @@ def __init__(
3939 regularizer = False ,
4040 ):
4141 """
42- :param AbstractProblem problem: The formualation of the problem.
43- :param torch.nn.Module generator: The neural network model to use
44- for the generator.
45- :param torch.nn.Module discriminator: The neural network model to use
42+ Initialization of the :class:`GAROM` class.
43+
44+ :param AbstractProblem problem: The formulation of the problem.
45+ :param torch.nn.Module generator: The generator model.
46+ :param torch.nn.Module discriminator: The discriminator model.
47+ :param torch.nn.Module loss: The loss function to be minimized.
48+ If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
49+ :param Optimizer optimizer_generator: The optimizer for the generator.
50+ If `None`, the Adam optimizer is used. Default is ``None``.
51+ :param Optimizer optimizer_discriminator: The optimizer for the
52+ discriminator. If `None`, the Adam optimizer is used.
53+ Default is ``None``.
54+ :param Scheduler scheduler_generator: The learning rate scheduler for
55+ the generator.
56+ :param Scheduler scheduler_discriminator: The learning rate scheduler
4657 for the discriminator.
47- :param torch.nn.Module loss: The loss function used as minimizer,
48- default ``None``. If ``loss`` is ``None`` the defualt
49- ``PowerLoss(p=1)`` is used, as in the original paper.
50- :param Optimizer optimizer_generator: The neural
51- network optimizer to use for the generator network
52- , default is `torch.optim.Adam`.
53- :param Optimizer optimizer_discriminator: The neural
54- network optimizer to use for the discriminator network
55- , default is `torch.optim.Adam`.
56- :param Scheduler scheduler_generator: Learning
57- rate scheduler for the generator.
58- :param Scheduler scheduler_discriminator: Learning
59- rate scheduler for the discriminator.
60- :param dict scheduler_discriminator_kwargs: LR scheduler constructor
61- keyword args.
62- :param gamma: Ratio of expected loss for generator and discriminator,
63- defaults to 0.3.
64- :type gamma: float
65- :param lambda_k: Learning rate for control theory optimization,
66- defaults to 0.001.
67- :type lambda_k: float
68- :param regularizer: Regularization term in the GAROM loss,
69- defaults to False.
70- :type regularizer: bool
71-
72- .. warning::
73- The algorithm works only for data-driven model. Hence in the
74- ``problem`` definition the codition must only contain ``input``
75- (e.g. coefficient parameters, time parameters), and ``target``.
58+ :param float gamma: Ratio of expected loss for generator and
59+ discriminator. Default is ``0.3``.
60+ :param float lambda_k: Learning rate for control theory optimization.
61+ Default is ``0.001``.
62+ :param bool regularizer: If ``True``, uses a regularization term in the
63+ GAROM loss. Default is ``False``.
7664 """
7765
7866 # set loss
@@ -112,19 +100,15 @@ def __init__(
112100
113101 def forward (self , x , mc_steps = 20 , variance = False ):
114102 """
115- Forward step for GAROM solver
103+ Forward pass implementation.
116104
117- :param x: The input tensor.
118- :type x: torch.Tensor
119- :param mc_steps: Number of montecarlo samples to approximate the
120- expected value, defaults to 20.
121- :type mc_steps: int
122- :param variance: Returining also the sample variance of the solution,
123- defaults to False.
124- :type variance: bool
105+ :param torch.Tensor x: The input tensor.
106+ :param int mc_steps: Number of Montecarlo samples to approximate the
107+ expected value. Default is ``20``.
108+ :param bool variance: If ``True``, the method returns also the variance
109+ of the solution. Default is ``False``.
125110 :return: The expected value of the generator distribution. If
126- ``variance=True`` also the
127- sample variance is returned.
111+ ``variance=True``, the method returns also the variance.
128112 :rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
129113 """
130114
@@ -142,13 +126,24 @@ def forward(self, x, mc_steps=20, variance=False):
142126 return mean
143127
144128 def sample (self , x ):
145- """TODO"""
129+ """
130+ Sample from the generator distribution.
131+
132+ :param torch.Tensor x: The input tensor.
133+ :return: The generated sample.
134+ :rtype: torch.Tensor
135+ """
146136 # sampling
147137 return self .generator (x )
148138
149139 def _train_generator (self , parameters , snapshots ):
150140 """
151- Private method to train the generator network.
141+ Train the generator model.
142+
143+ :param torch.Tensor parameters: The input tensor.
144+ :param torch.Tensor snapshots: The target tensor.
145+ :return: The residual loss and the generator loss.
146+ :rtype: tuple(torch.Tensor, torch.Tensor)
152147 """
153148 optimizer = self .optimizer_generator
154149 optimizer .zero_grad ()
@@ -170,16 +165,13 @@ def _train_generator(self, parameters, snapshots):
170165
171166 def on_train_batch_end (self , outputs , batch , batch_idx ):
172167 """
173- This method is called at the end of each training batch, and ovverides
174- the PytorchLightining implementation for logging the checkpoints.
168+ This method is called at the end of each training batch and overrides
169+ the PyTorch Lightning implementation to log checkpoints.
175170
176- :param torch.Tensor outputs: The output from the model for the
177- current batch.
178- :param tuple batch: The current batch of data.
171+ :param torch.Tensor outputs: The `` model``'s output for the current
172+ batch.
173+ :param dict batch: The current batch of data.
179174 :param int batch_idx: The index of the current batch.
180- :return: Whatever is returned by the parent
181- method ``on_train_batch_end``.
182- :rtype: Any
183175 """
184176 # increase by one the counter of optimization to save loggers
185177 (
@@ -190,7 +182,12 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
190182
191183 def _train_discriminator (self , parameters , snapshots ):
192184 """
193- Private method to train the discriminator network.
185+ Train the discriminator model.
186+
187+ :param torch.Tensor parameters: The input tensor.
188+ :param torch.Tensor snapshots: The target tensor.
189+ :return: The residual loss and the generator loss.
190+ :rtype: tuple(torch.Tensor, torch.Tensor)
194191 """
195192 optimizer = self .optimizer_discriminator
196193 optimizer .zero_grad ()
@@ -215,8 +212,15 @@ def _train_discriminator(self, parameters, snapshots):
215212
216213 def _update_weights (self , d_loss_real , d_loss_fake ):
217214 """
218- Private method to Update the weights of the generator and discriminator
219- networks.
215+ Update the weights of the generator and discriminator models.
216+
217+ :param torch.Tensor d_loss_real: The discriminator loss computed on
218+ dataset samples.
219+ :param torch.Tensor d_loss_fake: The discriminator loss computed on
220+ generated samples.
221+ :return: The difference between the loss computed on the dataset samples
222+ and the loss computed on the generated samples.
223+ :rtype: torch.Tensor
220224 """
221225
222226 diff = torch .mean (self .gamma * d_loss_real - d_loss_fake )
@@ -227,11 +231,11 @@ def _update_weights(self, d_loss_real, d_loss_fake):
227231 return diff
228232
229233 def optimization_cycle (self , batch ):
230- """GAROM solver training step.
234+ """
235+ The optimization cycle for the GAROM solver.
231236
232- :param batch: The batch element in the dataloader.
233- :type batch: tuple
234- :return: The sum of the loss functions.
237+ :param tuple batch: The batch element in the dataloader.
238+ :return: The loss of the optimization cycle.
235239 :rtype: LabelTensor
236240 """
237241 condition_loss = {}
@@ -258,6 +262,13 @@ def optimization_cycle(self, batch):
258262 return condition_loss
259263
260264 def validation_step (self , batch ):
265+ """
266+ The validation step for the PINN solver.
267+
268+ :param dict batch: The batch of data to use in the validation step.
269+ :return: The loss of the validation step.
270+ :rtype: torch.Tensor
271+ """
261272 condition_loss = {}
262273 for condition_name , points in batch :
263274 parameters , snapshots = (
@@ -273,6 +284,13 @@ def validation_step(self, batch):
273284 return loss
274285
275286 def test_step (self , batch ):
287+ """
288+ The test step for the PINN solver.
289+
290+ :param dict batch: The batch of data to use in the test step.
291+ :return: The loss of the test step.
292+ :rtype: torch.Tensor
293+ """
276294 condition_loss = {}
277295 for condition_name , points in batch :
278296 parameters , snapshots = (
@@ -289,30 +307,60 @@ def test_step(self, batch):
289307
290308 @property
291309 def generator (self ):
292- """TODO"""
310+ """
311+ The generator model.
312+
313+ :return: The generator model.
314+ :rtype: torch.nn.Module
315+ """
293316 return self .models [0 ]
294317
295318 @property
296319 def discriminator (self ):
297- """TODO"""
320+ """
321+ The discriminator model.
322+
323+ :return: The discriminator model.
324+ :rtype: torch.nn.Module
325+ """
298326 return self .models [1 ]
299327
300328 @property
301329 def optimizer_generator (self ):
302- """TODO"""
330+ """
331+ The optimizer for the generator.
332+
333+ :return: The optimizer for the generator.
334+ :rtype: Optimizer
335+ """
303336 return self .optimizers [0 ].instance
304337
305338 @property
306339 def optimizer_discriminator (self ):
307- """TODO"""
340+ """
341+ The optimizer for the discriminator.
342+
343+ :return: The optimizer for the discriminator.
344+ :rtype: Optimizer
345+ """
308346 return self .optimizers [1 ].instance
309347
310348 @property
311349 def scheduler_generator (self ):
312- """TODO"""
350+ """
351+ The scheduler for the generator.
352+
353+ :return: The scheduler for the generator.
354+ :rtype: Scheduler
355+ """
313356 return self .schedulers [0 ].instance
314357
315358 @property
316359 def scheduler_discriminator (self ):
317- """TODO"""
360+ """
361+ The scheduler for the discriminator.
362+
363+ :return: The scheduler for the discriminator.
364+ :rtype: Scheduler
365+ """
318366 return self .schedulers [1 ].instance
0 commit comments