Skip to content

Commit 1084c6f

Browse files
committed
Fix _grad_magnitude_ema_up _grad_magnitude_ema_down getting saved to LoRA
1 parent cd73cad commit 1084c6f

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

networks/lora_flux.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def __init__(
129129
self.mgpo_beta = mgpo_beta
130130

131131
# EMA of gradient magnitudes for adaptive normalization
132-
self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
133-
self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
132+
self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False)
133+
self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False)
134134

135135
self.optimizer: torch.optim.Optimizer | None = None
136136

@@ -322,24 +322,23 @@ def update_grad_norms(self):
322322
def update_gradient_ema(self):
323323
"""
324324
Update EMA of gradient magnitudes for adaptive perturbation normalization
325-
326325
Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
327326
"""
328327
if self.mgpo_beta is None:
329328
return
330-
329+
331330
# Update EMA for lora_down gradient magnitude
332331
if self.lora_down.weight.grad is not None:
333332
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
334-
self._grad_magnitude_ema_down.data = (
335-
self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm
333+
self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_(
334+
current_grad_norm, alpha=(1 - self.mgpo_beta)
336335
)
337-
336+
338337
# Update EMA for lora_up gradient magnitude
339338
if self.lora_up.weight.grad is not None:
340339
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
341-
self._grad_magnitude_ema_up.data = (
342-
self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm
340+
self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_(
341+
current_grad_norm, alpha=(1 - self.mgpo_beta)
343342
)
344343

345344
@torch.no_grad()

0 commit comments

Comments
 (0)