Skip to content

Commit b0aa5cb

Browse files
committed
IPAdapterTimeImageProjectionBlock now uses original attention implementation
1 parent f60751f commit b0aa5cb

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/diffusers/models/embeddings.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,10 +2107,10 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
21072107
class IPAdapterTimeImageProjectionBlock(nn.Module):
21082108
def __init__(
21092109
self,
2110-
hidden_dim: int = 768,
2110+
hidden_dim: int = 1280,
21112111
dim_head: int = 64,
2112-
heads: int = 16,
2113-
ffn_ratio: float = 4,
2112+
heads: int = 20,
2113+
ffn_ratio: int = 4,
21142114
) -> None:
21152115
super().__init__()
21162116
from .attention import FeedForward
@@ -2124,7 +2124,6 @@ def __init__(
21242124
heads=heads,
21252125
bias=False,
21262126
out_bias=False,
2127-
processor=FusedAttnProcessor2_0(),
21282127
)
21292128
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
21302129

@@ -2133,21 +2132,47 @@ def __init__(
21332132
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
21342133
self.adaln_norm = nn.LayerNorm(hidden_dim)
21352134

2136-
# Set scale and fuse KV
2135+
# Set attention scale and fuse KV
21372136
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
21382137
self.attn.fuse_projections()
21392138
self.attn.to_k = None
21402139
self.attn.to_v = None
21412140

21422141
def forward(self, x, latents, timestep_emb):
2142+
# Shift and scale for AdaLayerNorm
21432143
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
21442144
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
21452145

2146+
# Fused Attention
21462147
residual = latents
21472148
x = self.ln0(x)
21482149
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
2149-
latents = self.attn(latents, torch.cat((x, latents), dim=-2)) + residual
21502150

2151+
batch_size = latents.shape[0]
2152+
2153+
query = self.attn.to_q(latents)
2154+
kv_input = torch.cat((x, latents), dim=-2)
2155+
kv = self.attn.to_kv(kv_input)
2156+
split_size = kv.shape[-1] // 2
2157+
key, value = torch.split(kv, split_size, dim=-1)
2158+
2159+
inner_dim = key.shape[-1]
2160+
head_dim = inner_dim // self.attn.heads
2161+
2162+
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2163+
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2164+
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2165+
2166+
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
2167+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
2168+
latents = weight @ value
2169+
2170+
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
2171+
latents = self.attn.to_out[0](latents)
2172+
latents = self.attn.to_out[1](latents)
2173+
latents = latents + residual
2174+
2175+
## FeedForward
21512176
residual = latents
21522177
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
21532178
return self.ff(latents) + residual

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool):
366366
# Convert image_proj state dict to diffusers
367367
image_proj_state_dict = {}
368368
for key, value in state_dict["image_proj"].items():
369-
for idx in range(4):
369+
if key.startswith("layers."):
370+
idx = key.split(".")[1]
370371
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
371372
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
372373
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")

0 commit comments

Comments
 (0)