Skip to content

Commit 1fef4c7

Browse files
committed
omi lora.
1 parent c934720 commit 1fef4c7

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,12 +1789,58 @@ def get_alpha_scales(down_weight, key):
17891789
return converted_state_dict
17901790

17911791

1792-
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
1793-
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
1794-
raise ValueError("Invalid LoRA state dict for HiDream.")
1795-
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
1796-
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
1797-
return converted_state_dict
1792+
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict):
1793+
non_diffusers_prefix = "diffusion_model"
1794+
is_kohya = all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict)
1795+
1796+
def _convert_kohya(state_dict):
1797+
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
1798+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
1799+
return converted_state_dict
1800+
1801+
if is_kohya:
1802+
return _convert_kohya(state_dict)
1803+
1804+
else:
1805+
assert any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
1806+
converted_state_dict = {}
1807+
component = "transformer"
1808+
compoent_sd = {k: v for k, v in state_dict.items() if k.startswith(f"{component}.")}
1809+
1810+
def _convert_omi(key, state_dict, component):
1811+
down_key = f"{key}.lora_down.weight"
1812+
down_weight = state_dict.pop(down_key)
1813+
lora_rank = down_weight.shape[0]
1814+
1815+
up_weight_key = f"{key}.lora_up.weight"
1816+
up_weight = state_dict.pop(up_weight_key)
1817+
1818+
alpha_key = f"{key}.alpha"
1819+
alpha = state_dict.pop(alpha_key)
1820+
1821+
# scale weight by alpha and dim
1822+
scale = alpha / lora_rank
1823+
# calculate scale_down and scale_up
1824+
scale_down = scale
1825+
scale_up = 1.0
1826+
while scale_down * 2 < scale_up:
1827+
scale_down *= 2
1828+
scale_up /= 2
1829+
down_weight = down_weight * scale_down
1830+
up_weight = up_weight * scale_up
1831+
1832+
diffusers_down_key = f"{key}.lora_A.weight"
1833+
converted_state_dict[f"{component}.{diffusers_down_key}"] = down_weight
1834+
converted_state_dict[f"{component}.{diffusers_down_key.replace('.lora_A.', '.lora_B.')}"] = up_weight
1835+
1836+
all_unique_keys = {
1837+
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
1838+
for k in compoent_sd
1839+
}
1840+
for k in all_unique_keys:
1841+
_convert_omi(k, compoent_sd, component=component)
1842+
1843+
return converted_state_dict
17981844

17991845

18001846
def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):

src/diffusers/loaders/lora_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5489,7 +5489,9 @@ def lora_state_dict(
54895489
logger.warning(warn_msg)
54905490
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
54915491

5492-
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
5492+
kohya_format = any("diffusion_model" in k for k in state_dict)
5493+
is_omi_format = any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
5494+
is_non_diffusers_format = kohya_format or is_omi_format
54935495
if is_non_diffusers_format:
54945496
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
54955497

0 commit comments

Comments
 (0)