Skip to content

Commit 6523fa6

Browse files
committed
updates.
1 parent 993f3d3 commit 6523fa6

File tree

3 files changed

+44
-29
lines changed

3 files changed

+44
-29
lines changed

scripts/convert_flux_control_lora_to_diffusers.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
import argparse
2-
from contextlib import nullcontext
32

43
import safetensors.torch
54
import torch
6-
from accelerate import init_empty_weights
75
from huggingface_hub import hf_hub_download
86

9-
from diffusers.utils.import_utils import is_accelerate_available
10-
11-
12-
CTX = init_empty_weights if is_accelerate_available else nullcontext
137

148
parser = argparse.ArgumentParser()
159
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
@@ -22,27 +16,13 @@
2216
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
2317

2418

25-
# Adapted from from the original BFL codebase.
26-
def optionally_expand_state_dict(name: str, param: torch.Tensor, state_dict: dict) -> dict:
27-
if name in state_dict:
28-
print(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.")
29-
# expand with zeros:
30-
expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
31-
# popular with pre-trained param for the first half. Remaining half stays with zeros.
32-
slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
33-
expanded_state_dict_weight[slices] = state_dict[name]
34-
state_dict[name] = expanded_state_dict_weight
35-
36-
return state_dict
37-
38-
3919
def load_original_checkpoint(args):
4020
if args.original_state_dict_repo_id is not None:
4121
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
4222
elif args.checkpoint_path is not None:
4323
ckpt_path = args.checkpoint_path
4424
else:
45-
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
25+
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
4626

4727
original_state_dict = safetensors.torch.load_file(ckpt_path)
4828
return original_state_dict
@@ -60,7 +40,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers(
6040
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
6141
):
6242
converted_state_dict = {}
63-
original_state_dict_keys = original_state_dict.keys()
43+
original_state_dict_keys = list(original_state_dict.keys())
6444

6545
for lora_key in ["lora_A", "lora_B"]:
6646
## time_text_embed.timestep_embedder <- time_in
@@ -346,7 +326,8 @@ def convert_flux_control_lora_checkpoint_to_diffusers(
346326
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
347327
)
348328

349-
print("Remaining:", original_state_dict.keys())
329+
if len(original_state_dict) > 0:
330+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
350331

351332
for key in list(converted_state_dict.keys()):
352333
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,9 +1925,11 @@ def _load_norm_into_transformer(
19251925
transformer_keys = set(transformer_state_dict.keys())
19261926
state_dict_keys = set(state_dict.keys())
19271927
extra_keys = list(state_dict_keys - transformer_keys)
1928-
logger.warning(
1929-
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
1930-
)
1928+
1929+
if extra_keys:
1930+
logger.warning(
1931+
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
1932+
)
19311933

19321934
for key in extra_keys:
19331935
state_dict.pop(key)
@@ -2292,15 +2294,15 @@ def get_submodule(module, name):
22922294
)
22932295

22942296
new_weight = torch.zeros_like(
2295-
expanded_module.weight.data.shape, device=module_weight.device, dtype=module_weight.dtype
2297+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
22962298
)
22972299
slices = tuple(slice(0, dim) for dim in module_weight.shape)
22982300
new_weight[slices] = module_weight
22992301
expanded_module.weight.data.copy_(new_weight)
23002302

23012303
if bias:
23022304
new_bias = torch.zeros_like(
2303-
expanded_module.bias.data.shape, device=module_bias.device, dtype=module_bias.dtype
2305+
expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype
23042306
)
23052307
slices = tuple(slice(0, dim) for dim in module_bias.shape)
23062308
new_bias[slices] = module_bias

src/diffusers/loaders/peft.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,37 @@
5656
}
5757

5858

59+
def _maybe_adjust_config(config):
60+
rank_pattern = config["rank_pattern"].copy()
61+
target_modules = config["target_modules"]
62+
original_r = config["r"]
63+
64+
for key in list(rank_pattern.keys()):
65+
key_rank = rank_pattern[key]
66+
67+
# try to detect ambiguity
68+
exact_matches = [mod for mod in target_modules if mod == key]
69+
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
70+
ambiguous_key = key
71+
72+
if exact_matches and substring_matches:
73+
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
74+
config["r"] = key_rank
75+
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
76+
del config["rank_pattern"][key]
77+
for mod in substring_matches:
78+
# avoid overwriting if the module already has a specific rank
79+
if mod not in config["rank_pattern"]:
80+
config["rank_pattern"][mod] = original_r
81+
82+
# update the rest of the keys with the `original_r`
83+
for mod in target_modules:
84+
if mod != ambiguous_key and mod not in config["rank_pattern"]:
85+
config["rank_pattern"][mod] = original_r
86+
87+
return config
88+
89+
5990
class PeftAdapterMixin:
6091
"""
6192
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -226,7 +257,8 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
226257
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
227258

228259
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
229-
print(lora_config_kwargs)
260+
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
261+
230262
if "use_dora" in lora_config_kwargs:
231263
if lora_config_kwargs["use_dora"]:
232264
if is_peft_version("<", "0.9.0"):

0 commit comments

Comments
 (0)