Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/autoencoders/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def forward(

sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
upscale_dtype = self.up_blocks[0].resnets[0].norm1.weight.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think current failing tests in the CI are due to the fact that not every decoder block has a norm1 with a weight. Hence the use of the generator here to avoid such cases.

@ppadjinTT I noticed you initially used self.conv_out.weight here? What was the issue you ran into with that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I will change that too, tnx! I intially changed the self.conv_out.weight because there are some tests that check what happens when conv_out and upscale_blocks have different dtypes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point me to those tests? Seems like setting to conv_out is more robust.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, these are the tests pytest -svvv tests/models/autoencoders/test_models_autoencoder_kl.py

This is one of the tests from this test set that fails tests/models/autoencoders/test_models_autoencoder_kl.py::AutoencoderKLTests::test_layerwise_casting_inference

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added better logic for inferring dtype, to capture the case where it doesn't work

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we can remove upscale_type entirely here. I think all tests should still pass without it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay let's try that, i'm pushing the change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's okay now? @DN6

if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
Expand Down
Loading