Skip to content

Commit 4ba374a

Browse files
committed
Refactor of image_proj (testing)
1 parent a87895e commit 4ba374a

File tree

3 files changed

+70
-110
lines changed

3 files changed

+70
-110
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3927,9 +3927,8 @@ def __call__(
39273927
key = attn.norm_k(key)
39283928

39293929
# the output of sdp = (batch, num_heads, seq_len, head_dim)
3930-
# TODO: add support for attn.scale when we move to Torch 2.1
39313930
hidden_states = F.scaled_dot_product_attention(
3932-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3931+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
39333932
)
39343933

39353934
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

src/diffusers/models/embeddings.py

Lines changed: 44 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..utils import deprecate
2323
from .activations import FP32SiLU, get_activation
24-
from .attention_processor import Attention
24+
from .attention_processor import Attention, FusedAttnProcessor2_0
2525

2626

2727
def get_timestep_embedding(
@@ -2104,76 +2104,55 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
21042104
return out
21052105

21062106

2107-
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2108-
class TimePerceiverAttention(nn.Module):
2107+
class IPAdapterTimeImageProjectionBlock(nn.Module):
21092108
def __init__(
21102109
self,
2111-
*,
2112-
dim: int,
2110+
hidden_dim: int = 768,
21132111
dim_head: int = 64,
2114-
heads: int = 8,
2112+
heads: int = 16,
2113+
ffn_ratio: float = 4,
21152114
) -> None:
21162115
super().__init__()
2116+
from .attention import FeedForward
21172117

2118-
self.scale = dim_head**-0.5
2119-
self.dim_head = dim_head
2120-
self.heads = heads
2121-
inner_dim = dim_head * heads
2122-
2123-
self.norm1 = nn.LayerNorm(dim)
2124-
self.norm2 = nn.LayerNorm(dim)
2125-
2126-
self.to_q = nn.Linear(dim, inner_dim, bias=False)
2127-
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
2128-
self.to_out = nn.Linear(inner_dim, dim, bias=False)
2129-
2130-
def forward(self, x, latents, shift=None, scale=None):
2131-
"""
2132-
Args:
2133-
x (torch.Tensor): image features
2134-
shape (b, n1, D)
2135-
latent (torch.Tensor): latent features
2136-
shape (b, n2, D)
2137-
"""
2138-
2139-
def reshape_tensor(x, heads):
2140-
bs, length, _ = x.shape
2141-
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
2142-
x = x.view(bs, length, heads, -1)
2143-
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
2144-
x = x.transpose(1, 2)
2145-
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
2146-
return x.reshape(bs, heads, length, -1)
2147-
2148-
x = self.norm1(x)
2149-
latents = self.norm2(latents)
2150-
2151-
if shift is not None and scale is not None:
2152-
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
2153-
2154-
b, l, _ = latents.shape
2118+
self.ln0 = nn.LayerNorm(hidden_dim)
2119+
self.ln1 = nn.LayerNorm(hidden_dim)
2120+
self.attn = Attention(
2121+
query_dim=hidden_dim,
2122+
cross_attention_dim=hidden_dim,
2123+
dim_head=dim_head,
2124+
heads=heads,
2125+
bias=False,
2126+
out_bias=False,
2127+
processor=FusedAttnProcessor2_0(),
2128+
)
2129+
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
21552130

2156-
q = self.to_q(latents)
2157-
kv_input = torch.cat((x, latents), dim=-2)
2158-
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
2131+
# AdaLayerNorm
2132+
self.adaln_silu = nn.SiLU()
2133+
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
2134+
self.adaln_norm = nn.LayerNorm(hidden_dim)
21592135

2160-
q = reshape_tensor(q, self.heads)
2161-
k = reshape_tensor(k, self.heads)
2162-
v = reshape_tensor(v, self.heads)
2136+
# Custom scale cannot be passed in constructor
2137+
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
2138+
self.attn.fuse_projections()
2139+
self.attn.to_k = None
2140+
self.attn.to_v = None
21632141

2164-
# attention
2165-
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
2166-
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
2167-
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
2168-
out = weight @ v
2142+
def forward(self, x, latents, timestep_emb):
2143+
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaln_proj(self.adaln_silu(timestep_emb))
21692144

2170-
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
2145+
x = self.ln0(x)
2146+
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
2147+
latents = self.attn(x, latents) + latents
21712148

2172-
return self.to_out(out)
2149+
residual = latents
2150+
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
2151+
return self.ff(latents) + residual
21732152

21742153

21752154
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2176-
class TimePerceiverResampler(nn.Module):
2155+
class IPAdapterTimeImageProjection(nn.Module):
21772156
def __init__(
21782157
self,
21792158
embed_dim: int = 1152,
@@ -2189,65 +2168,32 @@ def __init__(
21892168
timestep_freq_shift: int = 0,
21902169
) -> None:
21912170
super().__init__()
2192-
21932171
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
21942172
self.proj_in = nn.Linear(embed_dim, hidden_dim)
21952173
self.proj_out = nn.Linear(hidden_dim, output_dim)
21962174
self.norm_out = nn.LayerNorm(output_dim)
2197-
2198-
ff_inner_dim = int(hidden_dim * ffn_ratio)
2199-
self.layers = nn.ModuleList([])
2200-
for _ in range(depth):
2201-
self.layers.append(
2202-
nn.ModuleList(
2203-
[
2204-
# msa
2205-
TimePerceiverAttention(dim=hidden_dim, dim_head=dim_head, heads=heads),
2206-
# ff
2207-
nn.Sequential(
2208-
nn.LayerNorm(hidden_dim),
2209-
nn.Linear(hidden_dim, ff_inner_dim, bias=False),
2210-
nn.GELU(),
2211-
nn.Linear(ff_inner_dim, hidden_dim, bias=False),
2212-
),
2213-
# adaLN
2214-
nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, ff_inner_dim, bias=True)),
2215-
]
2216-
)
2217-
)
2218-
2219-
# Time
2175+
self.layers = nn.ModuleList(
2176+
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
2177+
)
22202178
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
22212179
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
22222180

2223-
def forward(self, x, timestep, need_temb=False):
2181+
def forward(self, x, timestep):
22242182
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
2225-
timestep_emb = self.time_embedding(timestep_emb, None)
2183+
timestep_emb = self.time_embedding(timestep_emb)
22262184

22272185
latents = self.latents.repeat(x.size(0), 1, 1)
22282186

22292187
x = self.proj_in(x)
22302188
x = x + timestep_emb[:, None]
22312189

2232-
for attn, ff, adaLN_modulation in self.layers:
2233-
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
2234-
latents = attn(x, latents, shift_msa, scale_msa) + latents
2235-
2236-
res = latents
2237-
for idx_ff in range(len(ff)):
2238-
layer_ff = ff[idx_ff]
2239-
latents = layer_ff(latents)
2240-
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
2241-
latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
2242-
latents = latents + res
2190+
for block in self.layers:
2191+
latents = block(x, latents, timestep_emb)
22432192

22442193
latents = self.proj_out(latents)
22452194
latents = self.norm_out(latents)
22462195

2247-
if need_temb:
2248-
return latents, timestep_emb
2249-
else:
2250-
return latents
2196+
return latents, timestep_emb
22512197

22522198

22532199
class MultiIPAdapterImageProjection(nn.Module):

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
3232
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
3333
from ...utils.torch_utils import maybe_allow_in_graph
34-
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed, TimePerceiverResampler
34+
from ..embeddings import CombinedTimestepTextProjEmbeddings, IPAdapterTimeImageProjection, PatchEmbed
3535
from ..modeling_outputs import Transformer2DModelOutput
3636

3737

@@ -363,16 +363,31 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool):
363363

