33import torch
44import copy
55
6- from pina .problem import InverseProblem
6+ from .. .problem import InverseProblem
77from .pinn_interface import PINNInterface
88from ..solver import MultiSolverInterface
99
@@ -125,10 +125,15 @@ def training_step(self, batch):
125125 :return: The sum of the loss functions.
126126 :rtype: LabelTensor
127127 """
128+ # train model
128129 self .optimizer_model .instance .zero_grad ()
129- self .optimizer_discriminator .instance .zero_grad ()
130130 loss = super ().training_step (batch )
131+ self .manual_backward (loss )
131132 self .optimizer_model .instance .step ()
133+ # train discriminator
134+ self .optimizer_discriminator .instance .zero_grad ()
135+ loss = super ().training_step (batch )
136+ self .manual_backward (- loss )
132137 self .optimizer_discriminator .instance .step ()
133138 return loss
134139
@@ -144,20 +149,18 @@ def loss_phys(self, samples, equation):
144149 samples and equation.
145150 :rtype: LabelTensor
146151 """
147- # Train the model for one step
148- with torch .no_grad ():
149- discriminator_bets = self .discriminator (samples )
150- loss_val = self ._train_model (samples , equation , discriminator_bets )
151-
152- # Detach samples from the existing computational graph and
153- # create a new one by setting requires_grad to True.
154- # In alternative set `retain_graph=True`.
155- samples = samples .detach ()
156- samples .requires_grad_ ()
157-
158- # Train the discriminator for one step
152+ # Compute discriminator bets
159153 discriminator_bets = self .discriminator (samples )
160- self ._train_discriminator (samples , equation , discriminator_bets )
154+
155+ # Compute residual and multiply discriminator_bets
156+ residual = self .compute_residual (samples = samples , equation = equation )
157+ residual = residual * discriminator_bets
158+
159+ # Compute competitive residual.
160+ loss_val = self .loss (
161+ torch .zeros_like (residual , requires_grad = True ),
162+ residual ,
163+ )
161164 return loss_val
162165
163166 def configure_optimizers (self ):
@@ -213,58 +216,6 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
213216
214217 return super ().on_train_batch_end (outputs , batch , batch_idx )
215218
216- def _train_discriminator (self , samples , equation , discriminator_bets ):
217- """
218- Trains the discriminator network of the Competitive PINN.
219-
220- :param LabelTensor samples: Input samples to evaluate the physics loss.
221- :param EquationInterface equation: The governing equation representing
222- the physics.
223- :param Tensor discriminator_bets: Predictions made by the discriminator
224- network.
225- """
226- # Compute residual. Detach since discriminator weights are fixed
227- residual = self .compute_residual (
228- samples = samples , equation = equation
229- ).detach ()
230-
231- # Compute competitive residual, then maximise the loss
232- competitive_residual = residual * discriminator_bets
233- loss_val = - self .loss (
234- torch .zeros_like (competitive_residual , requires_grad = True ),
235- competitive_residual ,
236- )
237- # prepare for optimizer step called in training step
238- self .manual_backward (loss_val )
239-
240- def _train_model (self , samples , equation , discriminator_bets ):
241- """
242- Trains the model network of the Competitive PINN.
243-
244- :param LabelTensor samples: Input samples to evaluate the physics loss.
245- :param EquationInterface equation: The governing equation representing
246- the physics.
247- :param Tensor discriminator_bets: Predictions made by the discriminator.
248- network.
249- :return: The computed data loss.
250- :rtype: torch.Tensor
251- """
252- # Compute residual
253- residual = self .compute_residual (samples = samples , equation = equation )
254- with torch .no_grad ():
255- loss_residual = self .loss (torch .zeros_like (residual ), residual )
256-
257- # Compute competitive residual. Detach discriminator_bets
258- # to optimize only the generator model
259- competitive_residual = residual * discriminator_bets .detach ()
260- loss_val = self .loss (
261- torch .zeros_like (competitive_residual , requires_grad = True ),
262- competitive_residual ,
263- )
264- # prepare for optimizer step called in training step
265- self .manual_backward (loss_val )
266- return loss_residual
267-
268219 @property
269220 def neural_net (self ):
270221 """
0 commit comments