@@ -140,6 +140,21 @@ def __init__(
140140 # Set the loss function to return non-aggregated losses
141141 self ._loss_fn = type (self ._loss_fn )(reduction = "none" )
142142
143+ def on_train_start (self ):
144+ """
145+ Ensure that all residual weight buffers registered during initialization
146+ are moved to the correct computation device.
147+ """
148+ # Move all weight buffers to the correct device
149+ for cond in self .problem .input_pts :
150+
151+ # Get the buffer for the current condition
152+ weight_buf = getattr (self , f"weight_{ cond } " )
153+
154+ # Move the buffer to the correct device
155+ weight_buf .data = weight_buf .data .to (self .device )
156+ self .weights [cond ] = weight_buf
157+
143158 def training_step (self , batch , batch_idx , ** kwargs ):
144159 """
145160 Solver training step. It computes the optimization cycle and aggregates
@@ -235,7 +250,7 @@ def _optimization_cycle(self, batch, batch_idx, **kwargs):
235250 idx = torch .arange (
236251 batch_idx * len_res ,
237252 (batch_idx + 1 ) * len_res ,
238- device = res .device ,
253+ device = self . weights [ cond ] .device ,
239254 ) % len (self .problem .input_pts [cond ])
240255
241256 losses [cond ] = self ._apply_reduction (
@@ -271,7 +286,7 @@ def _update_weights(self, batch, batch_idx, residuals):
271286
272287 # Compute normalized residuals
273288 res = residuals [cond ]
274- res_abs = res . abs ( )
289+ res_abs = torch . linalg . vector_norm ( res , ord = 2 , dim = 1 , keepdim = True )
275290 r_norm = (self .eta * res_abs ) / (res_abs .max () + 1e-12 )
276291
277292 # Get the correct indices for the weights. Modulus is used according
@@ -280,7 +295,7 @@ def _update_weights(self, batch, batch_idx, residuals):
280295 idx = torch .arange (
281296 batch_idx * len_pts ,
282297 (batch_idx + 1 ) * len_pts ,
283- device = res .device ,
298+ device = self . weights [ cond ] .device ,
284299 ) % len (self .problem .input_pts [cond ])
285300
286301 # Update weights
0 commit comments