Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,3 +1712,10 @@ def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_p
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict

def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
raise ValueError("Invalid LoRA state dict for LTX-Video.")
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
6 changes: 5 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
Expand Down Expand Up @@ -3418,7 +3419,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -3512,6 +3512,10 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)

return state_dict

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
Expand Down
Loading