Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
return converted_state_dict


def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
num_layers = 19
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
original_block_prefix = "base_model.model."

for lora_key in ["lora_A", "lora_B"]:
# norms
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
)
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
)

# Q, K, V
if lora_key == "lora_A":
sample_lora_weight = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])

context_lora_weight = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
else:
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])

context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])

if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])

if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])

# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
)

# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
)

# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."

for lora_key in ["lora_A", "lora_B"]:
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
)
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
)

# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)

if lora_key == "lora_A":
lora_weight = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])

if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
else:
q, k, v, mlp = torch.split(
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
split_size,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])

if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
q_bias, k_bias, v_bias, mlp_bias = torch.split(
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
split_size,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])

# output projections.
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
)
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
)

for lora_key in ["lora_A", "lora_B"]:
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
)
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
)

if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}

Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
Expand Down Expand Up @@ -2062,6 +2063,17 @@ def lora_state_dict(
return_metadata=return_lora_metadata,
)

is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)

# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
Expand Down
Loading