Skip to content

Commit a09d104

Browse files
committed
Merge branch 'aritra/qunat-blog' of github.com:ariG23498/diffusers into aritra/qunat-blog
OK
2 parents daccd75 + ee79bf5 commit a09d104

File tree

7 files changed

+27
-15
lines changed

7 files changed

+27
-15
lines changed

docs/source/en/using-diffusers/loading_adapters.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,16 @@ The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads L
134134
- the LoRA weights don't have separate identifiers for the UNet and text encoder
135135
- the LoRA weights have separate identifiers for the UNet and text encoder
136136

137-
But if you only need to load LoRA weights into the UNet, then you can use the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Let's load the [jbilcke-hf/sdxl-cinematic-1](https://huggingface.co/jbilcke-hf/sdxl-cinematic-1) LoRA:
137+
To directly load (and save) a LoRA adapter at the *model-level*, use [`~PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
138+
139+
Use the `weight_name` parameter to specify the specific weight file and the `prefix` parameter to filter for the appropriate state dicts (`"unet"` in this case) to load.
138140

139141
```py
140142
from diffusers import AutoPipelineForText2Image
141143
import torch
142144

143145
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
144-
pipeline.unet.load_attn_procs("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors")
146+
pipeline.unet.load_lora_adapter("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", prefix="unet")
145147

146148
# use cnmt in the prompt to trigger the LoRA
147149
prompt = "A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration"
@@ -153,6 +155,8 @@ image
153155
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
154156
</div>
155157

158+
Save an adapter with [`~PeftAdapterMixin.save_lora_adapter`].
159+
156160
To unload the LoRA weights, use the [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
157161

158162
```py

src/diffusers/loaders/ip_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def load_ip_adapter(
187187
state_dict = pretrained_model_name_or_path_or_dict
188188

189189
keys = list(state_dict.keys())
190-
if keys != ["image_proj", "ip_adapter"]:
190+
if "image_proj" not in keys and "ip_adapter" not in keys:
191191
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
192192

193193
state_dicts.append(state_dict)

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1908,7 +1908,9 @@ def __call__(
19081908
query = apply_rotary_emb(query, image_rotary_emb)
19091909
key = apply_rotary_emb(key, image_rotary_emb)
19101910

1911-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1911+
hidden_states = F.scaled_dot_product_attention(
1912+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1913+
)
19121914
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
19131915
hidden_states = hidden_states.to(query.dtype)
19141916

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,19 @@ def custom_forward(*inputs):
393393
return custom_forward
394394

395395
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
396-
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
397-
create_custom_forward(block),
398-
hidden_states,
399-
encoder_hidden_states,
400-
temb,
401-
**ckpt_kwargs,
402-
)
396+
if self.context_embedder is not None:
397+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
398+
create_custom_forward(block),
399+
hidden_states,
400+
encoder_hidden_states,
401+
temb,
402+
**ckpt_kwargs,
403+
)
404+
else:
405+
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
406+
hidden_states = torch.utils.checkpoint.checkpoint(
407+
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
408+
)
403409

404410
else:
405411
if self.context_embedder is not None:

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def load_model_dict_into_meta(
176176
hf_quantizer=None,
177177
keep_in_fp32_modules=None,
178178
) -> List[str]:
179+
if device is not None and not isinstance(device, (str, torch.device)):
180+
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
179181
if hf_quantizer is None:
180182
device = device or torch.device("cpu")
181183
dtype = dtype or torch.float32

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
836836
param_device = "cpu"
837837
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
838838
elif is_quant_method_bnb:
839-
param_device = torch.cuda.current_device()
839+
param_device = torch.device(torch.cuda.current_device())
840840
state_dict = load_state_dict(model_file, variant=variant)
841841
model._convert_deprecated_attention_blocks(state_dict)
842842

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

18-
import numpy as np
1918
import torch
2019
import torch.nn as nn
2120
import torch.nn.functional as F
@@ -424,8 +423,7 @@ def custom_forward(*inputs):
424423
# controlnet residual
425424
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
426425
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
427-
interval_control = int(np.ceil(interval_control))
428-
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
426+
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
429427

430428
hidden_states = self.norm_out(hidden_states, temb)
431429
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)