Skip to content

Commit 16d69cd

Browse files
committed
handle channel last, for some unet middles may already be in that format for transformer
1 parent 0e74386 commit 16d69cd

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ assert noised_video.shape == denoised_video.shape
7474

7575
- [x] expose only temporal parameters for learning, freeze everything else
7676
- [x] figure out the best way to deal with the time conditioning after temporal downsampling - instead of pytree transform at the beginning, probably will need to hook into all the modules and inspect the batch sizes
77+
- [x] handle middle modules that may have output shape as `(batch, seq, dim)`
7778

78-
- [ ] handle shapes of `(batch, seq, dim)` as well as channel last
7979
- [ ] following the conclusions of Tero Karras, improvise a variant of the 4 modules with magnitude preservation
8080
- [ ] test out on <a href="https://github.com/lucidrains/imagen-pytorch">imagen-pytorch</a>
8181

@@ -100,4 +100,3 @@ assert noised_video.shape == denoised_video.shape
100100
url = {https://api.semanticscholar.org/CorpusID:265659032}
101101
}
102102
```
103-

lumiere_pytorch/lumiere_pytorch.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,30 @@ def inner(
134134

135135
return inner
136136

137+
# handle channel last
138+
139+
def handle_maybe_channel_last(fn):
140+
141+
@wraps(fn)
142+
def inner(
143+
self,
144+
x,
145+
*args,
146+
**kwargs
147+
):
148+
149+
if self.channel_last:
150+
x = rearrange(x, 'b c ... -> b ... c')
151+
152+
out = fn(self, x, *args, **kwargs)
153+
154+
if self.channel_last:
155+
out = rearrange(out, 'b c ... -> b ... c')
156+
157+
return out
158+
159+
return inner
160+
137161
# helpers
138162

139163
def Sequential(*modules):
@@ -164,14 +188,17 @@ class TemporalDownsample(Module):
164188
def __init__(
165189
self,
166190
dim,
191+
channel_last = False,
167192
time_dim = None
168193
):
169194
super().__init__()
170195
self.time_dim = time_dim
196+
self.channel_last = channel_last
171197

172198
self.conv = nn.Conv1d(dim, dim, kernel_size = 3, stride = 2, padding = 1)
173199
init_bilinear_kernel_1d_(self.conv)
174200

201+
@handle_maybe_channel_last
175202
@image_or_video_to_time
176203
def forward(
177204
self,
@@ -185,14 +212,17 @@ class TemporalUpsample(Module):
185212
def __init__(
186213
self,
187214
dim,
215+
channel_last = False,
188216
time_dim = None
189217
):
190218
super().__init__()
191219
self.time_dim = time_dim
220+
self.channel_last = channel_last
192221

193222
self.conv = nn.ConvTranspose1d(dim, dim, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
194223
init_bilinear_kernel_1d_(self.conv)
195224

225+
@handle_maybe_channel_last
196226
@image_or_video_to_time
197227
def forward(
198228
self,
@@ -210,13 +240,15 @@ def __init__(
210240
conv2d_kernel_size = 3,
211241
conv1d_kernel_size = 3,
212242
groups = 8,
243+
channel_last = False,
213244
time_dim = None
214245
):
215246
super().__init__()
216247
assert is_odd(conv2d_kernel_size)
217248
assert is_odd(conv1d_kernel_size)
218249

219250
self.time_dim = time_dim
251+
self.channel_last = channel_last
220252

221253
self.spatial_conv = nn.Sequential(
222254
nn.Conv2d(dim, dim, conv2d_kernel_size, padding = conv2d_kernel_size // 2),
@@ -235,6 +267,7 @@ def __init__(
235267
nn.init.zeros_(self.proj_out.weight)
236268
nn.init.zeros_(self.proj_out.bias)
237269

270+
@handle_maybe_channel_last
238271
def forward(
239272
self,
240273
x,
@@ -277,11 +310,13 @@ def __init__(
277310
prenorm = True,
278311
residual_attn = True,
279312
time_dim = None,
313+
channel_last = False,
280314
**attn_kwargs
281315
):
282316
super().__init__()
283317

284318
self.time_dim = time_dim
319+
self.channel_last = channel_last
285320

286321
self.temporal_attns = ModuleList([])
287322

@@ -304,6 +339,7 @@ def __init__(
304339
nn.init.zeros_(self.proj_out.weight)
305340
nn.init.zeros_(self.proj_out.bias)
306341

342+
@handle_maybe_channel_last
307343
def forward(
308344
self,
309345
x,
@@ -312,6 +348,9 @@ def forward(
312348
is_video = x.ndim == 5
313349
assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'
314350

351+
if self.channel_last:
352+
x = rearrange(x, 'b ... c -> b c ...')
353+
315354
if is_video:
316355
batch_size = x.shape[0]
317356
x = rearrange(x, 'b c t h w -> b h w t c')
@@ -339,6 +378,9 @@ def forward(
339378
else:
340379
x = rearrange(x, 'b h w t c -> (b t) c h w')
341380

381+
if self.channel_last:
382+
x = rearrange(x, 'b c ... -> b ... c')
383+
342384
return x
343385

344386
# post module hook wrapper
@@ -375,7 +417,9 @@ def __init__(
375417
upsample_module_names: List[str] = [],
376418
channels: int = 3,
377419
conv_inflation_kwargs: dict = dict(),
378-
attn_inflation_kwargs: dict = dict()
420+
attn_inflation_kwargs: dict = dict(),
421+
downsample_kwargs: dict = dict(),
422+
upsample_kwargs: dict = dict(),
379423
):
380424
super().__init__()
381425

@@ -421,8 +465,8 @@ def __init__(
421465

422466
self.convs = ModuleList([ConvolutionInflationBlock(dim = shape[1], **conv_inflation_kwargs) for shape in conv_shapes])
423467
self.attns = ModuleList([AttentionInflationBlock(dim = shape[1], **attn_inflation_kwargs) for shape in attn_shapes])
424-
self.downsamples = ModuleList([TemporalDownsample(dim = shape[1]) for shape in downsample_shapes])
425-
self.upsamples = ModuleList([TemporalUpsample(dim = shape[1]) for shape in upsample_shapes])
468+
self.downsamples = ModuleList([TemporalDownsample(dim = shape[1], **downsample_kwargs) for shape in downsample_shapes])
469+
self.upsamples = ModuleList([TemporalUpsample(dim = shape[1], **upsample_kwargs) for shape in upsample_shapes])
426470

427471
# insert all the temporal modules with hooks
428472

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lumiere-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.11',
6+
version = '0.0.14',
77
license='MIT',
88
description = 'Lumiere',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)