Skip to content

Commit b585827

Browse files
committed
feature: implement reg_noise
1 parent 416d91b commit b585827

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pytorch_optimizer/optimizer/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,19 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
278278
if d != dim:
279279
x = x.max(dim=d, keepdim=True).values
280280
return x
281+
282+
283+
def reg_noise(
284+
network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4
285+
) -> torch.Tensor | float:
286+
reg_coef: float = 0.5 / (eta * num_data)
287+
noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature)
288+
289+
loss = 0
290+
for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True):
291+
reg = torch.sub(param1, param2).pow_(2) * reg_coef
292+
noise1 = param1 * torch.randn_like(param1) * noise_coef
293+
noise2 = param2 * torch.randn_like(param2) * noise_coef
294+
loss += torch.sum(reg - noise1 - noise2)
295+
296+
return loss

0 commit comments

Comments
 (0)