@@ -2107,10 +2107,10 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
21072107class IPAdapterTimeImageProjectionBlock (nn .Module ):
21082108 def __init__ (
21092109 self ,
2110- hidden_dim : int = 768 ,
2110+ hidden_dim : int = 1280 ,
21112111 dim_head : int = 64 ,
2112- heads : int = 16 ,
2113- ffn_ratio : float = 4 ,
2112+ heads : int = 20 ,
2113+ ffn_ratio : int = 4 ,
21142114 ) -> None :
21152115 super ().__init__ ()
21162116 from .attention import FeedForward
@@ -2124,7 +2124,6 @@ def __init__(
21242124 heads = heads ,
21252125 bias = False ,
21262126 out_bias = False ,
2127- processor = FusedAttnProcessor2_0 (),
21282127 )
21292128 self .ff = FeedForward (hidden_dim , hidden_dim , activation_fn = "gelu" , mult = ffn_ratio , bias = False )
21302129
@@ -2133,21 +2132,47 @@ def __init__(
21332132 self .adaln_proj = nn .Linear (hidden_dim , 4 * hidden_dim )
21342133 self .adaln_norm = nn .LayerNorm (hidden_dim )
21352134
2136- # Set scale and fuse KV
2135+ # Set attention scale and fuse KV
21372136 self .attn .scale = 1 / math .sqrt (math .sqrt (dim_head ))
21382137 self .attn .fuse_projections ()
21392138 self .attn .to_k = None
21402139 self .attn .to_v = None
21412140
21422141 def forward (self , x , latents , timestep_emb ):
2142+ # Shift and scale for AdaLayerNorm
21432143 emb = self .adaln_proj (self .adaln_silu (timestep_emb ))
21442144 shift_msa , scale_msa , shift_mlp , scale_mlp = emb .chunk (4 , dim = 1 )
21452145
2146+ # Fused Attention
21462147 residual = latents
21472148 x = self .ln0 (x )
21482149 latents = self .ln1 (latents ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
2149- latents = self .attn (latents , torch .cat ((x , latents ), dim = - 2 )) + residual
21502150
2151+ batch_size = latents .shape [0 ]
2152+
2153+ query = self .attn .to_q (latents )
2154+ kv_input = torch .cat ((x , latents ), dim = - 2 )
2155+ kv = self .attn .to_kv (kv_input )
2156+ split_size = kv .shape [- 1 ] // 2
2157+ key , value = torch .split (kv , split_size , dim = - 1 )
2158+
2159+ inner_dim = key .shape [- 1 ]
2160+ head_dim = inner_dim // self .attn .heads
2161+
2162+ query = query .view (batch_size , - 1 , self .attn .heads , head_dim ).transpose (1 , 2 )
2163+ key = key .view (batch_size , - 1 , self .attn .heads , head_dim ).transpose (1 , 2 )
2164+ value = value .view (batch_size , - 1 , self .attn .heads , head_dim ).transpose (1 , 2 )
2165+
2166+ weight = (query * self .attn .scale ) @ (key * self .attn .scale ).transpose (- 2 , - 1 )
2167+ weight = torch .softmax (weight .float (), dim = - 1 ).type (weight .dtype )
2168+ latents = weight @ value
2169+
2170+ latents = latents .transpose (1 , 2 ).reshape (batch_size , - 1 , self .attn .heads * head_dim )
2171+ latents = self .attn .to_out [0 ](latents )
2172+ latents = self .attn .to_out [1 ](latents )
2173+ latents = latents + residual
2174+
2175+ ## FeedForward
21512176 residual = latents
21522177 latents = self .adaln_norm (latents ) * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
21532178 return self .ff (latents ) + residual
0 commit comments