Skip to content

Commit 3dcd22b

Browse files
committed
1 parent 0a642ce commit 3dcd22b

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

lumiere_pytorch/lumiere.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,20 @@ def __init__(self, fn: Module):
193193
def forward(self, t, *args, **kwargs):
194194
return self.fn(t, *args, **kwargs) + t
195195

196+
# channel rmsnorm
197+
198+
class ChanFirstRMSNorm(Module):
199+
def __init__(self, dim):
200+
super().__init__()
201+
self.scale = dim ** 0.5
202+
self.gamma = nn.Parameter(torch.ones(dim))
203+
204+
def forward(self, x):
205+
assert x.ndim > 2
206+
dims = (1,) * (x.ndim - 2)
207+
gamma = self.gamma.reshape(-1, *dims)
208+
return F.normalize(x, dim = -1) * self.scale * gamma
209+
196210
# temporal down and upsample
197211

198212
def init_bilinear_kernel_1d_(conv: Module):
@@ -262,7 +276,6 @@ def __init__(
262276
dim,
263277
conv2d_kernel_size = 3,
264278
conv1d_kernel_size = 3,
265-
groups = 8,
266279
channel_last = False,
267280
time_dim = None
268281
):
@@ -276,13 +289,13 @@ def __init__(
276289

277290
self.spatial_conv = nn.Sequential(
278291
nn.Conv2d(dim, dim, conv2d_kernel_size, padding = conv2d_kernel_size // 2),
279-
nn.GroupNorm(groups, num_channels = dim),
292+
ChanFirstRMSNorm(dim),
280293
nn.SiLU()
281294
)
282295

283296
self.temporal_conv = nn.Sequential(
284297
nn.Conv1d(dim, dim, conv1d_kernel_size, padding = conv1d_kernel_size // 2),
285-
nn.GroupNorm(groups, num_channels = dim),
298+
ChanFirstRMSNorm(dim),
286299
nn.SiLU()
287300
)
288301

@@ -373,6 +386,8 @@ def forward(
373386
batch_size = None
374387
):
375388
is_video = x.ndim == 5
389+
390+
batch_size = default(batch_size, self.batch_dim)
376391
assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'
377392

378393
if self.channel_last:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lumiere-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.22',
6+
version = '0.0.23',
77
license='MIT',
88
description = 'Lumiere',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)