Skip to content

Commit 85fcee0

Browse files
committed
address #293
1 parent 7a77b45 commit 85fcee0

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

denoising_diffusion_pytorch/karras_unet.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def forward(self, x):
124124
# forced weight normed conv2d and linear
125125
# algorithm 1 in paper
126126

127+
def normalize_weight(weight, eps = 1e-4):
128+
weight, ps = pack_one(weight, 'o *')
129+
normed_weight = l2norm(weight, eps = eps)
130+
normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
131+
return unpack_one(normed_weight, ps, 'o *')
132+
127133
class Conv2d(Module):
128134
def __init__(
129135
self,
@@ -142,14 +148,13 @@ def __init__(
142148
self.concat_ones_to_input = concat_ones_to_input
143149

144150
def forward(self, x):
151+
145152
if self.training:
146153
with torch.no_grad():
147-
weight, ps = pack_one(self.weight, 'o *')
148-
normed_weight = l2norm(weight, eps = self.eps)
149-
normed_weight = unpack_one(normed_weight, ps, 'o *')
154+
normed_weight = normalize_weight(self.weight, eps = self.eps)
150155
self.weight.copy_(normed_weight)
151156

152-
weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in)
157+
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
153158

154159
if self.concat_ones_to_input:
155160
x = F.pad(x, (0, 0, 0, 0, 1, 0), value = 1.)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.10'
1+
__version__ = '1.10.11'

0 commit comments

Comments
 (0)