diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index df1d351ca1f7..4b5426947906 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -14,6 +14,8 @@ import re +import torch + from ..utils import is_peft_version, logging @@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): prefix = "text_encoder_2." new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" return {new_name: alpha} + + +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +# All credits go to `kohya-ss`. +def _convert_kohya_flux_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear", + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + + return ait_sd + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + + +# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 +# Some utilities were reused from +# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): + new_state_dict = {} + orig_keys = list(old_state_dict.keys()) + + def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + down_weight = sds_sd.pop(sds_key) + up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + + for old_key in orig_keys: + # Handle double_blocks + if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): + block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.transformer_blocks.{block_num}" + + if "processor.proj_lora1" in old_key: + new_key += ".attn.to_out.0" + elif "processor.proj_lora2" in old_key: + new_key += ".attn.to_add_out" + elif "processor.qkv_lora1" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", + ], + ) + # continue + elif "processor.qkv_lora2" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.to_q", + f"transformer.transformer_blocks.{block_num}.attn.to_k", + f"transformer.transformer_blocks.{block_num}.attn.to_v", + ], + ) + # continue + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + # Handle single_blocks + elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): + block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.single_transformer_blocks.{block_num}" + + if "proj_lora1" in old_key or "proj_lora2" in old_key: + new_key += ".proj_out" + elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: + new_key += ".norm.linear" + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + else: + # Handle other potential key patterns here + new_key = old_key + + # Since we already handle qkv above. + if "qkv" not in old_key: + new_state_dict[new_key] = old_state_dict.pop(old_key) + + if len(old_state_dict) > 0: + raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") + + return new_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index cefe66bc8cb6..7d644d684153 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -31,7 +31,12 @@ scale_lora_layers, ) from .lora_base import LoraBaseMixin -from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers +from .lora_conversion_utils import ( + _convert_kohya_flux_lora_to_diffusers, + _convert_non_diffusers_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) if is_transformers_available(): @@ -1583,6 +1588,20 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return (state_dict, None) if return_alphas else state_dict + + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return (state_dict, None) if return_alphas else state_dict + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys())