Skip to content

Commit b947c66

Browse files
dario-cosciandem0
authored andcommitted
fixing competitive pinn (#470)
1 parent 1d85476 commit b947c66

File tree

1 file changed

+18
-67
lines changed

1 file changed

+18
-67
lines changed

pina/solver/physic_informed_solver/competitive_pinn.py

Lines changed: 18 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import copy
55

6-
from pina.problem import InverseProblem
6+
from ...problem import InverseProblem
77
from .pinn_interface import PINNInterface
88
from ..solver import MultiSolverInterface
99

@@ -125,10 +125,15 @@ def training_step(self, batch):
125125
:return: The sum of the loss functions.
126126
:rtype: LabelTensor
127127
"""
128+
# train model
128129
self.optimizer_model.instance.zero_grad()
129-
self.optimizer_discriminator.instance.zero_grad()
130130
loss = super().training_step(batch)
131+
self.manual_backward(loss)
131132
self.optimizer_model.instance.step()
133+
# train discriminator
134+
self.optimizer_discriminator.instance.zero_grad()
135+
loss = super().training_step(batch)
136+
self.manual_backward(-loss)
132137
self.optimizer_discriminator.instance.step()
133138
return loss
134139

@@ -144,20 +149,18 @@ def loss_phys(self, samples, equation):
144149
samples and equation.
145150
:rtype: LabelTensor
146151
"""
147-
# Train the model for one step
148-
with torch.no_grad():
149-
discriminator_bets = self.discriminator(samples)
150-
loss_val = self._train_model(samples, equation, discriminator_bets)
151-
152-
# Detach samples from the existing computational graph and
153-
# create a new one by setting requires_grad to True.
154-
# In alternative set `retain_graph=True`.
155-
samples = samples.detach()
156-
samples.requires_grad_()
157-
158-
# Train the discriminator for one step
152+
# Compute discriminator bets
159153
discriminator_bets = self.discriminator(samples)
160-
self._train_discriminator(samples, equation, discriminator_bets)
154+
155+
# Compute residual and multiply discriminator_bets
156+
residual = self.compute_residual(samples=samples, equation=equation)
157+
residual = residual * discriminator_bets
158+
159+
# Compute competitive residual.
160+
loss_val = self.loss(
161+
torch.zeros_like(residual, requires_grad=True),
162+
residual,
163+
)
161164
return loss_val
162165

163166
def configure_optimizers(self):
@@ -213,58 +216,6 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
213216

214217
return super().on_train_batch_end(outputs, batch, batch_idx)
215218

216-
def _train_discriminator(self, samples, equation, discriminator_bets):
217-
"""
218-
Trains the discriminator network of the Competitive PINN.
219-
220-
:param LabelTensor samples: Input samples to evaluate the physics loss.
221-
:param EquationInterface equation: The governing equation representing
222-
the physics.
223-
:param Tensor discriminator_bets: Predictions made by the discriminator
224-
network.
225-
"""
226-
# Compute residual. Detach since discriminator weights are fixed
227-
residual = self.compute_residual(
228-
samples=samples, equation=equation
229-
).detach()
230-
231-
# Compute competitive residual, then maximise the loss
232-
competitive_residual = residual * discriminator_bets
233-
loss_val = -self.loss(
234-
torch.zeros_like(competitive_residual, requires_grad=True),
235-
competitive_residual,
236-
)
237-
# prepare for optimizer step called in training step
238-
self.manual_backward(loss_val)
239-
240-
def _train_model(self, samples, equation, discriminator_bets):
241-
"""
242-
Trains the model network of the Competitive PINN.
243-
244-
:param LabelTensor samples: Input samples to evaluate the physics loss.
245-
:param EquationInterface equation: The governing equation representing
246-
the physics.
247-
:param Tensor discriminator_bets: Predictions made by the discriminator.
248-
network.
249-
:return: The computed data loss.
250-
:rtype: torch.Tensor
251-
"""
252-
# Compute residual
253-
residual = self.compute_residual(samples=samples, equation=equation)
254-
with torch.no_grad():
255-
loss_residual = self.loss(torch.zeros_like(residual), residual)
256-
257-
# Compute competitive residual. Detach discriminator_bets
258-
# to optimize only the generator model
259-
competitive_residual = residual * discriminator_bets.detach()
260-
loss_val = self.loss(
261-
torch.zeros_like(competitive_residual, requires_grad=True),
262-
competitive_residual,
263-
)
264-
# prepare for optimizer step called in training step
265-
self.manual_backward(loss_val)
266-
return loss_residual
267-
268219
@property
269220
def neural_net(self):
270221
"""

0 commit comments

Comments
 (0)