Skip to content

Commit 158c91b

Browse files
committed
complete magnitude preserving temporal unet layers for space-time karras unet
1 parent cbb294d commit 158c91b

File tree

3 files changed

+79
-18
lines changed

3 files changed

+79
-18
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ assert noised_video.shape == denoised_video.shape
7575
- [x] expose only temporal parameters for learning, freeze everything else
7676
- [x] figure out the best way to deal with the time conditioning after temporal downsampling - instead of pytree transform at the beginning, probably will need to hook into all the modules and inspect the batch sizes
7777
- [x] handle middle modules that may have output shape as `(batch, seq, dim)`
78+
- [x] following the conclusions of Tero Karras, improvise a variant of the 4 modules with magnitude preservation
7879

79-
- [ ] following the conclusions of Tero Karras, improvise a variant of the 4 modules with magnitude preservation
8080
- [ ] test out on <a href="https://github.com/lucidrains/imagen-pytorch">imagen-pytorch</a>
8181

8282
## Citations

lumiere_pytorch/mp_lumiere.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def compact_values(d: dict):
4949
def l2norm(t, dim = -1, eps = 1e-12):
5050
return F.normalize(t, dim = dim, eps = eps)
5151

52+
def interpolate_1d(x, length, mode = 'bilinear'):
53+
x = rearrange(x, 'b c t -> b c t 1')
54+
x = F.interpolate(x, (length, 1), mode = mode)
55+
return rearrange(x, 'b c t 1 -> b c t')
56+
5257
# mp activations
5358
# section 2.5
5459

@@ -85,6 +90,65 @@ def forward(self, x):
8590
weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in)
8691
return F.linear(x, weight)
8792

93+
# forced weight normed conv2d and linear
94+
# algorithm 1 in paper
95+
96+
class Conv2d(Module):
97+
def __init__(
98+
self,
99+
dim_in,
100+
dim_out,
101+
kernel_size,
102+
eps = 1e-4
103+
):
104+
super().__init__()
105+
weight = torch.randn(dim_out, dim_in, kernel_size, kernel_size)
106+
self.weight = nn.Parameter(weight)
107+
108+
self.eps = eps
109+
self.fan_in = dim_in * kernel_size ** 2
110+
111+
def forward(self, x):
112+
if self.training:
113+
with torch.no_grad():
114+
weight, ps = pack_one(self.weight, 'o *')
115+
normed_weight = l2norm(weight, eps = self.eps)
116+
normed_weight = unpack_one(normed_weight, ps, 'o *')
117+
self.weight.copy_(normed_weight)
118+
119+
weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in)
120+
return F.conv2d(x, weight, padding = 'same')
121+
122+
class Conv1d(Module):
123+
def __init__(
124+
self,
125+
dim_in,
126+
dim_out,
127+
kernel_size,
128+
eps = 1e-4,
129+
init_dirac = False
130+
):
131+
super().__init__()
132+
weight = torch.randn(dim_out, dim_in, kernel_size)
133+
self.weight = nn.Parameter(weight)
134+
135+
if init_dirac:
136+
nn.init.dirac_(self.weight)
137+
138+
self.eps = eps
139+
self.fan_in = dim_in * kernel_size
140+
141+
def forward(self, x):
142+
if self.training:
143+
with torch.no_grad():
144+
weight, ps = pack_one(self.weight, 'o *')
145+
normed_weight = l2norm(weight, eps = self.eps)
146+
normed_weight = unpack_one(normed_weight, ps, 'o *')
147+
self.weight.copy_(normed_weight)
148+
149+
weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in)
150+
return F.conv1d(x, weight, padding = 'same')
151+
88152
# pixelnorm
89153
# equation (30)
90154

@@ -183,18 +247,18 @@ def __init__(
183247
super().__init__()
184248
self.time_dim = time_dim
185249
self.channel_last = channel_last
186-
187-
self.conv = nn.Conv1d(dim, dim, kernel_size = 3, stride = 2, padding = 1)
188-
init_bilinear_kernel_1d_(self.conv)
250+
self.conv = Conv1d(dim, dim, 3, init_dirac = True)
189251

190252
@handle_maybe_channel_last
191253
@image_or_video_to_time
192254
def forward(
193255
self,
194256
x
195257
):
196-
assert x.shape[-1] > 1, 'time dimension must be greater than 1 to be compressed'
258+
t = x.shape[-1]
259+
assert t > 1, 'time dimension must be greater than 1 to be compressed'
197260

261+
x = interpolate_1d(x, t // 2)
198262
return self.conv(x)
199263

200264
class MPTemporalUpsample(Module):
@@ -207,16 +271,16 @@ def __init__(
207271
super().__init__()
208272
self.time_dim = time_dim
209273
self.channel_last = channel_last
210-
211-
self.conv = nn.ConvTranspose1d(dim, dim, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
212-
init_bilinear_kernel_1d_(self.conv)
274+
self.conv = Conv1d(dim, dim, 3, init_dirac = True)
213275

214276
@handle_maybe_channel_last
215277
@image_or_video_to_time
216278
def forward(
217279
self,
218280
x
219281
):
282+
t = x.shape[-1]
283+
x = interpolate_1d(x, t * 2)
220284
return self.conv(x)
221285

222286
# main modules
@@ -233,26 +297,23 @@ def __init__(
233297
mp_add_t = 0.3
234298
):
235299
super().__init__()
236-
assert is_odd(conv2d_kernel_size)
237-
assert is_odd(conv1d_kernel_size)
238-
239300
self.time_dim = time_dim
240301
self.channel_last = channel_last
241302

242303
self.spatial_conv = nn.Sequential(
243-
nn.Conv2d(dim, dim, conv2d_kernel_size, padding = conv2d_kernel_size // 2),
304+
Conv2d(dim, dim, conv2d_kernel_size, 3),
244305
MPSiLU()
245306
)
246307

247308
self.temporal_conv = nn.Sequential(
248-
nn.Conv1d(dim, dim, conv1d_kernel_size, padding = conv1d_kernel_size // 2),
309+
Conv1d(dim, dim, conv1d_kernel_size, 3),
249310
MPSiLU()
250311
)
251312

252-
self.proj_out = nn.Conv1d(dim, dim, 1)
253-
254-
nn.init.zeros_(self.proj_out.weight)
255-
nn.init.zeros_(self.proj_out.bias)
313+
self.proj_out = nn.Sequential(
314+
Conv1d(dim, dim, 1),
315+
Gain()
316+
)
256317

257318
self.residual_mp_add = MPAdd(t = mp_add_t)
258319

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lumiere-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.16',
6+
version = '0.0.17',
77
license='MIT',
88
description = 'Lumiere',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)