Skip to content

Commit b791e13

Browse files
authored
Merge branch 'main' into dreambooth-lora-flux-exploration
2 parents de3e2a5 + 31058cd commit b791e13

File tree

14 files changed

+489
-56
lines changed

14 files changed

+489
-56
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ image
7575

7676
![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)
7777

78+
<Tip>
79+
80+
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
81+
82+
</Tip>
83+
7884
## Merge adapters
7985

8086
You can also merge different adapter checkpoints for inference to blend their styles together.

src/diffusers/loaders/lora_pipeline.py

Lines changed: 264 additions & 26 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/unet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
115115
`default_{i}` where i is the total number of adapters being loaded.
116116
weight_name (`str`, *optional*, defaults to None):
117117
Name of the serialized state dict file.
118+
low_cpu_mem_usage (`bool`, *optional*):
119+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
120+
weights.
118121
119122
Example:
120123
@@ -142,8 +145,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
142145
adapter_name = kwargs.pop("adapter_name", None)
143146
_pipeline = kwargs.pop("_pipeline", None)
144147
network_alphas = kwargs.pop("network_alphas", None)
148+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
145149
allow_pickle = False
146150

151+
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
152+
raise ValueError(
153+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
154+
)
155+
147156
if use_safetensors is None:
148157
use_safetensors = True
149158
allow_pickle = True
@@ -209,6 +218,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
209218
network_alphas=network_alphas,
210219
adapter_name=adapter_name,
211220
_pipeline=_pipeline,
221+
low_cpu_mem_usage=low_cpu_mem_usage,
212222
)
213223
else:
214224
raise ValueError(
@@ -268,7 +278,9 @@ def _process_custom_diffusion(self, state_dict):
268278

269279
return attn_processors
270280

271-
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
281+
def _process_lora(
282+
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
283+
):
272284
# This method does the following things:
273285
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
274286
# format. For legacy format no filtering is applied.
@@ -335,9 +347,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
335347
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
336348
# otherwise loading LoRA weights will lead to an error
337349
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
350+
peft_kwargs = {}
351+
if is_peft_version(">=", "0.13.1"):
352+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
338353

339-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
340-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
354+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
355+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
341356

342357
if incompatible_keys is not None:
343358
# check only for unexpected keys

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,16 @@ def forward(
8383
hidden_states: torch.FloatTensor,
8484
temb: torch.FloatTensor,
8585
image_rotary_emb=None,
86+
joint_attention_kwargs=None,
8687
):
8788
residual = hidden_states
8889
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
8990
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
90-
91+
joint_attention_kwargs = joint_attention_kwargs or {}
9192
attn_output = self.attn(
9293
hidden_states=norm_hidden_states,
9394
image_rotary_emb=image_rotary_emb,
95+
**joint_attention_kwargs,
9496
)
9597

9698
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
@@ -161,18 +163,20 @@ def forward(
161163
encoder_hidden_states: torch.FloatTensor,
162164
temb: torch.FloatTensor,
163165
image_rotary_emb=None,
166+
joint_attention_kwargs=None,
164167
):
165168
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
166169

167170
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168171
encoder_hidden_states, emb=temb
169172
)
170-
173+
joint_attention_kwargs = joint_attention_kwargs or {}
171174
# Attention.
172175
attn_output, context_attn_output = self.attn(
173176
hidden_states=norm_hidden_states,
174177
encoder_hidden_states=norm_encoder_hidden_states,
175178
image_rotary_emb=image_rotary_emb,
179+
**joint_attention_kwargs,
176180
)
177181

178182
# Process attention outputs for the `hidden_states`.
@@ -497,6 +501,7 @@ def custom_forward(*inputs):
497501
encoder_hidden_states=encoder_hidden_states,
498502
temb=temb,
499503
image_rotary_emb=image_rotary_emb,
504+
joint_attention_kwargs=joint_attention_kwargs,
500505
)
501506

502507
# controlnet residual
@@ -533,6 +538,7 @@ def custom_forward(*inputs):
533538
hidden_states=hidden_states,
534539
temb=temb,
535540
image_rotary_emb=image_rotary_emb,
541+
joint_attention_kwargs=joint_attention_kwargs,
536542
)
537543

538544
# controlnet residual

src/diffusers/pipelines/deepfloyd_if/pipeline_output.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99

1010
@dataclass
1111
class IFPipelineOutput(BaseOutput):
12-
"""
13-
Args:
12+
r"""
1413
Output class for Stable Diffusion pipelines.
15-
images (`List[PIL.Image.Image]` or `np.ndarray`)
14+
15+
Args:
16+
images (`List[PIL.Image.Image]` or `np.ndarray`):
1617
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
1718
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
18-
nsfw_detected (`List[bool]`)
19+
nsfw_detected (`List[bool]`):
1920
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
2021
(nsfw) content or a watermark. `None` if safety checking could not be performed.
21-
watermark_detected (`List[bool]`)
22+
watermark_detected (`List[bool]`):
2223
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
2324
checking could not be performed.
2425
"""

src/diffusers/pipelines/pag/pag_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def _get_pag_scale(self, t):
9898
else:
9999
return self.pag_scale
100100

101-
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
101+
def _apply_perturbed_attention_guidance(
102+
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
103+
):
102104
r"""
103105
Apply perturbed attention guidance to the noise prediction.
104106
@@ -107,9 +109,11 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
107109
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
108110
guidance_scale (float): The scale factor for the guidance term.
109111
t (int): The current time step.
112+
return_pred_text (bool): Whether to return the text noise prediction.
110113
111114
Returns:
112-
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
115+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
116+
perturbed attention guidance and the text noise prediction.
113117
"""
114118
pag_scale = self._get_pag_scale(t)
115119
if do_classifier_free_guidance:
@@ -122,6 +126,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
122126
else:
123127
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
124128
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
129+
if return_pred_text:
130+
return noise_pred, noise_pred_text
125131
return noise_pred
126132

127133
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):

src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,8 @@ def __call__(
893893

894894
# perform guidance
895895
if self.do_perturbed_attention_guidance:
896-
noise_pred = self._apply_perturbed_attention_guidance(
897-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
896+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
897+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
898898
)
899899
elif self.do_classifier_free_guidance:
900900
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

src/diffusers/pipelines/pag/pipeline_pag_sd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -993,8 +993,8 @@ def __call__(
993993

994994
# perform guidance
995995
if self.do_perturbed_attention_guidance:
996-
noise_pred = self._apply_perturbed_attention_guidance(
997-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
996+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
997+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
998998
)
999999

10001000
elif self.do_classifier_free_guidance:

src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,8 @@ def __call__(
12371237

12381238
# perform guidance
12391239
if self.do_perturbed_attention_guidance:
1240-
noise_pred = self._apply_perturbed_attention_guidance(
1241-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1240+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1241+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
12421242
)
12431243

12441244
elif self.do_classifier_free_guidance:

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,8 +1437,8 @@ def denoising_value_valid(dnv):
14371437

14381438
# perform guidance
14391439
if self.do_perturbed_attention_guidance:
1440-
noise_pred = self._apply_perturbed_attention_guidance(
1441-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1440+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1441+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
14421442
)
14431443
elif self.do_classifier_free_guidance:
14441444
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

0 commit comments

Comments
 (0)