@@ -124,6 +124,12 @@ def forward(self, x):
124124# forced weight normed conv2d and linear
125125# algorithm 1 in paper
126126
127+ def normalize_weight (weight , eps = 1e-4 ):
128+ weight , ps = pack_one (weight , 'o *' )
129+ normed_weight = l2norm (weight , eps = eps )
130+ normed_weight = normed_weight * sqrt (weight .numel () / weight .shape [0 ])
131+ return unpack_one (normed_weight , ps , 'o *' )
132+
127133class Conv2d (Module ):
128134 def __init__ (
129135 self ,
@@ -142,14 +148,13 @@ def __init__(
142148 self .concat_ones_to_input = concat_ones_to_input
143149
144150 def forward (self , x ):
151+
145152 if self .training :
146153 with torch .no_grad ():
147- weight , ps = pack_one (self .weight , 'o *' )
148- normed_weight = l2norm (weight , eps = self .eps )
149- normed_weight = unpack_one (normed_weight , ps , 'o *' )
154+ normed_weight = normalize_weight (self .weight , eps = self .eps )
150155 self .weight .copy_ (normed_weight )
151156
152- weight = l2norm (self .weight , eps = self .eps ) / sqrt (self .fan_in )
157+ weight = normalize_weight (self .weight , eps = self .eps ) / sqrt (self .fan_in )
153158
154159 if self .concat_ones_to_input :
155160 x = F .pad (x , (0 , 0 , 0 , 0 , 1 , 0 ), value = 1. )
0 commit comments