Skip to content

Commit aa69b18

Browse files
committed
remove weight standardization, as somehow it does not play well with amp
1 parent 5aeacd4 commit aa69b18

File tree

6 files changed

+6
-75
lines changed

6 files changed

+6
-75
lines changed

README.md

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,6 @@ You could consider adding a suitable metric to the training loop yourself after
229229
}
230230
```
231231

232-
```bibtex
233-
@article{Qiao2019WeightS,
234-
title = {Weight Standardization},
235-
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
236-
journal = {ArXiv},
237-
year = {2019},
238-
volume = {abs/1903.10520}
239-
}
240-
```
241-
242232
```bibtex
243233
@article{Salimans2022ProgressiveDF,
244234
title = {Progressive Distillation for Fast Sampling of Diffusion Models},

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,6 @@ def Upsample(dim, dim_out = None):
9393
def Downsample(dim, dim_out = None):
9494
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
9595

96-
class WeightStandardizedConv2d(nn.Conv2d):
97-
"""
98-
https://arxiv.org/abs/1903.10520
99-
weight standardization purportedly works synergistically with group normalization
100-
"""
101-
def forward(self, x):
102-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
103-
104-
weight = self.weight
105-
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
106-
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
107-
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
108-
109-
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110-
11196
class RMSNorm(nn.Module):
11297
def __init__(self, dim):
11398
super().__init__()
@@ -164,7 +149,7 @@ def forward(self, x):
164149
class Block(nn.Module):
165150
def __init__(self, dim, dim_out, groups = 8):
166151
super().__init__()
167-
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
152+
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
168153
self.norm = nn.GroupNorm(groups, dim_out)
169154
self.act = nn.SiLU()
170155

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from torch import nn, einsum
11+
from torch.cuda.amp import autocast
1112
import torch.nn.functional as F
1213
from torch.utils.data import Dataset, DataLoader
1314

@@ -98,21 +99,6 @@ def Downsample(dim, dim_out = None):
9899
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
99100
)
100101

101-
class WeightStandardizedConv2d(nn.Conv2d):
102-
"""
103-
https://arxiv.org/abs/1903.10520
104-
weight standardization purportedly works synergistically with group normalization
105-
"""
106-
def forward(self, x):
107-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
108-
109-
weight = self.weight
110-
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
111-
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
112-
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
113-
114-
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
115-
116102
class RMSNorm(nn.Module):
117103
def __init__(self, dim):
118104
super().__init__()
@@ -169,7 +155,7 @@ def forward(self, x):
169155
class Block(nn.Module):
170156
def __init__(self, dim, dim_out, groups = 8):
171157
super().__init__()
172-
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
158+
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
173159
self.norm = nn.GroupNorm(groups, dim_out)
174160
self.act = nn.SiLU()
175161

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,6 @@ def Upsample(dim, dim_out = None):
8585
def Downsample(dim, dim_out = None):
8686
return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)
8787

88-
class WeightStandardizedConv2d(nn.Conv1d):
89-
"""
90-
https://arxiv.org/abs/1903.10520
91-
weight standardization purportedly works synergistically with group normalization
92-
"""
93-
def forward(self, x):
94-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
95-
96-
weight = self.weight
97-
mean = reduce(weight, 'o ... -> o 1 1', 'mean')
98-
var = reduce(weight, 'o ... -> o 1 1', partial(torch.var, unbiased = False))
99-
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
100-
101-
return F.conv1d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
102-
10388
class RMSNorm(nn.Module):
10489
def __init__(self, dim):
10590
super().__init__()
@@ -156,7 +141,7 @@ def forward(self, x):
156141
class Block(nn.Module):
157142
def __init__(self, dim, dim_out, groups = 8):
158143
super().__init__()
159-
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
144+
self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1)
160145
self.norm = nn.GroupNorm(groups, dim_out)
161146
self.act = nn.SiLU()
162147

denoising_diffusion_pytorch/guided_diffusion.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,6 @@ def Downsample(dim, dim_out = None):
9393
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
9494
)
9595

96-
class WeightStandardizedConv2d(nn.Conv2d):
97-
"""
98-
https://arxiv.org/abs/1903.10520
99-
weight standardization purportedly works synergistically with group normalization
100-
"""
101-
def forward(self, x):
102-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
103-
104-
weight = self.weight
105-
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
106-
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
107-
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
108-
109-
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110-
11196
class RMSNorm(nn.Module):
11297
def __init__(self, dim):
11398
super().__init__()
@@ -164,7 +149,7 @@ def forward(self, x):
164149
class Block(nn.Module):
165150
def __init__(self, dim, dim_out, groups = 8):
166151
super().__init__()
167-
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
152+
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
168153
self.norm = nn.GroupNorm(groups, dim_out)
169154
self.act = nn.SiLU()
170155

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.7.1'
1+
__version__ = '1.7.3'

0 commit comments

Comments
 (0)