Skip to content

Commit da8c40a

Browse files
committed
feature: improve reg_noise performance
1 parent 0035722 commit da8c40a

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def to_real(x: torch.Tensor) -> torch.Tensor:
6262
return x.real if torch.is_complex(x) else x
6363

6464

65-
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8):
65+
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8) -> None:
6666
r"""Normalize gradient with stddev.
6767
6868
:param x: torch.Tensor. gradient.
@@ -315,6 +315,7 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
315315
return x
316316

317317

318+
@torch.no_grad()
318319
def reg_noise(
319320
network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4
320321
) -> Union[torch.Tensor, float]:
@@ -332,11 +333,15 @@ def reg_noise(
332333
reg_coef: float = 0.5 / (eta * num_data)
333334
noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature)
334335

335-
loss = 0
336-
for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True):
337-
reg = torch.sub(param1, param2).pow_(2) * reg_coef
338-
noise1 = param1 * torch.randn_like(param1) * noise_coef
339-
noise2 = param2 * torch.randn_like(param2) * noise_coef
340-
loss += torch.sum(reg - noise1 - noise2)
336+
loss = torch.tensor(0.0, device=next(network1.parameters()).device)
337+
338+
for param1, param2 in zip(network1.parameters(), network2.parameters()):
339+
reg = (param1 - param2).pow_(2).mul_(reg_coef).sum()
340+
341+
noise = param1 * torch.randn_like(param1)
342+
noise.add_(param2 * torch.randn_like(param2))
343+
344+
loss.add_(reg - noise.mul_(noise_coef).sum())
341345

342346
return loss
347+

0 commit comments

Comments
 (0)