Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 18 additions & 67 deletions pina/solver/physic_informed_solver/competitive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import copy

from pina.problem import InverseProblem
from ...problem import InverseProblem
from .pinn_interface import PINNInterface
from ..solver import MultiSolverInterface

Expand Down Expand Up @@ -125,10 +125,15 @@ def training_step(self, batch):
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
# train model
self.optimizer_model.instance.zero_grad()
self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()
# train discriminator
self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(-loss)
self.optimizer_discriminator.instance.step()
return loss

Expand All @@ -144,20 +149,18 @@ def loss_phys(self, samples, equation):
samples and equation.
:rtype: LabelTensor
"""
# Train the model for one step
with torch.no_grad():
discriminator_bets = self.discriminator(samples)
loss_val = self._train_model(samples, equation, discriminator_bets)

# Detach samples from the existing computational graph and
# create a new one by setting requires_grad to True.
# In alternative set `retain_graph=True`.
samples = samples.detach()
samples.requires_grad_()

# Train the discriminator for one step
# Compute discriminator bets
discriminator_bets = self.discriminator(samples)
self._train_discriminator(samples, equation, discriminator_bets)

# Compute residual and multiply discriminator_bets
residual = self.compute_residual(samples=samples, equation=equation)
residual = residual * discriminator_bets

# Compute competitive residual.
loss_val = self.loss(
torch.zeros_like(residual, requires_grad=True),
residual,
)
return loss_val

def configure_optimizers(self):
Expand Down Expand Up @@ -213,58 +216,6 @@ def on_train_batch_end(self, outputs, batch, batch_idx):

return super().on_train_batch_end(outputs, batch, batch_idx)

def _train_discriminator(self, samples, equation, discriminator_bets):
"""
Trains the discriminator network of the Competitive PINN.

:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation representing
the physics.
:param Tensor discriminator_bets: Predictions made by the discriminator
network.
"""
# Compute residual. Detach since discriminator weights are fixed
residual = self.compute_residual(
samples=samples, equation=equation
).detach()

# Compute competitive residual, then maximise the loss
competitive_residual = residual * discriminator_bets
loss_val = -self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)

def _train_model(self, samples, equation, discriminator_bets):
"""
Trains the model network of the Competitive PINN.

:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation representing
the physics.
:param Tensor discriminator_bets: Predictions made by the discriminator.
network.
:return: The computed data loss.
:rtype: torch.Tensor
"""
# Compute residual
residual = self.compute_residual(samples=samples, equation=equation)
with torch.no_grad():
loss_residual = self.loss(torch.zeros_like(residual), residual)

# Compute competitive residual. Detach discriminator_bets
# to optimize only the generator model
competitive_residual = residual * discriminator_bets.detach()
loss_val = self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)
return loss_residual

@property
def neural_net(self):
"""
Expand Down
Loading