Skip to content

Commit 00df9c9

Browse files
authored
Merge branch 'main' into patch-1
2 parents 11847cc + 07bd2fa commit 00df9c9

20 files changed

+1866
-29
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
5353
- all
5454
- __call__
5555

56+
## StableDiffusionPAGImg2ImgPipeline
57+
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
58+
- all
59+
- __call__
60+
5661
## StableDiffusionControlNetPAGPipeline
5762
[[autodoc]] StableDiffusionControlNetPAGPipeline
5863

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ image
7575

7676
![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)
7777

78+
<Tip>
79+
80+
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
81+
82+
</Tip>
83+
7884
## Merge adapters
7985

8086
You can also merge different adapter checkpoints for inference to blend their styles together.

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@
344344
"StableDiffusionLatentUpscalePipeline",
345345
"StableDiffusionLDM3DPipeline",
346346
"StableDiffusionModelEditingPipeline",
347+
"StableDiffusionPAGImg2ImgPipeline",
347348
"StableDiffusionPAGPipeline",
348349
"StableDiffusionPanoramaPipeline",
349350
"StableDiffusionParadigmsPipeline",
@@ -795,6 +796,7 @@
795796
StableDiffusionLatentUpscalePipeline,
796797
StableDiffusionLDM3DPipeline,
797798
StableDiffusionModelEditingPipeline,
799+
StableDiffusionPAGImg2ImgPipeline,
798800
StableDiffusionPAGPipeline,
799801
StableDiffusionPanoramaPipeline,
800802
StableDiffusionParadigmsPipeline,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 231 additions & 20 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/unet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
115115
`default_{i}` where i is the total number of adapters being loaded.
116116
weight_name (`str`, *optional*, defaults to None):
117117
Name of the serialized state dict file.
118+
low_cpu_mem_usage (`bool`, *optional*):
119+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
120+
weights.
118121
119122
Example:
120123
@@ -142,8 +145,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
142145
adapter_name = kwargs.pop("adapter_name", None)
143146
_pipeline = kwargs.pop("_pipeline", None)
144147
network_alphas = kwargs.pop("network_alphas", None)
148+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
145149
allow_pickle = False
146150

