Skip to content

Commit 1b26c9f

Browse files
committed
address #4
1 parent b806d16 commit 1b26c9f

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

lumiere_pytorch/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
ConvolutionInflationBlock,
33
AttentionInflationBlock,
44
TemporalDownsample,
5-
TemporalUpsample,
6-
set_time_dim_
5+
TemporalUpsample
76
)
87

98
from lumiere_pytorch.lumiere import Lumiere

lumiere_pytorch/lumiere.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch.nn.functional as F
1717

1818
from beartype import beartype
19-
from beartype.typing import List, Tuple, Optional, Type
19+
from beartype.typing import List, Tuple, Optional, Type, Any
2020

2121
from einops import rearrange, pack, unpack, repeat
2222

@@ -92,14 +92,15 @@ def freeze_all_layers_(module):
9292

9393
# function that takes in the entire text-to-video network, and sets the time dimension
9494

95-
def set_time_dim_(
95+
def set_attr_on_klasses_(
9696
klasses: Tuple[Type[Module]],
9797
model: Module,
98-
time_dim: int
98+
attr_name: str,
99+
value: Any
99100
):
100101
for model in model.modules():
101102
if isinstance(model, klasses):
102-
model.time_dim = time_dim
103+
setattr(model, attr_name, value)
103104

104105
# decorator for residual
105106

@@ -135,7 +136,7 @@ def inner(
135136
batch_size = x.shape[0]
136137
x = rearrange(x, 'b c t h w -> b h w c t')
137138
else:
138-
assert exists(batch_size) or exists(self.time_dim)
139+
batch_size = default(batch_size, self.batch_dim)
139140
rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
140141
x = rearrange(x, '(b t) c h w -> b h w c t', **compact_values(rearrange_kwargs))
141142

@@ -212,6 +213,7 @@ def __init__(
212213
time_dim = None
213214
):
214215
super().__init__()
216+
self.batch_dim = None
215217
self.time_dim = time_dim
216218
self.channel_last = channel_last
217219

@@ -236,6 +238,7 @@ def __init__(
236238
time_dim = None
237239
):
238240
super().__init__()
241+
self.batch_dim = None
239242
self.time_dim = time_dim
240243
self.channel_last = channel_last
241244

@@ -267,6 +270,7 @@ def __init__(
267270
assert is_odd(conv2d_kernel_size)
268271
assert is_odd(conv1d_kernel_size)
269272

273+
self.batch_dim = None
270274
self.time_dim = time_dim
271275
self.channel_last = channel_last
272276

@@ -302,6 +306,7 @@ def forward(
302306

303307
x = self.spatial_conv(x)
304308

309+
batch_size = default(batch_size, self.batch_dim)
305310
rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))
306311

307312
assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
@@ -335,6 +340,7 @@ def __init__(
335340
):
336341
super().__init__()
337342

343+
self.batch_dim = None
338344
self.time_dim = time_dim
339345
self.channel_last = channel_last
340346

@@ -376,6 +382,7 @@ def forward(
376382
batch_size = x.shape[0]
377383
x = rearrange(x, 'b c t h w -> b h w t c')
378384
else:
385+
batch_size = default(batch_size, self.batch_dim)
379386
assert exists(batch_size) or exists(self.time_dim)
380387

381388
rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
@@ -579,7 +586,7 @@ def forward(
579586

580587
# set the correct time dimension for all temporal layers
581588

582-
set_time_dim_(self.temporal_klasses, self, time)
589+
set_attr_on_klasses_(self.temporal_klasses, self, 'batch_dim', batch)
583590

584591
# forward all images into text-to-image model
585592

lumiere_pytorch/mp_lumiere.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def __init__(
290290
dropout = 0.
291291
):
292292
super().__init__()
293+
self.batch_dim = None
293294
self.time_dim = time_dim
294295
self.channel_last = channel_last
295296

@@ -327,6 +328,7 @@ def forward(
327328

328329
x = self.spatial_conv(x)
329330

331+
batch_size = default(batch_size, self.batch_dim)
330332
rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))
331333

332334
assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
@@ -388,13 +390,14 @@ def forward(
388390
batch_size = None
389391
):
390392
is_video = x.ndim == 5
393+
394+
batch_size = default(batch_size, self.batch_dim)
391395
assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'
392396

393397
if self.channel_last:
394398
x = rearrange(x, 'b ... c -> b c ...')
395399

396400
if is_video:
397-
batch_size = x.shape[0]
398401
x = rearrange(x, 'b c t h w -> b h w t c')
399402
else:
400403
assert exists(batch_size) or exists(self.time_dim)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lumiere-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.20',
6+
version = '0.0.21',
77
license='MIT',
88
description = 'Lumiere',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)