-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
I have been working through your code trying to get it working, and I believe I found an issue when you set the time_dim for the temporal layers here:
def set_time_dim_(
klasses: Tuple[Type[Module]],
model: Module,
time_dim: int
):
for model in model.modules():
if isinstance(model, klasses):
model.time_dim = time_dimYou are setting the same time_dim for all of layers, but the size of the temporal dimension is cut in half after each step in the UNet. Because of this, the model crashes when trying to reshape/rearrange the tensors for intermediate layers (for instance here (maybe others as well?):
if is_video:
batch_size = x.shape[0]
x = rearrange(x, 'b c t h w -> b h w t c')
else:
assert exists(batch_size) or exists(self.time_dim)
rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs))I am working on my on workaround in the same set_time_dim function but thought I would report it in case it is helpful.
lucidrains
Metadata
Metadata
Assignees
Labels
No labels