Skip to content

Commit 5567438

Browse files
committed
Added support for single IPAdapter on SD3.5 pipeline
1 parent 0af910b commit 5567438

File tree

4 files changed

+583
-13
lines changed

4 files changed

+583
-13
lines changed

src/diffusers/models/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
189189

190190
def forward(
191191
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192-
joint_attention_kwargs=None,
192+
joint_attention_kwargs: Dict[str, Any] = {}
193193
):
194194
if self.use_dual_attention:
195195
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -208,15 +208,15 @@ def forward(
208208
# Attention.
209209
attn_output, context_attn_output = self.attn(
210210
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211-
**({} if joint_attention_kwargs is None else joint_attention_kwargs),
211+
**joint_attention_kwargs
212212
)
213213

214214
# Process attention outputs for the `hidden_states`.
215215
attn_output = gate_msa.unsqueeze(1) * attn_output
216216
hidden_states = hidden_states + attn_output
217217

218218
if self.use_dual_attention:
219-
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
219+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
220220
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
221221
hidden_states = hidden_states + attn_output2
222222

src/diffusers/models/attention_processor.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn.functional as F
2020
from torch import nn
21+
from einops import rearrange
2122

2223
from ..image_processor import IPAdapterMaskProcessor
2324
from ..utils import deprecate, logging
@@ -4800,6 +4801,144 @@ def __call__(
48004801
hidden_states = hidden_states / attn.rescale_output_factor
48014802

48024803
return hidden_states
4804+
4805+
4806+
class IPAdapterJointAttnProcessor2_0(torch.nn.Module):
4807+
"""Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections."""
4808+
4809+
def __init__(
4810+
self,
4811+
hidden_size: int,
4812+
ip_hidden_states_dim: int,
4813+
head_dim: int,
4814+
timesteps_emb_dim: int = 1280,
4815+
scale: float = 0.5
4816+
):
4817+
super().__init__()
4818+
4819+
# To prevent circular import
4820+
from .normalization import RMSNorm, AdaLayerNorm
4821+
4822+
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2,
4823+
norm_eps=1e-6, chunk_dim=1)
4824+
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
4825+
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
4826+
self.norm_q = RMSNorm(head_dim, 1e-6)
4827+
self.norm_k = RMSNorm(head_dim, 1e-6)
4828+
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
4829+
self.scale = scale
4830+
4831+
def __call__(
4832+
self,
4833+
attn: Attention,
4834+
hidden_states: torch.FloatTensor,
4835+
encoder_hidden_states: torch.FloatTensor = None,
4836+
attention_mask: Optional[torch.FloatTensor] = None,
4837+
ip_hidden_states: torch.FloatTensor = None,
4838+
temb: torch.FloatTensor = None
4839+
) -> torch.FloatTensor:
4840+
residual = hidden_states
4841+
4842+
batch_size = hidden_states.shape[0]
4843+
4844+
# `sample` projections.
4845+
query = attn.to_q(hidden_states)
4846+
key = attn.to_k(hidden_states)
4847+
value = attn.to_v(hidden_states)
4848+
img_query = query
4849+
img_key = key
4850+
img_value = value
4851+
4852+
inner_dim = key.shape[-1]
4853+
head_dim = inner_dim // attn.heads
4854+
4855+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4856+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4857+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4858+
4859+
if attn.norm_q is not None:
4860+
query = attn.norm_q(query)
4861+
if attn.norm_k is not None:
4862+
key = attn.norm_k(key)
4863+
4864+
# `context` projections.
4865+
if encoder_hidden_states is not None:
4866+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
4867+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
4868+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
4869+
4870+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
4871+
batch_size, -1, attn.heads, head_dim
4872+
).transpose(1, 2)
4873+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
4874+
batch_size, -1, attn.heads, head_dim
4875+
).transpose(1, 2)
4876+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
4877+
batch_size, -1, attn.heads, head_dim
4878+
).transpose(1, 2)
4879+
4880+
if attn.norm_added_q is not None:
4881+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
4882+
if attn.norm_added_k is not None:
4883+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
4884+
4885+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
4886+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
4887+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
4888+
4889+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
4890+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
4891+
hidden_states = hidden_states.to(query.dtype)
4892+
4893+
if encoder_hidden_states is not None:
4894+
# Split the attention outputs.
4895+
hidden_states, encoder_hidden_states = (
4896+
hidden_states[:, : residual.shape[1]],
4897+
hidden_states[:, residual.shape[1] :],
4898+
)
4899+
if not attn.context_pre_only:
4900+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
4901+
4902+
# IP Adapter
4903+
if self.scale != 0 and ip_hidden_states is not None:
4904+
# Norm image features
4905+
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
4906+
4907+
# To k and v
4908+
ip_key = self.to_k_ip(norm_ip_hidden_states)
4909+
ip_value = self.to_v_ip(norm_ip_hidden_states)
4910+
4911+
# Reshape
4912+
img_query = rearrange(img_query, 'b l (h d) -> b h l d', h=attn.heads)
4913+
img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
4914+
img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
4915+
ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
4916+
ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
4917+
4918+
# Norm
4919+
img_query = self.norm_q(img_query)
4920+
img_key = self.norm_k(img_key)
4921+
ip_key = self.norm_ip_k(ip_key)
4922+
4923+
# cat img
4924+
img_key = torch.cat([img_key, ip_key], dim=2)
4925+
img_value = torch.cat([img_value, ip_value], dim=2)
4926+
4927+
ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False)
4928+
ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
4929+
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
4930+
4931+
hidden_states = hidden_states + ip_hidden_states * self.scale
4932+
4933+
# linear proj
4934+
hidden_states = attn.to_out[0](hidden_states)
4935+
# dropout
4936+
hidden_states = attn.to_out[1](hidden_states)
4937+
4938+
if encoder_hidden_states is not None:
4939+
return hidden_states, encoder_hidden_states
4940+
else:
4941+
return hidden_states
48034942