151+
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
152+
raise ValueError(
153+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
154+
)
155+
147156
if use_safetensors is None:
148157
use_safetensors = True
149158
allow_pickle = True
@@ -209,6 +218,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
209218
network_alphas=network_alphas,
210219
adapter_name=adapter_name,
211220
_pipeline=_pipeline,
221+
low_cpu_mem_usage=low_cpu_mem_usage,
212222
)
213223
else:
214224
raise ValueError(
@@ -268,7 +278,9 @@ def _process_custom_diffusion(self, state_dict):
268278

269279
return attn_processors
270280

271-
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
281+
def _process_lora(
282+
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
283+
):
272284
# This method does the following things:
273285
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
274286
# format. For legacy format no filtering is applied.
@@ -335,9 +347,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
335347
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
336348
# otherwise loading LoRA weights will lead to an error
337349
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
350+
peft_kwargs = {}
351+
if is_peft_version(">=", "0.13.1"):
352+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
338353

339-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
340-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
354+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
355+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
341356

342357
if incompatible_keys is not None:
343358
# check only for unexpected keys

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
"HunyuanDiTPAGPipeline",
165165
"StableDiffusion3PAGPipeline",
166166
"StableDiffusionPAGPipeline",
167+
"StableDiffusionPAGImg2ImgPipeline",
167168
"StableDiffusionControlNetPAGPipeline",
168169
"StableDiffusionXLPAGPipeline",
169170
"StableDiffusionXLPAGInpaintPipeline",
@@ -569,6 +570,7 @@
569570
StableDiffusion3PAGPipeline,
570571
StableDiffusionControlNetPAGInpaintPipeline,
571572
StableDiffusionControlNetPAGPipeline,
573+
StableDiffusionPAGImg2ImgPipeline,
572574
StableDiffusionPAGPipeline,
573575
StableDiffusionXLControlNetPAGImg2ImgPipeline,
574576
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
StableDiffusion3PAGPipeline,
6464
StableDiffusionControlNetPAGInpaintPipeline,
6565
StableDiffusionControlNetPAGPipeline,
66+
StableDiffusionPAGImg2ImgPipeline,
6667
StableDiffusionPAGPipeline,
6768
StableDiffusionXLControlNetPAGImg2ImgPipeline,
6869
StableDiffusionXLControlNetPAGPipeline,
@@ -131,6 +132,7 @@
131132
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
132133
("kandinsky3", Kandinsky3Img2ImgPipeline),
133134
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
135+
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
134136
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
135137
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
136138
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
893893
def num_timesteps(self):
894894
return self._num_timesteps
895895

896+
@property
897+
def interrupt(self):
898+
return self._interrupt
899+
896900
@torch.no_grad()
897901
@replace_example_docstring(EXAMPLE_DOC_STRING)
898902
def __call__(
@@ -1089,6 +1093,7 @@ def __call__(
10891093
self._guidance_scale = guidance_scale
10901094
self._clip_skip = clip_skip
10911095
self._cross_attention_kwargs = cross_attention_kwargs
1096+
self._interrupt = False
10921097

10931098
# 2. Define call parameters
10941099
if prompt is not None and isinstance(prompt, str):
@@ -1235,6 +1240,9 @@ def __call__(
12351240
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
12361241
with self.progress_bar(total=num_inference_steps) as progress_bar:
12371242
for i, t in enumerate(timesteps):
1243+
if self.interrupt:
1244+
continue
1245+
12381246
# Relevant thread:
12391247
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
12401248
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ def cross_attention_kwargs(self):
891891
def num_timesteps(self):
892892
return self._num_timesteps
893893

894+
@property
895+
def interrupt(self):
896+
return self._interrupt
897+
894898
@torch.no_grad()
895899
@replace_example_docstring(EXAMPLE_DOC_STRING)
896900
def __call__(
@@ -1081,6 +1085,7 @@ def __call__(
10811085
self._guidance_scale = guidance_scale
10821086
self._clip_skip = clip_skip
10831087
self._cross_attention_kwargs = cross_attention_kwargs
1088+
self._interrupt = False
10841089

10851090
# 2. Define call parameters
10861091
if prompt is not None and isinstance(prompt, str):
@@ -1211,6 +1216,9 @@ def __call__(
12111216
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
12121217
with self.progress_bar(total=num_inference_steps) as progress_bar:
12131218
for i, t in enumerate(timesteps):
1219+
if self.interrupt:
1220+
continue
1221+
12141222
# expand the latents if we are doing classifier free guidance
12151223
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
12161224
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
976976
def num_timesteps(self):
977977
return self._num_timesteps
978978

979+
@property
980+
def interrupt(self):
981+
return self._interrupt
982+
979983
@torch.no_grad()
980984
@replace_example_docstring(EXAMPLE_DOC_STRING)
981985
def __call__(
@@ -1191,6 +1195,7 @@ def __call__(
11911195
self._guidance_scale = guidance_scale
11921196
self._clip_skip = clip_skip
11931197
self._cross_attention_kwargs = cross_attention_kwargs
1198+
self._interrupt = False
11941199

11951200
# 2. Define call parameters
11961201
if prompt is not None and isinstance(prompt, str):
@@ -1375,6 +1380,9 @@ def __call__(
13751380
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
13761381
with self.progress_bar(total=num_inference_steps) as progress_bar:
13771382
for i, t in enumerate(timesteps):
1383+
if self.interrupt:
1384+
continue
1385+
13781386
# expand the latents if we are doing classifier free guidance
13791387
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
13801388
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

0 commit comments

Comments
 (0)