Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
# Remove "diffusion_model." prefix from keys.
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
converted_state_dict = {}

def get_num_layers(keys, pattern):
layers = set()
for key in keys:
match = re.search(pattern, key)
if match:
layers.add(int(match.group(1)))
return len(layers)

def process_block(prefix, index, convert_norm):
# Process attention qkv: pop lora_A and lora_B weights.
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
for attn_key in ["to_q", "to_k", "to_v"]:
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight

# Process attention out weights.
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.attention.out.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.attention.out.lora_B.weight"
)

# Process feed-forward weights for layers 1, 2, and 3.
for layer in range(1, 4):
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
)

if convert_norm:
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
)

noise_refiner_pattern = r"noise_refiner\.(\d+)\."
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
for i in range(num_noise_refiner_layers):
process_block("noise_refiner", i, convert_norm=True)

context_refiner_pattern = r"context_refiner\.(\d+)\."
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
for i in range(num_context_refiner_layers):
process_block("context_refiner", i, convert_norm=False)

core_transformer_pattern = r"layers\.(\d+)\."
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
for i in range(num_core_transformer_layers):
process_block("layers", i, convert_norm=True)

if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict
7 changes: 6 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
Expand Down Expand Up @@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -3909,6 +3909,11 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

# conversion.
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
if non_diffusers:
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this prefix specific to Lumina? Should we always just remove it?

Copy link
Member Author

@sayakpaul sayakpaul Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not specific to Lumina2 but specific to external trainer libraries. In all the past iterations where we have supported non-diffusers LoRA checkpoints, we have removed it because it's not in the diffusers compatible format.

We are not removing the prefix. We are using it to detect if the state dict is non-diffusers. If so, we are converting the state dict.

This is how rest of the non-diffusers checkpoints across different models have been supported in diffusers.


return state_dict

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
Expand Down
Loading