364364
self.set_attn_processor(attn_procs)
365365

366+
# Convert image_proj state dict to diffusers
367+
image_proj_state_dict = {}
368+
for key, value in state_dict["image_proj"].items():
369+
for idx in range(4):
370+
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
371+
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
372+
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
373+
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
374+
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
375+
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
376+
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
377+
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
378+
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
379+
image_proj_state_dict[key] = value
380+
366381
# Image projetion parameters
367-
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
368-
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
369-
hidden_dim = state_dict["image_proj"]["latents"].shape[2]
370-
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
371-
num_queries = state_dict["image_proj"]["latents"].shape[1]
372-
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
382+
embed_dim = image_proj_state_dict["proj_in.weight"].shape[1]
383+
output_dim = image_proj_state_dict["proj_out.weight"].shape[0]
384+
hidden_dim = image_proj_state_dict["proj_in.weight"].shape[0]
385+
heads = image_proj_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
386+
num_queries = image_proj_state_dict["latents"].shape[1]
387+
timestep_in_dim = image_proj_state_dict["time_embedding.linear_1.weight"].shape[1]
373388

374389
# Image projection
375-
self.image_proj = TimePerceiverResampler(
390+
self.image_proj = IPAdapterTimeImageProjection(
376391
embed_dim=embed_dim,
377392
output_dim=output_dim,
378393
hidden_dim=hidden_dim,
@@ -382,9 +397,9 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool):
382397
).to(device=self.device, dtype=self.dtype)
383398

384399
if not low_cpu_mem_usage:
385-
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
400+
self.image_proj.load_state_dict(image_proj_state_dict, strict=True)
386401
else:
387-
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
402+
load_model_dict_into_meta(self.image_proj, image_proj_state_dict, device=self.device, dtype=self.dtype)
388403

389404
def forward(
390405
self,

0 commit comments

Comments
 (0)