Skip to content

Commit dce4dc1

Browse files
committed
just move towards using rmsnorm, given success of llama. and to avoid using pytorch layernorm, which has issues
1 parent 995f1c5 commit dce4dc1

File tree

6 files changed

+34
-49
lines changed

6 files changed

+34
-49
lines changed

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,26 +104,23 @@ def forward(self, x):
104104
weight = self.weight
105105
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
106106
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
107-
normalized_weight = (weight - mean) * (var + eps).rsqrt()
107+
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
108108

109109
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110110

111-
class LayerNorm(nn.Module):
111+
class RMSNorm(nn.Module):
112112
def __init__(self, dim):
113113
super().__init__()
114114
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
115115

116116
def forward(self, x):
117-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
118-
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
119-
mean = torch.mean(x, dim = 1, keepdim = True)
120-
return (x - mean) * (var + eps).rsqrt() * self.g
117+
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
121118

122119
class PreNorm(nn.Module):
123120
def __init__(self, dim, fn):
124121
super().__init__()
125122
self.fn = fn
126-
self.norm = LayerNorm(dim)
123+
self.norm = RMSNorm(dim)
127124

128125
def forward(self, x):
129126
x = self.norm(x)
@@ -220,7 +217,7 @@ def __init__(self, dim, heads = 4, dim_head = 32):
220217

221218
self.to_out = nn.Sequential(
222219
nn.Conv2d(hidden_dim, dim, 1),
223-
LayerNorm(dim)
220+
RMSNorm(dim)
224221
)
225222

226223
def forward(self, x):

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,26 +109,23 @@ def forward(self, x):
109109
weight = self.weight
110110
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
111111
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
112-
normalized_weight = (weight - mean) * (var + eps).rsqrt()
112+
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
113113

114114
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
115115

116-
class LayerNorm(nn.Module):
116+
class RMSNorm(nn.Module):
117117
def __init__(self, dim):
118118
super().__init__()
119119
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
120120

121121
def forward(self, x):
122-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
123-
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
124-
mean = torch.mean(x, dim = 1, keepdim = True)
125-
return (x - mean) * (var + eps).rsqrt() * self.g
122+
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
126123

127124
class PreNorm(nn.Module):
128125
def __init__(self, dim, fn):
129126
super().__init__()
130127
self.fn = fn
131-
self.norm = LayerNorm(dim)
128+
self.norm = RMSNorm(dim)
132129

133130
def forward(self, x):
134131
x = self.norm(x)
@@ -223,7 +220,7 @@ def __init__(self, dim, heads = 4, dim_head = 32):
223220

224221
self.to_out = nn.Sequential(
225222
nn.Conv2d(hidden_dim, dim, 1),
226-
LayerNorm(dim)
223+
RMSNorm(dim)
227224
)
228225

229226
def forward(self, x):

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,23 @@ def forward(self, x):
9696
weight = self.weight
9797
mean = reduce(weight, 'o ... -> o 1 1', 'mean')
9898
var = reduce(weight, 'o ... -> o 1 1', partial(torch.var, unbiased = False))
99-
normalized_weight = (weight - mean) * (var + eps).rsqrt()
99+
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
100100

101101
return F.conv1d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
102102

103-
class LayerNorm(nn.Module):
103+
class RMSNorm(nn.Module):
104104
def __init__(self, dim):
105105
super().__init__()
106106
self.g = nn.Parameter(torch.ones(1, dim, 1))
107107

108108
def forward(self, x):
109-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
110-
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
111-
mean = torch.mean(x, dim = 1, keepdim = True)
112-
return (x - mean) * (var + eps).rsqrt() * self.g
109+
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
113110

114111
class PreNorm(nn.Module):
115112
def __init__(self, dim, fn):
116113
super().__init__()
117114
self.fn = fn
118-
self.norm = LayerNorm(dim)
115+
self.norm = RMSNorm(dim)
119116

120117
def forward(self, x):
121118
x = self.norm(x)
@@ -210,7 +207,7 @@ def __init__(self, dim, heads = 4, dim_head = 32):
210207

211208
self.to_out = nn.Sequential(
212209
nn.Conv1d(hidden_dim, dim, 1),
213-
LayerNorm(dim)
210+
RMSNorm(dim)
214211
)
215212

216213
def forward(self, x):
@@ -868,9 +865,9 @@ def train(self):
868865
milestone = self.step // self.save_and_sample_every
869866
batches = num_to_groups(self.num_samples, self.batch_size)
870867
all_samples_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
871-
#
868+
872869
all_samples = torch.cat(all_samples_list, dim = 0)
873-
#
870+
874871
torch.save(all_samples, str(self.results_folder / f'sample-{milestone}.png'))
875872
self.save(milestone)
876873

