11"""Module for GAROM"""
22
33import torch
4-
4+ from torch . nn . modules . loss import _Loss
55from .solver import MultiSolverInterface
6- from ..utils import check_consistency
7- from ..loss .loss_interface import LossInterface
86from ..condition import InputTargetCondition
97from ..utils import check_consistency
108from ..loss import LossInterface , PowerLoss
11- from torch .nn .modules .loss import _Loss
129
1310
1411class GAROM (MultiSolverInterface ):
@@ -60,18 +57,22 @@ def __init__(
6057 rate scheduler for the generator.
6158 :param Scheduler scheduler_discriminator: Learning
6259 rate scheduler for the discriminator.
63- :param dict scheduler_discriminator_kwargs: LR scheduler constructor keyword args.
64- :param gamma: Ratio of expected loss for generator and discriminator, defaults to 0.3.
60+ :param dict scheduler_discriminator_kwargs: LR scheduler constructor
61+ keyword args.
62+ :param gamma: Ratio of expected loss for generator and discriminator,
63+ defaults to 0.3.
6564 :type gamma: float
66- :param lambda_k: Learning rate for control theory optimization, defaults to 0.001.
65+ :param lambda_k: Learning rate for control theory optimization,
66+ defaults to 0.001.
6767 :type lambda_k: float
68- :param regularizer: Regularization term in the GAROM loss, defaults to False.
68+ :param regularizer: Regularization term in the GAROM loss,
69+ defaults to False.
6970 :type regularizer: bool
7071
7172 .. warning::
72- The algorithm works only for data-driven model. Hence in the ``problem`` definition
73- the codition must only contain ``input`` (e.g. coefficient parameters, time
74- parameters), and ``target``.
73+ The algorithm works only for data-driven model. Hence in the
74+ ``problem`` definition the codition must only contain ``input``
75+ (e.g. coefficient parameters, time parameters), and ``target``.
7576 """
7677
7778 # set loss
@@ -118,9 +119,11 @@ def forward(self, x, mc_steps=20, variance=False):
118119 :param mc_steps: Number of montecarlo samples to approximate the
119120 expected value, defaults to 20.
120121 :type mc_steps: int
121- :param variance: Returining also the sample variance of the solution, defaults to False.
122+ :param variance: Returining also the sample variance of the solution,
123+ defaults to False.
122124 :type variance: bool
123- :return: The expected value of the generator distribution. If ``variance=True`` also the
125+ :return: The expected value of the generator distribution. If
126+ ``variance=True`` also the
124127 sample variance is returned.
125128 :rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
126129 """
@@ -139,6 +142,7 @@ def forward(self, x, mc_steps=20, variance=False):
139142 return mean
140143
141144 def sample (self , x ):
145+ """TODO"""
142146 # sampling
143147 return self .generator (x )
144148
@@ -285,24 +289,30 @@ def test_step(self, batch):
285289
286290 @property
287291 def generator (self ):
292+ """TODO"""
288293 return self .models [0 ]
289294
290295 @property
291296 def discriminator (self ):
297+ """TODO"""
292298 return self .models [1 ]
293299
294300 @property
295301 def optimizer_generator (self ):
302+ """TODO"""
296303 return self .optimizers [0 ].instance
297304
298305 @property
299306 def optimizer_discriminator (self ):
307+ """TODO"""
300308 return self .optimizers [1 ].instance
301309
302310 @property
303311 def scheduler_generator (self ):
312+ """TODO"""
304313 return self .schedulers [0 ].instance
305314
306315 @property
307316 def scheduler_discriminator (self ):
317+ """TODO"""
308318 return self .schedulers [1 ].instance
0 commit comments