-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Support dynamically loading/unloading loras with group offloading #11804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c740127
03335bf
3897c57
9c0ae73
17e55b8
ef8f7df
9b9fa85
52d9c20
955a406
a0fff85
06de952
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import unittest | ||
|
|
||
| import torch | ||
| from parameterized import parameterized | ||
| from transformers import AutoTokenizer, T5EncoderModel | ||
|
|
||
| from diffusers import ( | ||
|
|
@@ -28,6 +29,7 @@ | |
| from diffusers.utils.testing_utils import ( | ||
| floats_tensor, | ||
| require_peft_backend, | ||
| require_torch_accelerator, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -127,6 +129,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): | |
| def test_lora_scale_kwargs_match_fusion(self): | ||
| super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) | ||
|
|
||
| @parameterized.expand([("block_level", True), ("leaf_level", False)]) | ||
| @require_torch_accelerator | ||
| def test_group_offloading_inference_denoiser(self, offload_type, use_stream): | ||
| # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. | ||
| # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 | ||
| super()._test_group_offloading_inference_denoiser(offload_type, use_stream) | ||
|
Comment on lines
+132
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do this or we could detect if the test class is either of CogView4 or CogVideoX and use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer the base test class not having any information about the model test classes that derive it. Current implementation will work for any model that overrides the test, so also much cleaner |
||
|
|
||
| @unittest.skip("Not supported in CogVideoX.") | ||
| def test_simple_inference_with_text_denoiser_block_scale(self): | ||
| pass | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.