Skip to content

Commit ba1bfac

Browse files
authored
[Core] Refactor IPAdapterPlusImageProjection a bit (#7994)
* use IPAdapterPlusImageProjectionBlock in IPAdapterPlusImageProjection * reposition IPAdapterPlusImageProjection * refactor complete? * fix heads param retrieval. * update test dict creation method.
1 parent 5edd0b3 commit ba1bfac

File tree

3 files changed

+133
-103
lines changed

3 files changed

+133
-103
lines changed

src/diffusers/loaders/unet.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
847847
embed_dims = state_dict["proj_in.weight"].shape[1]
848848
output_dims = state_dict["proj_out.weight"].shape[0]
849849
hidden_dims = state_dict["latents"].shape[2]
850-
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
850+
attn_key_present = any("attn" in k for k in state_dict)
851+
heads = (
852+
state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
853+
if attn_key_present
854+
else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
855+
)
851856

852857
with init_context():
853858
image_projection = IPAdapterPlusImageProjection(
@@ -860,26 +865,53 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
860865

861866
for key, value in state_dict.items():
862867
diffusers_name = key.replace("0.to", "2.to")
863-
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
864-
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
865-
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
866-
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
867868

868-
if "norm1" in diffusers_name:
869-
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
870-
elif "norm2" in diffusers_name:
871-
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
872-
elif "to_kv" in diffusers_name:
869+
diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
870+
diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
871+
diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
872+
diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
873+
diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
874+
diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
875+
diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
876+
diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
877+
878+
if "to_kv" in diffusers_name:
879+
parts = diffusers_name.split(".")
880+
parts[2] = "attn"
881+
diffusers_name = ".".join(parts)
873882
v_chunk = value.chunk(2, dim=0)
874883
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
875884
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
885+
elif "to_q" in diffusers_name:
886+
parts = diffusers_name.split(".")
887+
parts[2] = "attn"
888+
diffusers_name = ".".join(parts)
889+
updated_state_dict[diffusers_name] = value
876890
elif "to_out" in diffusers_name:
891+
parts = diffusers_name.split(".")
892+
parts[2] = "attn"
893+
diffusers_name = ".".join(parts)
877894
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
878895
else:
896+
diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
897+
diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
898+
diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
899+
900+
diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
901+
diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
902+
diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
903+
904+
diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
905+
diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
906+
diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
907+
908+
diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
909+
diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
910+
diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
879911
updated_state_dict[diffusers_name] = value
880912

881913
if not low_cpu_mem_usage:
882-
image_projection.load_state_dict(updated_state_dict)
914+
image_projection.load_state_dict(updated_state_dict, strict=True)
883915
else:
884916
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
885917

src/diffusers/models/embeddings.py

Lines changed: 38 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,39 @@ def forward(self, caption):
806806
return hidden_states
807807

808808

809+
class IPAdapterPlusImageProjectionBlock(nn.Module):
810+
def __init__(
811+
self,
812+
embed_dims: int = 768,
813+
dim_head: int = 64,
814+
heads: int = 16,
815+
ffn_ratio: float = 4,
816+
) -> None:
817+
super().__init__()
818+
from .attention import FeedForward
819+
820+
self.ln0 = nn.LayerNorm(embed_dims)
821+
self.ln1 = nn.LayerNorm(embed_dims)
822+
self.attn = Attention(
823+
query_dim=embed_dims,
824+
dim_head=dim_head,
825+
heads=heads,
826+
out_bias=False,
827+
)
828+
self.ff = nn.Sequential(
829+
nn.LayerNorm(embed_dims),
830+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
831+
)
832+
833+
def forward(self, x, latents, residual):
834+
encoder_hidden_states = self.ln0(x)
835+
latents = self.ln1(latents)
836+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
837+
latents = self.attn(latents, encoder_hidden_states) + residual
838+
latents = self.ff(latents) + latents
839+
return latents
840+
841+
809842
class IPAdapterPlusImageProjection(nn.Module):
810843
"""Resampler of IP-Adapter Plus.
811844
@@ -834,35 +867,16 @@ def __init__(
834867
ffn_ratio: float = 4,
835868
) -> None:
836869
super().__init__()
837-
from .attention import FeedForward # Lazy import to avoid circular import
838-
839870
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
840871

841872
self.proj_in = nn.Linear(embed_dims, hidden_dims)
842873

843874
self.proj_out = nn.Linear(hidden_dims, output_dims)
844875
self.norm_out = nn.LayerNorm(output_dims)
845876

846-
self.layers = nn.ModuleList([])
847-
for _ in range(depth):
848-
self.layers.append(
849-
nn.ModuleList(
850-
[
851-
nn.LayerNorm(hidden_dims),
852-
nn.LayerNorm(hidden_dims),
853-
Attention(
854-
query_dim=hidden_dims,
855-
dim_head=dim_head,
856-
heads=heads,
857-
out_bias=False,
858-
),
859-
nn.Sequential(
860-
nn.LayerNorm(hidden_dims),
861-
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
862-
),
863-
]
864-
)
865-
)
877+
self.layers = nn.ModuleList(
878+
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
879+
)
866880

867881
def forward(self, x: torch.Tensor) -> torch.Tensor:
868882
"""Forward pass.
@@ -876,52 +890,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
876890

877891
x = self.proj_in(x)
878892

879-
for ln0, ln1, attn, ff in self.layers:
893+
for block in self.layers:
880894
residual = latents
881-
882-
encoder_hidden_states = ln0(x)
883-
latents = ln1(latents)
884-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
885-
latents = attn(latents, encoder_hidden_states) + residual
886-
latents = ff(latents) + latents
895+
latents = block(x, latents, residual)
887896

888897
latents = self.proj_out(latents)
889898
return self.norm_out(latents)
890899

891900

892-
class IPAdapterPlusImageProjectionBlock(nn.Module):
893-
def __init__(
894-
self,
895-
embed_dims: int = 768,
896-
dim_head: int = 64,
897-
heads: int = 16,
898-
ffn_ratio: float = 4,
899-
) -> None:
900-
super().__init__()
901-
from .attention import FeedForward
902-
903-
self.ln0 = nn.LayerNorm(embed_dims)
904-
self.ln1 = nn.LayerNorm(embed_dims)
905-
self.attn = Attention(
906-
query_dim=embed_dims,
907-
dim_head=dim_head,
908-
heads=heads,
909-
out_bias=False,
910-
)
911-
self.ff = nn.Sequential(
912-
nn.LayerNorm(embed_dims),
913-
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
914-
)
915-
916-
def forward(self, x, latents, residual):
917-
encoder_hidden_states = self.ln0(x)
918-
latents = self.ln1(latents)
919-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
920-
latents = self.attn(latents, encoder_hidden_states) + residual
921-
latents = self.ff(latents) + latents
922-
return latents
923-
924-
925901
class IPAdapterFaceIDPlusImageProjection(nn.Module):
926902
"""FacePerceiverResampler of IP-Adapter Plus.
927903

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -146,42 +146,64 @@ def create_ip_adapter_plus_state_dict(model):
146146
)
147147

148148
ip_image_projection_state_dict = OrderedDict()
149+
keys = [k for k in image_projection.state_dict() if "layers." in k]
150+
print(keys)
149151
for k, v in image_projection.state_dict().items():
150152
if "2.to" in k:
151153
k = k.replace("2.to", "0.to")
152-
elif "3.0.weight" in k:
153-
k = k.replace("3.0.weight", "1.0.weight")
154-
elif "3.0.bias" in k:
155-
k = k.replace("3.0.bias", "1.0.bias")
156-
elif "3.0.weight" in k:
157-
k = k.replace("3.0.weight", "1.0.weight")
158-
elif "3.1.net.0.proj.weight" in k:
159-
k = k.replace("3.1.net.0.proj.weight", "1.1.weight")
160-
elif "3.net.2.weight" in k:
161-
k = k.replace("3.net.2.weight", "1.3.weight")
162-
elif "layers.0.0" in k:
163-
k = k.replace("layers.0.0", "layers.0.0.norm1")
164-
elif "layers.0.1" in k:
165-
k = k.replace("layers.0.1", "layers.0.0.norm2")
166-
elif "layers.1.0" in k:
167-
k = k.replace("layers.1.0", "layers.1.0.norm1")
168-
elif "layers.1.1" in k:
169-
k = k.replace("layers.1.1", "layers.1.0.norm2")
170-
elif "layers.2.0" in k:
171-
k = k.replace("layers.2.0", "layers.2.0.norm1")
172-
elif "layers.2.1" in k:
173-
k = k.replace("layers.2.1", "layers.2.0.norm2")
174-
175-
if "norm_cross" in k:
176-
ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
177-
elif "layer_norm" in k:
178-
ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
179-
elif "to_k" in k:
154+
elif "layers.0.ln0" in k:
155+
k = k.replace("layers.0.ln0", "layers.0.0.norm1")
156+
elif "layers.0.ln1" in k:
157+
k = k.replace("layers.0.ln1", "layers.0.0.norm2")
158+
elif "layers.1.ln0" in k:
159+
k = k.replace("layers.1.ln0", "layers.1.0.norm1")
160+
elif "layers.1.ln1" in k:
161+
k = k.replace("layers.1.ln1", "layers.1.0.norm2")
162+
elif "layers.2.ln0" in k:
163+
k = k.replace("layers.2.ln0", "layers.2.0.norm1")
164+
elif "layers.2.ln1" in k:
165+
k = k.replace("layers.2.ln1", "layers.2.0.norm2")
166+
elif "layers.3.ln0" in k:
167+
k = k.replace("layers.3.ln0", "layers.3.0.norm1")
168+
elif "layers.3.ln1" in k:
169+
k = k.replace("layers.3.ln1", "layers.3.0.norm2")
170+
elif "to_q" in k:
171+
parts = k.split(".")
172+
parts[2] = "attn"
173+
k = ".".join(parts)
174+
elif "to_out.0" in k:
175+
parts = k.split(".")
176+
parts[2] = "attn"
177+
k = ".".join(parts)
178+
k = k.replace("to_out.0", "to_out")
179+
else:
180+
k = k.replace("0.ff.0", "0.1.0")
181+
k = k.replace("0.ff.1.net.0.proj", "0.1.1")
182+
k = k.replace("0.ff.1.net.2", "0.1.3")
183+
184+
k = k.replace("1.ff.0", "1.1.0")
185+
k = k.replace("1.ff.1.net.0.proj", "1.1.1")
186+
k = k.replace("1.ff.1.net.2", "1.1.3")
187+
188+
k = k.replace("2.ff.0", "2.1.0")
189+
k = k.replace("2.ff.1.net.0.proj", "2.1.1")
190+
k = k.replace("2.ff.1.net.2", "2.1.3")
191+
192+
k = k.replace("3.ff.0", "3.1.0")
193+
k = k.replace("3.ff.1.net.0.proj", "3.1.1")
194+
k = k.replace("3.ff.1.net.2", "3.1.3")
195+
196+
# if "norm_cross" in k:
197+
# ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
198+
# elif "layer_norm" in k:
199+
# ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
200+
if "to_k" in k:
201+
parts = k.split(".")
202+
parts[2] = "attn"
203+
k = ".".join(parts)
180204
ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0)
181205
elif "to_v" in k:
182206
continue
183-
elif "to_out.0" in k:
184-
ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v
185207
else:
186208
ip_image_projection_state_dict[k] = v
187209

0 commit comments

Comments
 (0)