48044943

48054944
class PAGIdentitySelfAttnProcessor2_0:
@@ -5089,6 +5228,7 @@ def __init__(self):
50895228
IPAdapterAttnProcessor,
50905229
IPAdapterAttnProcessor2_0,
50915230
IPAdapterXFormersAttnProcessor,
5231+
IPAdapterJointAttnProcessor2_0,
50925232
PAGIdentitySelfAttnProcessor2_0,
50935233
PAGCFGIdentitySelfAttnProcessor2_0,
50945234
LoRAAttnProcessor,

src/diffusers/models/embeddings.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,154 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
19991999
return out
20002000

20012001

2002+
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2003+
class TimePerceiverAttention(nn.Module):
2004+
def __init__(
2005+
self,
2006+
*,
2007+
dim: int,
2008+
dim_head: int = 64,
2009+
heads: int = 8,
2010+
) -> None:
2011+
super().__init__()
2012+
2013+
self.scale = dim_head ** -0.5
2014+
self.dim_head = dim_head
2015+
self.heads = heads
2016+
inner_dim = dim_head * heads
2017+
2018+
self.norm1 = nn.LayerNorm(dim)
2019+
self.norm2 = nn.LayerNorm(dim)
2020+
2021+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
2022+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
2023+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
2024+
2025+
def forward(self, x, latents, shift=None, scale=None):
2026+
"""
2027+
Args:
2028+
x (torch.Tensor): image features
2029+
shape (b, n1, D)
2030+
latent (torch.Tensor): latent features
2031+
shape (b, n2, D)
2032+
"""
2033+
def reshape_tensor(x, heads):
2034+
bs, length, _ = x.shape
2035+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
2036+
x = x.view(bs, length, heads, -1)
2037+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
2038+
x = x.transpose(1, 2)
2039+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
2040+
return x.reshape(bs, heads, length, -1)
2041+
2042+
x = self.norm1(x)
2043+
latents = self.norm2(latents)
2044+
2045+
if shift is not None and scale is not None:
2046+
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
2047+
2048+
b, l, _ = latents.shape
2049+
2050+
q = self.to_q(latents)
2051+
kv_input = torch.cat((x, latents), dim=-2)
2052+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
2053+
2054+
q = reshape_tensor(q, self.heads)
2055+
k = reshape_tensor(k, self.heads)
2056+
v = reshape_tensor(v, self.heads)
2057+
2058+
# attention
2059+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
2060+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
2061+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
2062+
out = weight @ v
2063+
2064+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
2065+
2066+
return self.to_out(out)
2067+
2068+
2069+
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2070+
class TimePerceiverResampler(nn.Module):
2071+
def __init__(
2072+
self,
2073+
embed_dim: int = 1152,
2074+
output_dim: int = 2432,
2075+
hidden_dim: int = 1280,
2076+
depth: int = 4,
2077+
dim_head: int = 64,
2078+
heads: int = 20,
2079+
num_queries: int = 64,
2080+
ffn_ratio: int = 4,
2081+
timestep_in_dim: int = 320,
2082+
timestep_flip_sin_to_cos: bool = True,
2083+
timestep_freq_shift: int = 0,
2084+
) -> None:
2085+
super().__init__()
2086+
2087+
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim ** 0.5)
2088+
self.proj_in = nn.Linear(embed_dim, hidden_dim)
2089+
self.proj_out = nn.Linear(hidden_dim, output_dim)
2090+
self.norm_out = nn.LayerNorm(output_dim)
2091+
2092+
ff_inner_dim = int(hidden_dim * ffn_ratio)
2093+
self.layers = nn.ModuleList([])
2094+
for _ in range(depth):
2095+
self.layers.append(
2096+
nn.ModuleList(
2097+
[
2098+
# msa
2099+
TimePerceiverAttention(dim=hidden_dim, dim_head=dim_head, heads=heads),
2100+
# ff
2101+
nn.Sequential(
2102+
nn.LayerNorm(hidden_dim),
2103+
nn.Linear(hidden_dim, ff_inner_dim, bias=False),
2104+
nn.GELU(),
2105+
nn.Linear(ff_inner_dim, hidden_dim, bias=False),
2106+
),
2107+
# adaLN
2108+
nn.Sequential(
2109+
nn.SiLU(),
2110+
nn.Linear(hidden_dim, ff_inner_dim, bias=True)
2111+
)
2112+
]
2113+
)
2114+
)
2115+
2116+
# Time
2117+
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
2118+
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
2119+
2120+
def forward(self, x, timestep, need_temb=False):
2121+
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
2122+
timestep_emb = self.time_embedding(timestep_emb, None)
2123+
2124+
latents = self.latents.repeat(x.size(0), 1, 1)
2125+
2126+
x = self.proj_in(x)
2127+
x = x + timestep_emb[:, None]
2128+
2129+
for attn, ff, adaLN_modulation in self.layers:
2130+
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
2131+
latents = attn(x, latents, shift_msa, scale_msa) + latents
2132+
2133+
res = latents
2134+
for idx_ff in range(len(ff)):
2135+
layer_ff = ff[idx_ff]
2136+
latents = layer_ff(latents)
2137+
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
2138+
latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
2139+
latents = latents + res
2140+
2141+
latents = self.proj_out(latents)
2142+
latents = self.norm_out(latents)
2143+
2144+
if need_temb:
2145+
return latents, timestep_emb
2146+
else:
2147+
return latents
2148+
2149+
20022150
class MultiIPAdapterImageProjection(nn.Module):
20032151
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
20042152
super().__init__()

0 commit comments

Comments
 (0)