77import torch
88from torch import nn
99from torch .distributed import all_reduce
10- from torch .nn import functional as f
10+ from torch .nn . functional import cosine_similarity
1111from torch .nn .modules .batchnorm import _BatchNorm
1212from torch .nn .utils import clip_grad_norm_
1313
@@ -62,7 +62,7 @@ def to_real(x: torch.Tensor) -> torch.Tensor:
6262 return x .real if torch .is_complex (x ) else x
6363
6464
65- def normalize_gradient (x : torch .Tensor , use_channels : bool = False , epsilon : float = 1e-8 ):
65+ def normalize_gradient (x : torch .Tensor , use_channels : bool = False , epsilon : float = 1e-8 ) -> None :
6666 r"""Normalize gradient with stddev.
6767
6868 :param x: torch.Tensor. gradient.
@@ -119,7 +119,7 @@ def cosine_similarity_by_view(
119119 """
120120 x = view_func (x )
121121 y = view_func (y )
122- return f . cosine_similarity (x , y , dim = 1 , eps = eps ).abs_ ()
122+ return cosine_similarity (x , y , dim = 1 , eps = eps ).abs_ ()
123123
124124
125125def clip_grad_norm (
@@ -315,6 +315,7 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
315315 return x
316316
317317
318+ @torch .no_grad ()
318319def reg_noise (
319320 network1 : nn .Module , network2 : nn .Module , num_data : int , lr : float , eta : float = 8e-3 , temperature : float = 1e-4
320321) -> Union [torch .Tensor , float ]:
@@ -332,11 +333,14 @@ def reg_noise(
332333 reg_coef : float = 0.5 / (eta * num_data )
333334 noise_coef : float = math .sqrt (2.0 / lr / num_data * temperature )
334335
335- loss = 0
336- for param1 , param2 in zip (network1 .parameters (), network2 .parameters (), strict = True ):
337- reg = torch .sub (param1 , param2 ).pow_ (2 ) * reg_coef
338- noise1 = param1 * torch .randn_like (param1 ) * noise_coef
339- noise2 = param2 * torch .randn_like (param2 ) * noise_coef
340- loss += torch .sum (reg - noise1 - noise2 )
336+ loss = torch .tensor (0.0 , device = next (network1 .parameters ()).device )
337+
338+ for param1 , param2 in zip (network1 .parameters (), network2 .parameters ()):
339+ reg = (param1 - param2 ).pow_ (2 ).mul_ (reg_coef ).sum ()
340+
341+ noise = param1 * torch .randn_like (param1 )
342+ noise .add_ (param2 * torch .randn_like (param2 ))
343+
344+ loss .add_ (reg - noise .mul_ (noise_coef ).sum ())
341345
342346 return loss
0 commit comments