Skip to content

Commit 5abff4e

Browse files
sayakpaula-r-r-o-w
andcommitted
[LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill (#10259)
* lora expansion with dummy zeros. * updates * fix working 🥳 * working. * use torch.device meta for state dict expansion. * tests Co-authored-by: a-r-r-o-w <[email protected]> * fixes * fixes * switch to debug * fix * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * fix stuff * docs --------- Co-authored-by: a-r-r-o-w <[email protected]> Co-authored-by: Aryan <[email protected]>
1 parent bb482ba commit 5abff4e

File tree

3 files changed

+239
-84
lines changed

3 files changed

+239
-84
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,43 @@ images = pipe(
268268
images[0].save("flux-redux.png")
269269
```
270270

271+
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
272+
273+
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
274+
275+
```py
276+
from diffusers import FluxControlPipeline
277+
from image_gen_aux import DepthPreprocessor
278+
from diffusers.utils import load_image
279+
from huggingface_hub import hf_hub_download
280+
import torch
281+
282+
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
283+
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
284+
control_pipe.load_lora_weights(
285+
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
286+
)
287+
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
288+
control_pipe.enable_model_cpu_offload()
289+
290+
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
291+
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
292+
293+
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
294+
control_image = processor(control_image)[0].convert("RGB")
295+
296+
image = control_pipe(
297+
prompt=prompt,
298+
control_image=control_image,
299+
height=1024,
300+
width=1024,
301+
num_inference_steps=8,
302+
guidance_scale=10.0,
303+
generator=torch.Generator().manual_seed(42),
304+
).images[0]
305+
image.save("output.png")
306+
```
307+
271308
## Running FP16 inference
272309

273310
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.

src/diffusers/loaders/lora_pipeline.py

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,9 @@ def load_lora_weights(
18631863
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
18641864
"To get a comprehensive list of parameter names that were modified, enable debug logging."
18651865
)
1866+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1867+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
1868+
)
18661869

18671870
if len(transformer_lora_state_dict) > 0:
18681871
self.load_lora_into_transformer(
@@ -2309,16 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
23092312

23102313
# Expand transformer parameter shapes if they don't match lora
23112314
has_param_with_shape_update = False
2312-
2315+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
23132316
for name, module in transformer.named_modules():
23142317
if isinstance(module, torch.nn.Linear):
23152318
module_weight = module.weight.data
23162319
module_bias = module.bias.data if module.bias is not None else None
23172320
bias = module_bias is not None
23182321

2319-
lora_A_weight_name = f"{name}.lora_A.weight"
2320-
lora_B_weight_name = f"{name}.lora_B.weight"
2321-
if lora_A_weight_name not in state_dict.keys():
2322+
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
2323+
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
2324+
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
2325+
if lora_A_weight_name not in state_dict:
23222326
continue
23232327

23242328
in_features = state_dict[lora_A_weight_name].shape[1]
@@ -2329,56 +2333,105 @@ def _maybe_expand_transformer_param_shape_or_error_(
23292333
continue
23302334

23312335
module_out_features, module_in_features = module_weight.shape
2332-
if out_features < module_out_features or in_features < module_in_features:
2333-
raise NotImplementedError(
2334-
f"Only LoRAs with input/output features higher than the current module's input/output features "
2335-
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
2336-
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
2337-
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
2336+
debug_message = ""
2337+
if in_features > module_in_features:
2338+
debug_message += (
2339+
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2340+
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2341+
f"expanded from {module_in_features} to {in_features}"
23382342
)
2339-
2340-
debug_message = (
2341-
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2342-
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2343-
f"expanded from {module_in_features} to {in_features}"
2344-
)
2345-
if module_out_features != out_features:
2343+
if out_features > module_out_features:
23462344
debug_message += (
23472345
", and the number of output features will be "
23482346
f"expanded from {module_out_features} to {out_features}."
23492347
)
23502348
else:
23512349
debug_message += "."
2352-
logger.debug(debug_message)
2350+
if debug_message:
2351+
logger.debug(debug_message)
2352+
2353+
if out_features > module_out_features or in_features > module_in_features:
2354+
has_param_with_shape_update = True
2355+
parent_module_name, _, current_module_name = name.rpartition(".")
2356+
parent_module = transformer.get_submodule(parent_module_name)
2357+
2358+
with torch.device("meta"):
2359+
expanded_module = torch.nn.Linear(
2360+
in_features, out_features, bias=bias, dtype=module_weight.dtype
2361+
)
2362+
# Only weights are expanded and biases are not. This is because only the input dimensions
2363+
# are changed while the output dimensions remain the same. The shape of the weight tensor
2364+
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2365+
# explains the reason why only weights are expanded.
2366+
new_weight = torch.zeros_like(
2367+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2368+
)
2369+
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2370+
new_weight[slices] = module_weight
2371+
tmp_state_dict = {"weight": new_weight}
2372+
if module_bias is not None:
2373+
tmp_state_dict["bias"] = module_bias
2374+
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
2375+
2376+
setattr(parent_module, current_module_name, expanded_module)
2377+
2378+
del tmp_state_dict
2379+
2380+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2381+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2382+
new_value = int(expanded_module.weight.data.shape[1])
2383+
old_value = getattr(transformer.config, attribute_name)
2384+
setattr(transformer.config, attribute_name, new_value)
2385+
logger.info(
2386+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2387+
)
23532388

2354-
has_param_with_shape_update = True
2355-
parent_module_name, _, current_module_name = name.rpartition(".")
2356-
parent_module = transformer.get_submodule(parent_module_name)
2389+
return has_param_with_shape_update
23572390

2358-
# TODO: consider initializing this under meta device for optims.
2359-
expanded_module = torch.nn.Linear(
2360-
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
2361-
)
2362-
# Only weights are expanded and biases are not.
2363-
new_weight = torch.zeros_like(
2364-
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2391+
@classmethod
2392+
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
2393+
expanded_module_names = set()
2394+
transformer_state_dict = transformer.state_dict()
2395+
prefix = f"{cls.transformer_name}."
2396+
2397+
lora_module_names = [
2398+
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
2399+
]
2400+
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
2401+
lora_module_names = sorted(set(lora_module_names))
2402+
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
2403+
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
2404+
if unexpected_modules:
2405+
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
2406+
2407+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2408+
for k in lora_module_names:
2409+
if k in unexpected_modules:
2410+
continue
2411+
2412+
base_param_name = (
2413+
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2414+
)
2415+
base_weight_param = transformer_state_dict[base_param_name]
2416+
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
2417+
2418+
if base_weight_param.shape[1] > lora_A_param.shape[1]:
2419+
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2420+
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2421+
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
2422+
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
2423+
expanded_module_names.add(k)
2424+
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2425+
raise NotImplementedError(
2426+
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
23652427
)
2366-
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2367-
new_weight[slices] = module_weight
2368-
expanded_module.weight.data.copy_(new_weight)
2369-
if module_bias is not None:
2370-
expanded_module.bias.data.copy_(module_bias)
2371-
2372-
setattr(parent_module, current_module_name, expanded_module)
23732428

2374-
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2375-
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2376-
new_value = int(expanded_module.weight.data.shape[1])
2377-
old_value = getattr(transformer.config, attribute_name)
2378-
setattr(transformer.config, attribute_name, new_value)
2379-
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
2429+
if expanded_module_names:
2430+
logger.info(
2431+
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2432+
)
23802433

2381-
return has_param_with_shape_update
2434+
return lora_state_dict
23822435

23832436

23842437
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially

0 commit comments

Comments
 (0)