-
Couldn't load subscription status.
- Fork 6.5k
Add missing MochiEncoder3D.gradient_checkpointing attribute #11146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
a-r-r-o-w
merged 4 commits into
huggingface:main
from
mjkvaak-amd:add-mochiencoder3d-grad-chkpt-attribute
Apr 5, 2025
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
121 changes: 121 additions & 0 deletions
121
tests/models/autoencoders/test_models_autoencoder_mochi.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2024 HuggingFace Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| from diffusers import AutoencoderKLMochi | ||
| from diffusers.utils.testing_utils import ( | ||
| enable_full_determinism, | ||
| floats_tensor, | ||
| torch_device, | ||
| ) | ||
|
|
||
| from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin | ||
|
|
||
|
|
||
| enable_full_determinism() | ||
|
|
||
|
|
||
| class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): | ||
| model_class = AutoencoderKLMochi | ||
| main_input_name = "sample" | ||
| base_precision = 1e-2 | ||
|
|
||
|
|
||
| def get_autoencoder_kl_mochi_config(self): | ||
| return { | ||
| "in_channels": 15, | ||
| "out_channels": 3, | ||
| "latent_channels": 4, | ||
| "encoder_block_out_channels": (32, 32, 32, 32), | ||
| "decoder_block_out_channels": (32, 32, 32, 32), | ||
| "layers_per_block": (1, 1, 1, 1, 1), | ||
| "act_fn": "silu", | ||
| "scaling_factor": 1, | ||
| } | ||
|
|
||
| @property | ||
| def dummy_input(self): | ||
| batch_size = 2 | ||
| num_frames = 7 | ||
| num_channels = 3 | ||
| sizes = (16, 16) | ||
|
|
||
| image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) | ||
|
|
||
| return {"sample": image} | ||
|
|
||
| @property | ||
| def input_shape(self): | ||
| return (3, 7, 16, 16) | ||
|
|
||
| @property | ||
| def output_shape(self): | ||
| return (3, 7, 16, 16) | ||
|
|
||
| def prepare_init_args_and_inputs_for_common(self): | ||
| init_dict = self.get_autoencoder_kl_mochi_config() | ||
| inputs_dict = self.dummy_input | ||
| return init_dict, inputs_dict | ||
|
|
||
| def test_gradient_checkpointing_is_applied(self): | ||
| expected_set = { | ||
| "MochiDecoder3D", | ||
| "MochiDownBlock3D", | ||
| "MochiEncoder3D", | ||
| "MochiMidBlock3D", | ||
| "MochiUpBlock3D", | ||
| } | ||
| super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||
|
|
||
| @unittest.skip("Unsupported test.") | ||
| def test_effective_gradient_checkpointing(self): | ||
| """ Fails because of conv_cache: | ||
| tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_effective_gradient_checkpointing - | ||
| TypeError: ModelMixin.enable_gradient_checkpointing.<locals>._gradient_checkpointing_func() got an unexpected keyword argument 'conv_cache' | ||
| """ | ||
| pass | ||
|
|
||
| @unittest.skip("Unsupported test.") | ||
| def test_forward_with_norm_groups(self): | ||
| """ | ||
| tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups - | ||
| TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups' | ||
| """ | ||
| pass | ||
|
|
||
|
|
||
| @unittest.skip("Unsupported test.") | ||
| def test_model_parallelism(self): | ||
| """ | ||
| tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence - | ||
| RuntimeError: values expected sparse tensor layout but got Strided | ||
| """ | ||
| pass | ||
|
|
||
| @unittest.skip("Unsupported test.") | ||
| def test_outputs_equivalence(self): | ||
| """ | ||
| tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence - | ||
| RuntimeError: values expected sparse tensor layout but got Strided | ||
| """ | ||
| pass | ||
|
|
||
| @unittest.skip("Unsupported test.") | ||
| def test_sharded_checkpoints_device_map(self): | ||
| """ | ||
| tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_sharded_checkpoints_device_map - | ||
| RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:5! | ||
| """ | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test should be made to pass. It fails because we don't accept kwargs in gradient checkpointing func:
diffusers/src/diffusers/models/modeling_utils.py
Line 318 in d6f4774
Could you update the vae implementation to not pass conv_cache argument as a keyword arg and instead just use normal arg? LMK if you'd like me to take this up in a separate PR. Other than, changes LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done ✅