Skip to content

Incorrect time_dim for intermediate temporal layersΒ #4

@mlomnitz

Description

@mlomnitz

I have been working through your code trying to get it working, and I believe I found an issue when you set the time_dim for the temporal layers here:

def set_time_dim_(
    klasses: Tuple[Type[Module]],
    model: Module,
    time_dim: int
):
    for model in model.modules():
        if isinstance(model, klasses):
            model.time_dim = time_dim

You are setting the same time_dim for all of layers, but the size of the temporal dimension is cut in half after each step in the UNet. Because of this, the model crashes when trying to reshape/rearrange the tensors for intermediate layers (for instance here (maybe others as well?):

if is_video:
    batch_size = x.shape[0]
    x = rearrange(x, 'b c t h w -> b h w t c')
else:
    assert exists(batch_size) or exists(self.time_dim)

    rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
    x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs))

I am working on my on workaround in the same set_time_dim function but thought I would report it in case it is helpful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions