Skip to content

Invalid API call in Cosmos VAE Encoder #12025

@akhilg-nv

Description

@akhilg-nv

Describe the bug

At this line, the reshape appears to be called incorrectly. I believe it should either be hidden_states.reshape(*) or torch.reshape(hidden_states, *)

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py#L171

It appears like this part of the model is not used in the text2world pipeline so it doesn't raise an issue, but trying to run this line in local testing results in an error.

Reproduction

>>> hidden_states = torch.reshape(batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p)
Traceback (most recent call last):                                                                                                     
  File "<stdin>", line 1, in <module>
TypeError: reshape() takes 2 positional arguments but 8 were given

Logs

System Info

Python 3.9

Who can help?

@DN6 @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions