Skip to content

Commit 419ac7f

Browse files
GiovanniCanalidario-coscia
authored andcommitted
fix device problem of residual weights
1 parent 4ad939f commit 419ac7f

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

pina/solver/physics_informed_solver/rba_pinn.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@ def __init__(
140140
# Set the loss function to return non-aggregated losses
141141
self._loss_fn = type(self._loss_fn)(reduction="none")
142142

143+
def on_train_start(self):
144+
"""
145+
Ensure that all residual weight buffers registered during initialization
146+
are moved to the correct computation device.
147+
"""
148+
# Move all weight buffers to the correct device
149+
for cond in self.problem.input_pts:
150+
151+
# Get the buffer for the current condition
152+
weight_buf = getattr(self, f"weight_{cond}")
153+
154+
# Move the buffer to the correct device
155+
weight_buf.data = weight_buf.data.to(self.device)
156+
self.weights[cond] = weight_buf
157+
143158
def training_step(self, batch, batch_idx, **kwargs):
144159
"""
145160
Solver training step. It computes the optimization cycle and aggregates
@@ -235,7 +250,7 @@ def _optimization_cycle(self, batch, batch_idx, **kwargs):
235250
idx = torch.arange(
236251
batch_idx * len_res,
237252
(batch_idx + 1) * len_res,
238-
device=res.device,
253+
device=self.weights[cond].device,
239254
) % len(self.problem.input_pts[cond])
240255

241256
losses[cond] = self._apply_reduction(
@@ -271,7 +286,7 @@ def _update_weights(self, batch, batch_idx, residuals):
271286

272287
# Compute normalized residuals
273288
res = residuals[cond]
274-
res_abs = res.abs()
289+
res_abs = torch.linalg.vector_norm(res, ord=2, dim=1, keepdim=True)
275290
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
276291

277292
# Get the correct indices for the weights. Modulus is used according
@@ -280,7 +295,7 @@ def _update_weights(self, batch, batch_idx, residuals):
280295
idx = torch.arange(
281296
batch_idx * len_pts,
282297
(batch_idx + 1) * len_pts,
283-
device=res.device,
298+
device=self.weights[cond].device,
284299
) % len(self.problem.input_pts[cond])
285300

286301
# Update weights

0 commit comments

Comments
 (0)