-
Couldn't load subscription status.
- Fork 6.5k
[LoRA] Support original format loras for HunyuanVideo #10376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
66fc85e
893b9c0
904e3a4
8be9180
a040c5d
f682d76
63d5e9f
4ac0c12
5fbc59c
95a7e0f
738f50d
23854f2
2cc3683
3b64bd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |
| from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa | ||
| from .lora_conversion_utils import ( | ||
| _convert_bfl_flux_control_lora_to_diffusers, | ||
| _convert_hunyuan_video_lora_to_diffusers, | ||
| _convert_kohya_flux_lora_to_diffusers, | ||
| _convert_non_diffusers_lora_to_diffusers, | ||
| _convert_xlabs_flux_lora_to_diffusers, | ||
|
|
@@ -4007,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): | |
|
|
||
| @classmethod | ||
| @validate_hf_hub_args | ||
| # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict | ||
| def lora_state_dict( | ||
| cls, | ||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
|
|
@@ -4018,7 +4018,7 @@ def lora_state_dict( | |
|
|
||
| <Tip warning={true}> | ||
|
|
||
| We support loading A1111 formatted LoRA checkpoints in a limited capacity. | ||
| We support loading original format HunyuanVideo LoRA checkpoints. | ||
|
|
||
| This function is experimental and might change in the future. | ||
|
|
||
|
|
@@ -4101,6 +4101,10 @@ def lora_state_dict( | |
| logger.warning(warn_msg) | ||
| state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | ||
|
|
||
| is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) | ||
| if is_original_hunyuan_video: | ||
| state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) | ||
|
|
||
| return state_dict | ||
|
|
||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
|
|
@@ -4239,10 +4243,9 @@ def save_lora_weights( | |
| safe_serialization=safe_serialization, | ||
| ) | ||
|
|
||
| # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could leverage the CogVideoX fuse_lora for the "Copy" statement, no? If so, I'd prefer that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really because we have a hunyuan specific example here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we follow "Copied from ..." with the same example to play it to our advantage (of maintenance) for the other classes, too. So, let's perhaps maintain that consistency. @stevhliu WDYT about that? |
||
| def fuse_lora( | ||
| self, | ||
| components: List[str] = ["transformer", "text_encoder"], | ||
| components: List[str] = ["transformer"], | ||
| lora_scale: float = 1.0, | ||
| safe_fusing: bool = False, | ||
| adapter_names: Optional[List[str]] = None, | ||
|
|
@@ -4269,14 +4272,16 @@ def fuse_lora( | |
| Example: | ||
|
|
||
| ```py | ||
| from diffusers import DiffusionPipeline | ||
| import torch | ||
|
|
||
| pipeline = DiffusionPipeline.from_pretrained( | ||
| "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 | ||
| ).to("cuda") | ||
| pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") | ||
| pipeline.fuse_lora(lora_scale=0.7) | ||
| >>> import torch | ||
| >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel | ||
|
|
||
| >>> model_id = "hunyuanvideo-community/HunyuanVideo" | ||
| >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( | ||
| ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 | ||
| ... ) | ||
| >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) | ||
| >>> pipe.load_lora_weights("a-r-r-o-w/HunyuanVideo-tuxemons", adapter_name="tuxemons") | ||
| >>> pipe.set_adapter("tuxemons", 1.2) | ||
| ``` | ||
| """ | ||
| super().fuse_lora( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -12,9 +12,11 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|
|
||||
| import gc | ||||
| import sys | ||||
| import unittest | ||||
|
|
||||
| import numpy as np | ||||
| import torch | ||||
| from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast | ||||
|
|
||||
|
|
@@ -26,8 +28,12 @@ | |||
| ) | ||||
| from diffusers.utils.testing_utils import ( | ||||
| floats_tensor, | ||||
| nightly, | ||||
| numpy_cosine_similarity_distance, | ||||
| require_peft_backend, | ||||
| require_torch_gpu, | ||||
| skip_mps, | ||||
| slow, | ||||
| ) | ||||
|
|
||||
|
|
||||
|
|
@@ -182,3 +188,70 @@ def test_simple_inference_with_text_lora_fused(self): | |||
| @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") | ||||
| def test_simple_inference_with_text_lora_save_load(self): | ||||
| pass | ||||
|
|
||||
|
|
||||
| @slow | ||||
| @nightly | ||||
| @require_torch_gpu | ||||
| @require_peft_backend | ||||
| # @unittest.skip("We cannot run inference on this model with the current CI hardware") | ||||
| # TODO (DN6, sayakpaul): move these tests to a beefier GPU | ||||
|
||||
| @require_big_gpu_with_torch_cuda |
Flux LoRA ones will be in after #9845 is merged. Since we already have a test suite for LoRA that uses the big model marker, I think it's fine to utilize that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be a blocker to do it in separate PR? If not, will revert the copied from changes and proceed to merge as this seems like something folks want without more delay, and I don't really have the bandwidth atm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine by me.
Uh oh!
There was an error while loading. Please reload this page.