Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions pina/solver/physic_informed_solver/competitive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,6 @@ def loss_phys(self, samples, equation):
self._train_discriminator(samples, equation, discriminator_bets)
return loss_val

def loss_data(self, input_pts, output_pts):
"""
The data loss for the CompetitivePINN solver. It computes the loss
between the network output against the true solution.

:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
network solution.
:return: The computed data loss.
:rtype: torch.Tensor
"""
loss_val = super().loss_data(input_pts, output_pts)
# prepare for optimizer step called in training step
loss_val.backward()
return loss_val

def configure_optimizers(self):
"""
Optimizer configuration for the Competitive PINN solver.
Expand Down Expand Up @@ -252,7 +236,6 @@ def _train_discriminator(self, samples, equation, discriminator_bets):
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)
return

def _train_model(self, samples, equation, discriminator_bets):
"""
Expand Down
61 changes: 9 additions & 52 deletions pina/solver/physic_informed_solver/self_adaptive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,62 +178,19 @@ def training_step(self, batch):
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
self.optimizer_model.instance.zero_grad()
# Weights optimization
self.optimizer_weights.instance.zero_grad()
loss = super().training_step(batch)
self.optimizer_model.instance.step()
self.manual_backward(-loss)
self.optimizer_weights.instance.step()
return loss

def loss_phys(self, samples, equation):
"""
Computes the physics loss for the SAPINN solver based on given
samples and equation.

:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The physics loss calculated based on given
samples and equation.
:rtype: torch.Tensor
"""
# Train the weights
weighted_loss = self._loss_phys(samples, equation)
loss_value = -weighted_loss.as_subclass(torch.Tensor)
self.manual_backward(loss_value)

# Detach samples from the existing computational graph and
# create a new one by setting requires_grad to True.
# In alternative set `retain_graph=True`.
samples = samples.detach()
samples.requires_grad_() # = True

# Train the model
weighted_loss = self._loss_phys(samples, equation)
loss_value = weighted_loss.as_subclass(torch.Tensor)
self.manual_backward(loss_value)

return loss_value

def loss_data(self, input_pts, output_pts):
"""
Computes the data loss for the SAPINN solver based on input and
output. It computes the loss between the
network output against the true solution.
# Model optimization
self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()

:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
network solution.
:return: The computed data loss.
:rtype: torch.Tensor
"""
residual = self.forward(input_pts) - output_pts
loss = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
loss_value = self._vect_to_scalar(loss).as_subclass(torch.Tensor)
self.manual_backward(loss_value)
return loss_value
return loss

def configure_optimizers(self):
"""
Expand Down Expand Up @@ -330,7 +287,7 @@ def on_load_checkpoint(self, checkpoint):
)
return super().on_load_checkpoint(checkpoint)

def _loss_phys(self, samples, equation):
def loss_phys(self, samples, equation):
"""
Computation of the physical loss for SelfAdaptive PINN solver.

Expand Down
Loading