Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2624,6 +2624,38 @@ def get_alpha_scales(down_weight, alpha_key):
converted_state_dict[diffusers_down] = down_weight * scale_down
converted_state_dict[diffusers_up] = up_weight * scale_up

# Handle LoKr format: .alpha, .lokr_w1, .lokr_w2 (e.g. from Kohya/LyCORIS Z-image trainers).
# LoKr decomposition: delta = alpha * (lokr_w1 @ lokr_w2). Map to LoRA: lora_B @ lora_A.
lokr_w1_key = ".lokr_w1"
lokr_w2_key = ".lokr_w2"
has_lokr_format = any(lokr_w1_key in k for k in state_dict)

if has_lokr_format:
lokr_keys = [k for k in list(state_dict.keys()) if lokr_w1_key in k]
for k in lokr_keys:
if k not in state_dict:
continue
if not k.endswith(lokr_w1_key):
continue

base = k[: -len(lokr_w1_key)]
lokr_w2_key_full = base + lokr_w2_key
alpha_key = base + ".alpha"

if lokr_w2_key_full not in state_dict or alpha_key not in state_dict:
continue

lokr_w1 = state_dict.pop(k)
lokr_w2 = state_dict.pop(lokr_w2_key_full)
scale_down, scale_up = get_alpha_scales(lokr_w2, alpha_key)

# lora_A = lokr_w2 (r, in), lora_B = lokr_w1 (out, r)
diffusers_a_key = base + ".lora_A.weight"
diffusers_b_key = base + ".lora_B.weight"
converted_state_dict[diffusers_a_key] = lokr_w2 * scale_down
converted_state_dict[diffusers_b_key] = lokr_w1 * scale_up
state_dict.pop(alpha_key, None)

if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

Expand Down