diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 9a1cc96e93e9..da65a1208b69 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict ) + def get_alpha_scales(down_weight, alpha_key): + rank = down_weight.shape[0] + alpha = original_state_dict.pop(alpha_key).item() + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + for key in list(original_state_dict.keys()): if key.endswith((".diff", ".diff_b")) and "norm" in key: # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it @@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): for i in range(min_block, max_block + 1): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.self_attn.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight" - original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight" + + if has_alpha: + down_weight = original_state_dict.pop(original_key_A) + up_weight = original_state_dict.pop(original_key_B) + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] = down_weight * scale_down + converted_state_dict[converted_key_B] = up_weight * scale_up + + else: + if original_key_A in original_state_dict: + converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A) + if original_key_B in original_state_dict: + converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B) original_key = f"blocks.{i}.self_attn.{o}.diff_b" converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" @@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.cross_attn.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight" + + original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight" + + if original_key_A in original_state_dict: + down_weight = original_state_dict.pop(original_key_A) + converted_state_dict[converted_key_A] = down_weight + if original_key_B in original_state_dict: + up_weight = original_state_dict.pop(original_key_B) + converted_state_dict[converted_key_B] = up_weight + if has_alpha: + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] *= scale_down + converted_state_dict[converted_key_B] *= scale_up original_key = f"blocks.{i}.cross_attn.{o}.diff_b" converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" @@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.cross_attn.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight" + + original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight" + + if original_key_A in original_state_dict: + down_weight = original_state_dict.pop(original_key_A) + converted_state_dict[converted_key_A] = down_weight + if original_key_B in original_state_dict: + up_weight = original_state_dict.pop(original_key_B) + converted_state_dict[converted_key_B] = up_weight + if has_alpha: + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] *= scale_down + converted_state_dict[converted_key_B] *= scale_up original_key = f"blocks.{i}.cross_attn.{o}.diff_b" converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" @@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): - original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight" + + original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight" + + if original_key_A in original_state_dict: + down_weight = original_state_dict.pop(original_key_A) + converted_state_dict[converted_key_A] = down_weight + if original_key_B in original_state_dict: + up_weight = original_state_dict.pop(original_key_B) + converted_state_dict[converted_key_B] = up_weight + if has_alpha: + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] *= scale_down + converted_state_dict[converted_key_B] *= scale_up original_key = f"blocks.{i}.{o}.diff_b" converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 24fcd37fd75d..7461143ad5f1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5270,15 +5270,37 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + if "transformer_2" not in self._lora_loadable_modules: + self._lora_loadable_modules.append("transformer_2") + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel @@ -5668,15 +5690,37 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + if "transformer_2" not in self._lora_loadable_modules: + self._lora_loadable_modules.append("transformer_2") + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel