Skip to content

Commit f75f695

Browse files
authored
Merge branch 'main' into ds-support-sd3-lora
2 parents 5e760cd + 01780c3 commit f75f695

File tree

8 files changed

+173
-12
lines changed

8 files changed

+173
-12
lines changed

docs/source/en/api/pipelines/hunyuan_video.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
<Tip>
2222

23-
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
23+
Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
2424

2525
</Tip>
2626

docs/source/en/api/pipelines/hunyuandit.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ HunyuanDiT has the following components:
3030

3131
<Tip>
3232

33-
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
33+
Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
3434

3535
</Tip>
3636

src/diffusers/loaders/lora_pipeline.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,50 @@ def unload_lora_weights(self):
22862286
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
22872287
transformer._transformer_norm_layers = None
22882288

2289+
if getattr(transformer, "_overwritten_params", None) is not None:
2290+
overwritten_params = transformer._overwritten_params
2291+
module_names = set()
2292+
2293+
for param_name in overwritten_params:
2294+
if param_name.endswith(".weight"):
2295+
module_names.add(param_name.replace(".weight", ""))
2296+
2297+
for name, module in transformer.named_modules():
2298+
if isinstance(module, torch.nn.Linear) and name in module_names:
2299+
module_weight = module.weight.data
2300+
module_bias = module.bias.data if module.bias is not None else None
2301+
bias = module_bias is not None
2302+
2303+
parent_module_name, _, current_module_name = name.rpartition(".")
2304+
parent_module = transformer.get_submodule(parent_module_name)
2305+
2306+
current_param_weight = overwritten_params[f"{name}.weight"]
2307+
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
2308+
with torch.device("meta"):
2309+
original_module = torch.nn.Linear(
2310+
in_features,
2311+
out_features,
2312+
bias=bias,
2313+
dtype=module_weight.dtype,
2314+
)
2315+
2316+
tmp_state_dict = {"weight": current_param_weight}
2317+
if module_bias is not None:
2318+
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
2319+
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
2320+
setattr(parent_module, current_module_name, original_module)
2321+
2322+
del tmp_state_dict
2323+
2324+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2325+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2326+
new_value = int(current_param_weight.shape[1])
2327+
old_value = getattr(transformer.config, attribute_name)
2328+
setattr(transformer.config, attribute_name, new_value)
2329+
logger.info(
2330+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2331+
)
2332+
22892333
@classmethod
22902334
def _maybe_expand_transformer_param_shape_or_error_(
22912335
cls,
@@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_(
23122356

23132357
# Expand transformer parameter shapes if they don't match lora
23142358
has_param_with_shape_update = False
2359+
overwritten_params = {}
2360+
23152361
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
23162362
for name, module in transformer.named_modules():
23172363
if isinstance(module, torch.nn.Linear):
@@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
23862432
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
23872433
)
23882434

2435+
# For `unload_lora_weights()`.
2436+
# TODO: this could lead to more memory overhead if the number of overwritten params
2437+
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
2438+
overwritten_params[f"{current_module_name}.weight"] = module_weight
2439+
if module_bias is not None:
2440+
overwritten_params[f"{current_module_name}.bias"] = module_bias
2441+
2442+
if len(overwritten_params) > 0:
2443+
transformer._overwritten_params = overwritten_params
2444+
23892445
return has_param_with_shape_update
23902446

23912447
@classmethod

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@
2121
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2222
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2323
from ...schedulers import FlowMatchEulerDiscreteScheduler
24-
from ...utils import logging, replace_example_docstring
24+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2525
from ...utils.torch_utils import randn_tensor
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2727

2828

29+
if is_torch_xla_available():
30+
import torch_xla.core.xla_model as xm
31+
32+
XLA_AVAILABLE = True
33+
else:
34+
XLA_AVAILABLE = False
35+
2936
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3037

3138

@@ -564,6 +571,9 @@ def __call__(
564571
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
565572
progress_bar.update()
566573

574+
if XLA_AVAILABLE:
575+
xm.mark_step()
576+
567577
if output_type == "latent":
568578
image = latents
569579
else:

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
USE_PEFT_BACKEND,
3232
is_bs4_available,
3333
is_ftfy_available,
34+
is_torch_xla_available,
3435
logging,
3536
replace_example_docstring,
3637
scale_lora_layers,
@@ -46,6 +47,13 @@
4647
from .pipeline_output import SanaPipelineOutput
4748

4849

50+
if is_torch_xla_available():
51+
import torch_xla.core.xla_model as xm
52+
53+
XLA_AVAILABLE = True
54+
else:
55+
XLA_AVAILABLE = False
56+
4957
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5058

5159
if is_bs4_available():
@@ -864,6 +872,9 @@ def __call__(
864872
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
865873
progress_bar.update()
866874

875+
if XLA_AVAILABLE:
876+
xm.mark_step()
877+
867878
if output_type == "latent":
868879
image = latents
869880
else:

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,21 @@ def __init__(
226226
transformer=transformer,
227227
scheduler=scheduler,
228228
)
229-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
229+
self.vae_scale_factor = (
230+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
231+
)
232+
latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
230233
self.image_processor = VaeImageProcessor(
231-
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
234+
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
235+
)
236+
self.tokenizer_max_length = (
237+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
238+
)
239+
self.default_sample_size = (
240+
self.transformer.config.sample_size
241+
if hasattr(self, "transformer") and self.transformer is not None
242+
else 128
232243
)
233-
self.tokenizer_max_length = self.tokenizer.model_max_length
234-
self.default_sample_size = self.transformer.config.sample_size
235244
self.patch_size = (
236245
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
237246
)

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,28 @@ def __init__(
225225
transformer=transformer,
226226
scheduler=scheduler,
227227
)
228-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
228+
self.vae_scale_factor = (
229+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
230+
)
231+
latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
229232
self.image_processor = VaeImageProcessor(
230-
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
233+
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
231234
)
232235
self.mask_processor = VaeImageProcessor(
233236
vae_scale_factor=self.vae_scale_factor,
234-
vae_latent_channels=self.vae.config.latent_channels,
237+
vae_latent_channels=latent_channels,
235238
do_normalize=False,
236239
do_binarize=True,
237240
do_convert_grayscale=True,
238241
)
239-
self.tokenizer_max_length = self.tokenizer.model_max_length
240-
self.default_sample_size = self.transformer.config.sample_size
242+
self.tokenizer_max_length = (
243+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
244+
)
245+
self.default_sample_size = (
246+
self.transformer.config.sample_size
247+
if hasattr(self, "transformer") and self.transformer is not None
248+
else 128
249+
)
241250
self.patch_size = (
242251
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
243252
)

tests/lora/test_lora_layers_flux.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,72 @@ def test_load_regular_lora(self):
558558
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
559559
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
560560

561+
def test_lora_unload_with_parameter_expanded_shapes(self):
562+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
563+
564+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
565+
logger.setLevel(logging.DEBUG)
566+
567+
# Change the transformer config to mimic a real use case.
568+
num_channels_without_control = 4
569+
transformer = FluxTransformer2DModel.from_config(
570+
components["transformer"].config, in_channels=num_channels_without_control
571+
).to(torch_device)
572+
self.assertTrue(
573+
transformer.config.in_channels == num_channels_without_control,
574+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
575+
)
576+
577+
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
578+
components["transformer"] = transformer
579+
pipe = FluxPipeline(**components)
580+
pipe = pipe.to(torch_device)
581+
pipe.set_progress_bar_config(disable=None)
582+
583+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
584+
control_image = inputs.pop("control_image")
585+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
586+
587+
control_pipe = self.pipeline_class(**components)
588+
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
589+
rank = 4
590+
591+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
592+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
593+
lora_state_dict = {
594+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
595+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
596+
}
597+
with CaptureLogger(logger) as cap_logger:
598+
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
599+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
600+
601+
inputs["control_image"] = control_image
602+
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
603+
604+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
605+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
606+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
607+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
608+
609+
control_pipe.unload_lora_weights()
610+
self.assertTrue(
611+
control_pipe.transformer.config.in_channels == num_channels_without_control,
612+
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
613+
)
614+
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
615+
self.assertTrue(
616+
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
617+
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
618+
)
619+
inputs.pop("control_image")
620+
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
621+
622+
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
623+
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
624+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
625+
self.assertTrue(pipe.transformer.config.in_channels == in_features)
626+
561627
@unittest.skip("Not supported in Flux.")
562628
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
563629
pass

0 commit comments

Comments
 (0)