denoising_diffusion_pytorch/guided_diffusion.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,26 +104,23 @@ def forward(self, x):
104104
weight = self.weight
105105
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
106106
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
107-
normalized_weight = (weight - mean) * (var + eps).rsqrt()
107+
normalized_weight = (weight - mean) * var.clamp(min = eps).rsqrt()
108108

109109
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110110

111-
class LayerNorm(nn.Module):
111+
class RMSNorm(nn.Module):
112112
def __init__(self, dim):
113113
super().__init__()
114114
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
115115

116116
def forward(self, x):
117-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
118-
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
119-
mean = torch.mean(x, dim = 1, keepdim = True)
120-
return (x - mean) * (var + eps).rsqrt() * self.g
117+
return F.normalize(x, dim = 1) * self.g * (x.shape[-1] ** 0.5)
121118

122119
class PreNorm(nn.Module):
123120
def __init__(self, dim, fn):
124121
super().__init__()
125122
self.fn = fn
126-
self.norm = LayerNorm(dim)
123+
self.norm = RMSNorm(dim)
127124

128125
def forward(self, x):
129126
x = self.norm(x)
@@ -218,7 +215,7 @@ def __init__(self, dim, heads = 4, dim_head = 32):
218215

219216
self.to_out = nn.Sequential(
220217
nn.Conv2d(hidden_dim, dim, 1),
221-
LayerNorm(dim)
218+
RMSNorm(dim)
222219
)
223220

224221
def forward(self, x):

denoising_diffusion_pytorch/simple_diffusion.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def Downsample(
8383
nn.Conv2d(dim * (factor ** 2), default(dim_out, dim), 1)
8484
)
8585

86-
class LayerNorm(nn.Module):
86+
class RMSNorm(nn.Module):
8787
def __init__(self, dim, scale = True, normalize_dim = 2):
8888
super().__init__()
8989
self.g = nn.Parameter(torch.ones(dim)) if scale else 1
@@ -94,11 +94,7 @@ def __init__(self, dim, scale = True, normalize_dim = 2):
9494
def forward(self, x):
9595
normalize_dim = self.normalize_dim
9696
scale = append_dims(self.g, x.ndim - self.normalize_dim - 1) if self.scale else 1
97-
98-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
99-
var = torch.var(x, dim = normalize_dim, unbiased = False, keepdim = True)
100-
mean = torch.mean(x, dim = normalize_dim, keepdim = True)
101-
return (x - mean) * var.clamp(min = eps).rsqrt() * scale
97+
return F.normalize(x, dim = normalize_dim) * scale * (x.shape[normalize_dim] ** 0.5)
10298

10399
# sinusoidal positional embeds
104100

@@ -169,12 +165,12 @@ def __init__(self, dim, heads = 4, dim_head = 32):
169165
self.heads = heads
170166
hidden_dim = dim_head * heads
171167

172-
self.norm = LayerNorm(dim, normalize_dim = 1)
168+
self.norm = RMSNorm(dim, normalize_dim = 1)
173169
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
174170

175171
self.to_out = nn.Sequential(
176172
nn.Conv2d(hidden_dim, dim, 1),
177-
LayerNorm(dim, normalize_dim = 1)
173+
RMSNorm(dim, normalize_dim = 1)
178174
)
179175

180176
def forward(self, x):
@@ -207,7 +203,7 @@ def __init__(self, dim, heads = 4, dim_head = 32, scale = 8, dropout = 0.):
207203
self.heads = heads
208204
hidden_dim = dim_head * heads
209205

210-
self.norm = LayerNorm(dim)
206+
self.norm = RMSNorm(dim)
211207

212208
self.attn_dropout = nn.Dropout(dropout)
213209
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)
@@ -247,7 +243,7 @@ def __init__(
247243
dropout = 0.
248244
):
249245
super().__init__()
250-
self.norm = LayerNorm(dim, scale = False)
246+
self.norm = RMSNorm(dim, scale = False)
251247
dim_hidden = dim * mult
252248

253249
self.to_scale_shift = nn.Sequential(
@@ -359,10 +355,11 @@ def __init__(
359355
self.init_conv = nn.Conv2d(channels, init_dim, patch_size, stride = patch_size)
360356
else:
361357
self.init_conv = nn.Sequential(
362-
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = patch_size, p2 = patch_size),
363-
LayerNorm(input_channels, normalize_dim = 1),
364-
nn.Conv2d(input_channels, init_dim, 1),
365-
LayerNorm(init_dim, normalize_dim = 1)
358+
Rearrange('b c (h p1) (w p2) -> b h w (c p1 p2)', p1 = patch_size, p2 = patch_size),
359+
nn.LayerNorm(input_channels),
360+
nn.Linear(input_channels, init_dim),
361+
nn.LayerNorm(init_dim),
362+
Rearrange('b h w c -> b c h w')
366363
)
367364

368365
self.unpatchify = nn.ConvTranspose2d(input_channels, channels, patch_size, stride = patch_size)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.4'
1+
__version__ = '1.7.1'

0 commit comments

Comments
 (0)