Skip to content

Commit 0a46f55

Browse files
GiovanniCanalindem0
authored andcommitted
Fixing self adaptive pinns (#469)
* fix self adaptive pinn * clean competitive pinn
1 parent 67570d8 commit 0a46f55

File tree

2 files changed

+9
-69
lines changed

2 files changed

+9
-69
lines changed

pina/solver/physic_informed_solver/competitive_pinn.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,6 @@ def loss_phys(self, samples, equation):
160160
self._train_discriminator(samples, equation, discriminator_bets)
161161
return loss_val
162162

163-
def loss_data(self, input_pts, output_pts):
164-
"""
165-
The data loss for the CompetitivePINN solver. It computes the loss
166-
between the network output against the true solution.
167-
168-
:param LabelTensor input_tensor: The input to the neural networks.
169-
:param LabelTensor output_tensor: The true solution to compare the
170-
network solution.
171-
:return: The computed data loss.
172-
:rtype: torch.Tensor
173-
"""
174-
loss_val = super().loss_data(input_pts, output_pts)
175-
# prepare for optimizer step called in training step
176-
loss_val.backward()
177-
return loss_val
178-
179163
def configure_optimizers(self):
180164
"""
181165
Optimizer configuration for the Competitive PINN solver.
@@ -252,7 +236,6 @@ def _train_discriminator(self, samples, equation, discriminator_bets):
252236
)
253237
# prepare for optimizer step called in training step
254238
self.manual_backward(loss_val)
255-
return
256239

257240
def _train_model(self, samples, equation, discriminator_bets):
258241
"""

pina/solver/physic_informed_solver/self_adaptive_pinn.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -178,62 +178,19 @@ def training_step(self, batch):
178178
:return: The sum of the loss functions.
179179
:rtype: LabelTensor
180180
"""
181-
self.optimizer_model.instance.zero_grad()
181+
# Weights optimization
182182
self.optimizer_weights.instance.zero_grad()
183183
loss = super().training_step(batch)
184-
self.optimizer_model.instance.step()
184+
self.manual_backward(-loss)
185185
self.optimizer_weights.instance.step()
186-
return loss
187-
188-
def loss_phys(self, samples, equation):
189-
"""
190-
Computes the physics loss for the SAPINN solver based on given
191-
samples and equation.
192-
193-
:param LabelTensor samples: The samples to evaluate the physics loss.
194-
:param EquationInterface equation: The governing equation
195-
representing the physics.
196-
:return: The physics loss calculated based on given
197-
samples and equation.
198-
:rtype: torch.Tensor
199-
"""
200-
# Train the weights
201-
weighted_loss = self._loss_phys(samples, equation)
202-
loss_value = -weighted_loss.as_subclass(torch.Tensor)
203-
self.manual_backward(loss_value)
204-
205-
# Detach samples from the existing computational graph and
206-
# create a new one by setting requires_grad to True.
207-
# In alternative set `retain_graph=True`.
208-
samples = samples.detach()
209-
samples.requires_grad_() # = True
210-
211-
# Train the model
212-
weighted_loss = self._loss_phys(samples, equation)
213-
loss_value = weighted_loss.as_subclass(torch.Tensor)
214-
self.manual_backward(loss_value)
215186

216-
return loss_value
217-
218-
def loss_data(self, input_pts, output_pts):
219-
"""
220-
Computes the data loss for the SAPINN solver based on input and
221-
output. It computes the loss between the
222-
network output against the true solution.
187+
# Model optimization
188+
self.optimizer_model.instance.zero_grad()
189+
loss = super().training_step(batch)
190+
self.manual_backward(loss)
191+
self.optimizer_model.instance.step()
223192

224-
:param LabelTensor input_pts: The input to the neural networks.
225-
:param LabelTensor output_pts: The true solution to compare the
226-
network solution.
227-
:return: The computed data loss.
228-
:rtype: torch.Tensor
229-
"""
230-
residual = self.forward(input_pts) - output_pts
231-
loss = self._vectorial_loss(
232-
torch.zeros_like(residual, requires_grad=True), residual
233-
)
234-
loss_value = self._vect_to_scalar(loss).as_subclass(torch.Tensor)
235-
self.manual_backward(loss_value)
236-
return loss_value
193+
return loss
237194

238195
def configure_optimizers(self):
239196
"""
@@ -330,7 +287,7 @@ def on_load_checkpoint(self, checkpoint):
330287
)
331288
return super().on_load_checkpoint(checkpoint)
332289

333-
def _loss_phys(self, samples, equation):
290+
def loss_phys(self, samples, equation):
334291
"""
335292
Computation of the physical loss for SelfAdaptive PINN solver.
336293

0 commit comments

Comments
 (0)