Skip to content

Commit 08b1aeb

Browse files
authored
Merge branch 'main' into ipadapter-flux
2 parents 188a515 + 2739241 commit 08b1aeb

File tree

9 files changed

+355
-169
lines changed

9 files changed

+355
-169
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ image
5656

5757
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
5858

59-
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
59+
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
6060

6161
```python
6262
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
8585

8686
You can also merge different adapter checkpoints for inference to blend their styles together.
8787

88-
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
88+
Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
8989

9090
```python
9191
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
114114
> [!TIP]
115115
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
116116
117-
To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
117+
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
118118

119119
```python
120120
pipe.set_adapters("toy")
@@ -127,7 +127,7 @@ image = pipe(
127127
image
128128
```
129129

130-
Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
130+
Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
131131

132132
```python
133133
pipe.disable_lora()
@@ -140,7 +140,8 @@ image
140140
![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)
141141

142142
### Customize adapters strength
143-
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
143+
144+
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
144145

145146
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
146147
```python
@@ -195,7 +196,7 @@ image
195196

196197
![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png)
197198

198-
## Manage active adapters
199+
## Manage adapters
199200

200201
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
201202

@@ -212,3 +213,11 @@ list_adapters_component_wise = pipe.get_list_adapters()
212213
list_adapters_component_wise
213214
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
214215
```
216+
217+
The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
218+
219+
```py
220+
pipe.delete_adapters("toy")
221+
pipe.get_active_adapters()
222+
["pixel"]
223+
```

src/diffusers/models/attention_processor.py

Lines changed: 172 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906906
return self.processor(self, hidden_states)
907907

908908

909+
class MochiAttention(nn.Module):
910+
def __init__(
911+
self,
912+
query_dim: int,
913+
added_kv_proj_dim: int,
914+
processor: "MochiAttnProcessor2_0",
915+
heads: int = 8,
916+
dim_head: int = 64,
917+
dropout: float = 0.0,
918+
bias: bool = False,
919+
added_proj_bias: bool = True,
920+
out_dim: Optional[int] = None,
921+
out_context_dim: Optional[int] = None,
922+
out_bias: bool = True,
923+
context_pre_only: bool = False,
924+
eps: float = 1e-5,
925+
):
926+
super().__init__()
927+
from .normalization import MochiRMSNorm
928+
929+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
930+
self.out_dim = out_dim if out_dim is not None else query_dim
931+
self.out_context_dim = out_context_dim if out_context_dim else query_dim
932+
self.context_pre_only = context_pre_only
933+
934+
self.heads = out_dim // dim_head if out_dim is not None else heads
935+
936+
self.norm_q = MochiRMSNorm(dim_head, eps, True)
937+
self.norm_k = MochiRMSNorm(dim_head, eps, True)
938+
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
939+
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
940+
941+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
942+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
943+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
944+
945+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
946+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
947+
if self.context_pre_only is not None:
948+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
949+
950+
self.to_out = nn.ModuleList([])
951+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
952+
self.to_out.append(nn.Dropout(dropout))
953+
954+
if not self.context_pre_only:
955+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
956+
957+
self.processor = processor
958+
959+
def forward(
960+
self,
961+
hidden_states: torch.Tensor,
962+
encoder_hidden_states: Optional[torch.Tensor] = None,
963+
attention_mask: Optional[torch.Tensor] = None,
964+
**kwargs,
965+
):
966+
return self.processor(
967+
self,
968+
hidden_states,
969+
encoder_hidden_states=encoder_hidden_states,
970+
attention_mask=attention_mask,
971+
**kwargs,
972+
)
973+
974+
975+
class MochiAttnProcessor2_0:
976+
"""Attention processor used in Mochi."""
977+
978+
def __init__(self):
979+
if not hasattr(F, "scaled_dot_product_attention"):
980+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
981+
982+
def __call__(
983+
self,
984+
attn: "MochiAttention",
985+
hidden_states: torch.Tensor,
986+
encoder_hidden_states: torch.Tensor,
987+
attention_mask: torch.Tensor,
988+
image_rotary_emb: Optional[torch.Tensor] = None,
989+
) -> torch.Tensor:
990+
query = attn.to_q(hidden_states)
991+
key = attn.to_k(hidden_states)
992+
value = attn.to_v(hidden_states)
993+
994+
query = query.unflatten(2, (attn.heads, -1))
995+
key = key.unflatten(2, (attn.heads, -1))
996+
value = value.unflatten(2, (attn.heads, -1))
997+
998+
if attn.norm_q is not None:
999+
query = attn.norm_q(query)
1000+
if attn.norm_k is not None:
1001+
key = attn.norm_k(key)
1002+
1003+
encoder_query = attn.add_q_proj(encoder_hidden_states)
1004+
encoder_key = attn.add_k_proj(encoder_hidden_states)
1005+
encoder_value = attn.add_v_proj(encoder_hidden_states)
1006+
1007+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
1008+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
1009+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
1010+
1011+
if attn.norm_added_q is not None:
1012+
encoder_query = attn.norm_added_q(encoder_query)
1013+
if attn.norm_added_k is not None:
1014+
encoder_key = attn.norm_added_k(encoder_key)
1015+
1016+
if image_rotary_emb is not None:
1017+
1018+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
1019+
x_even = x[..., 0::2].float()
1020+
x_odd = x[..., 1::2].float()
1021+
1022+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
1023+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
1024+
1025+
return torch.stack([cos, sin], dim=-1).flatten(-2)
1026+
1027+
query = apply_rotary_emb(query, *image_rotary_emb)
1028+
key = apply_rotary_emb(key, *image_rotary_emb)
1029+
1030+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
1031+
encoder_query, encoder_key, encoder_value = (
1032+
encoder_query.transpose(1, 2),
1033+
encoder_key.transpose(1, 2),
1034+
encoder_value.transpose(1, 2),
1035+
)
1036+
1037+
sequence_length = query.size(2)
1038+
encoder_sequence_length = encoder_query.size(2)
1039+
total_length = sequence_length + encoder_sequence_length
1040+
1041+
batch_size, heads, _, dim = query.shape
1042+
attn_outputs = []
1043+
for idx in range(batch_size):
1044+
mask = attention_mask[idx][None, :]
1045+
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
1046+
1047+
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
1048+
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
1049+
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
1050+
1051+
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
1052+
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
1053+
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
1054+
1055+
attn_output = F.scaled_dot_product_attention(
1056+
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
1057+
)
1058+
valid_sequence_length = attn_output.size(2)
1059+
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
1060+
attn_outputs.append(attn_output)
1061+
1062+
hidden_states = torch.cat(attn_outputs, dim=0)
1063+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
1064+
1065+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
1066+
(sequence_length, encoder_sequence_length), dim=1
1067+
)
1068+
1069+
# linear proj
1070+
hidden_states = attn.to_out[0](hidden_states)
1071+
# dropout
1072+
hidden_states = attn.to_out[1](hidden_states)
1073+
1074+
if hasattr(attn, "to_add_out"):
1075+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1076+
1077+
return hidden_states, encoder_hidden_states
1078+
1079+
9091080
class AttnProcessor:
9101081
r"""
9111082
Default processor for performing attention-related computations.
@@ -4013,94 +4184,6 @@ def __call__(
40134184
return hidden_states
40144185

40154186

4016-
class MochiAttnProcessor2_0:
4017-
"""Attention processor used in Mochi."""
4018-
4019-
def __init__(self):
4020-
if not hasattr(F, "scaled_dot_product_attention"):
4021-
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
4022-
4023-
def __call__(
4024-
self,
4025-
attn: Attention,
4026-
hidden_states: torch.Tensor,
4027-
encoder_hidden_states: torch.Tensor,
4028-
attention_mask: Optional[torch.Tensor] = None,
4029-
image_rotary_emb: Optional[torch.Tensor] = None,
4030-
) -> torch.Tensor:
4031-
query = attn.to_q(hidden_states)
4032-
key = attn.to_k(hidden_states)
4033-
value = attn.to_v(hidden_states)
4034-
4035-
query = query.unflatten(2, (attn.heads, -1))
4036-
key = key.unflatten(2, (attn.heads, -1))
4037-
value = value.unflatten(2, (attn.heads, -1))
4038-
4039-
if attn.norm_q is not None:
4040-
query = attn.norm_q(query)
4041-
if attn.norm_k is not None:
4042-
key = attn.norm_k(key)
4043-
4044-
encoder_query = attn.add_q_proj(encoder_hidden_states)
4045-
encoder_key = attn.add_k_proj(encoder_hidden_states)
4046-
encoder_value = attn.add_v_proj(encoder_hidden_states)
4047-
4048-
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
4049-
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
4050-
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
4051-
4052-
if attn.norm_added_q is not None:
4053-
encoder_query = attn.norm_added_q(encoder_query)
4054-
if attn.norm_added_k is not None:
4055-
encoder_key = attn.norm_added_k(encoder_key)
4056-
4057-
if image_rotary_emb is not None:
4058-
4059-
def apply_rotary_emb(x, freqs_cos, freqs_sin):
4060-
x_even = x[..., 0::2].float()
4061-
x_odd = x[..., 1::2].float()
4062-
4063-
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
4064-
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
4065-
4066-
return torch.stack([cos, sin], dim=-1).flatten(-2)
4067-
4068-
query = apply_rotary_emb(query, *image_rotary_emb)
4069-
key = apply_rotary_emb(key, *image_rotary_emb)
4070-
4071-
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
4072-
encoder_query, encoder_key, encoder_value = (
4073-
encoder_query.transpose(1, 2),
4074-
encoder_key.transpose(1, 2),
4075-
encoder_value.transpose(1, 2),
4076-
)
4077-
4078-
sequence_length = query.size(2)
4079-
encoder_sequence_length = encoder_query.size(2)
4080-
4081-
query = torch.cat([query, encoder_query], dim=2)
4082-
key = torch.cat([key, encoder_key], dim=2)
4083-
value = torch.cat([value, encoder_value], dim=2)
4084-
4085-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
4086-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
4087-
hidden_states = hidden_states.to(query.dtype)
4088-
4089-
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
4090-
(sequence_length, encoder_sequence_length), dim=1
4091-
)
4092-
4093-
# linear proj
4094-
hidden_states = attn.to_out[0](hidden_states)
4095-
# dropout
4096-
hidden_states = attn.to_out[1](hidden_states)
4097-
4098-
if getattr(attn, "to_add_out", None) is not None:
4099-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
4100-
4101-
return hidden_states, encoder_hidden_states
4102-
4103-
41044187
class FusedAttnProcessor2_0:
41054188
r"""
41064189
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -5814,13 +5897,13 @@ def __call__(
58145897
AttnProcessorNPU,
58155898
AttnProcessor2_0,
58165899
MochiVaeAttnProcessor2_0,
5900+
MochiAttnProcessor2_0,
58175901
StableAudioAttnProcessor2_0,
58185902
HunyuanAttnProcessor2_0,
58195903
FusedHunyuanAttnProcessor2_0,
58205904
PAGHunyuanAttnProcessor2_0,
58215905
PAGCFGHunyuanAttnProcessor2_0,
58225906
LuminaAttnProcessor2_0,
5823-
MochiAttnProcessor2_0,
58245907
FusedAttnProcessor2_0,
58255908
CustomDiffusionXFormersAttnProcessor,
58265909
CustomDiffusionAttnProcessor2_0,

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,6 @@ def forward(self, latent):
542542
height, width = latent.shape[-2:]
543543
else:
544544
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
545-
546545
latent = self.proj(latent)
547546
if self.flatten:
548547
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC

0 commit comments

Comments
 (0)