Skip to content

Commit d44f39c

Browse files
committed
feat: support unload_lora_weights() for Flux Control.
1 parent ec9bfa9 commit d44f39c

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,50 @@ def unload_lora_weights(self):
22832283
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
22842284
transformer._transformer_norm_layers = None
22852285

2286+
if getattr(transformer, "_overwritten_params", None) is not None:
2287+
print(f"{transformer._overwritten_params.keys()=}")
2288+
overwritten_params = transformer._overwritten_params
2289+
module_names = set()
2290+
2291+
for param_name in overwritten_params:
2292+
if param_name.endswith(".weight"):
2293+
module_names.add(param_name.replace(".weight", ""))
2294+
2295+
for name, module in transformer.named_modules():
2296+
if isinstance(module, torch.nn.Linear) and name in module_names:
2297+
module_weight = module.weight.data
2298+
module_bias = module.bias.data if module.bias is not None else None
2299+
bias = module_bias is not None
2300+
2301+
parent_module_name, _, current_module_name = name.rpartition(".")
2302+
parent_module = transformer.get_submodule(parent_module_name)
2303+
2304+
current_param_weight = overwritten_params[f"{name}.weight"]
2305+
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
2306+
with torch.device("meta"):
2307+
original_module = torch.nn.Linear(
2308+
in_features,
2309+
out_features,
2310+
bias=bias,
2311+
device=module_weight.device,
2312+
dtype=module_weight.dtype,
2313+
)
2314+
2315+
original_module.weight.data.copy_(current_param_weight)
2316+
if module_bias is not None:
2317+
original_module.bias.data.copy_(overwritten_params[f"{name}.bias"])
2318+
2319+
setattr(parent_module, current_module_name, original_module)
2320+
2321+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2322+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2323+
new_value = int(current_param_weight.shape[1])
2324+
old_value = getattr(transformer.config, attribute_name)
2325+
setattr(transformer.config, attribute_name, new_value)
2326+
logger.info(
2327+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2328+
)
2329+
22862330
@classmethod
22872331
def _maybe_expand_transformer_param_shape_or_error_(
22882332
cls,
@@ -2309,6 +2353,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23092353

23102354
# Expand transformer parameter shapes if they don't match lora
23112355
has_param_with_shape_update = False
2356+
overwritten_params = {}
23122357

23132358
for name, module in transformer.named_modules():
23142359
if isinstance(module, torch.nn.Linear):
@@ -2371,6 +2416,14 @@ def _maybe_expand_transformer_param_shape_or_error_(
23712416
setattr(transformer.config, attribute_name, new_value)
23722417
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
23732418

2419+
# For `unload_lora_weights()`.
2420+
overwritten_params[f"{current_module_name}.weight"] = module_weight
2421+
if module_bias is not None:
2422+
overwritten_params[f"{current_module_name}.bias"] = module_bias
2423+
2424+
if len(overwritten_params) > 0:
2425+
transformer._overwritten_params = overwritten_params
2426+
23742427
return has_param_with_shape_update
23752428

23762429

tests/lora/test_lora_layers_flux.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,68 @@ def test_correct_lora_configs_with_different_ranks(self):
430430
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
431431
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
432432

433+
def test_lora_unload_with_parameter_expanded_shapes(self):
434+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
435+
436+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
437+
logger.setLevel(logging.DEBUG)
438+
439+
# Change the transformer config to mimic a real use case.
440+
num_channels_without_control = 4
441+
transformer = FluxTransformer2DModel.from_config(
442+
components["transformer"].config, in_channels=num_channels_without_control
443+
).to(torch_device)
444+
self.assertTrue(
445+
transformer.config.in_channels == num_channels_without_control,
446+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
447+
)
448+
449+
# This should be initialize with a Flux pipeline variant that doesn't accept `control_image`.
450+
components["transformer"] = transformer
451+
pipe = FluxPipeline(**components)
452+
pipe = pipe.to(torch_device)
453+
pipe.set_progress_bar_config(disable=None)
454+
455+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
456+
control_image = inputs.pop("control_image")
457+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
458+
459+
control_pipe = self.pipeline_class(**components)
460+
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
461+
rank = 4
462+
463+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
464+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
465+
lora_state_dict = {
466+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
467+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
468+
}
469+
with CaptureLogger(logger) as cap_logger:
470+
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
471+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
472+
473+
inputs["control_image"] = control_image
474+
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
475+
476+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
477+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
478+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
479+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
480+
481+
control_pipe.unload_lora_weights()
482+
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
483+
self.assertTrue(
484+
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
485+
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
486+
)
487+
inputs.pop("control_image")
488+
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
489+
490+
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
491+
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
492+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
493+
self.assertTrue(pipe.transformer.config.in_channels == in_features)
494+
433495
@unittest.skip("Not supported in Flux.")
434496
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
435497
pass

0 commit comments

Comments
 (0)