@@ -129,8 +129,8 @@ def __init__(
129129 self .mgpo_beta = mgpo_beta
130130
131131 # EMA of gradient magnitudes for adaptive normalization
132- self ._grad_magnitude_ema_down = torch .nn . Parameter ( torch . tensor (1.0 ), requires_grad = False )
133- self ._grad_magnitude_ema_up = torch .nn . Parameter ( torch . tensor (1.0 ), requires_grad = False )
132+ self .register_buffer ( ' _grad_magnitude_ema_down' , torch .tensor (1.0 ), persistent = False )
133+ self .register_buffer ( ' _grad_magnitude_ema_up' , torch .tensor (1.0 ), persistent = False )
134134
135135 self .optimizer : torch .optim .Optimizer | None = None
136136
@@ -322,24 +322,23 @@ def update_grad_norms(self):
322322 def update_gradient_ema (self ):
323323 """
324324 Update EMA of gradient magnitudes for adaptive perturbation normalization
325-
326325 Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
327326 """
328327 if self .mgpo_beta is None :
329328 return
330-
329+
331330 # Update EMA for lora_down gradient magnitude
332331 if self .lora_down .weight .grad is not None :
333332 current_grad_norm = torch .norm (self .lora_down .weight .grad , p = 2 )
334- self ._grad_magnitude_ema_down .data = (
335- self . mgpo_beta * self . _grad_magnitude_ema_down . data + (1 - self .mgpo_beta ) * current_grad_norm
333+ self ._grad_magnitude_ema_down .mul_ ( self . mgpo_beta ). add_ (
334+ current_grad_norm , alpha = (1 - self .mgpo_beta )
336335 )
337-
336+
338337 # Update EMA for lora_up gradient magnitude
339338 if self .lora_up .weight .grad is not None :
340339 current_grad_norm = torch .norm (self .lora_up .weight .grad , p = 2 )
341- self ._grad_magnitude_ema_up .data = (
342- self . mgpo_beta * self . _grad_magnitude_ema_up . data + (1 - self .mgpo_beta ) * current_grad_norm
340+ self ._grad_magnitude_ema_up .mul_ ( self . mgpo_beta ). add_ (
341+ current_grad_norm , alpha = (1 - self .mgpo_beta )
343342 )
344343
345344 @torch .no_grad ()
0 commit comments