2121
2222from ..utils import deprecate
2323from .activations import FP32SiLU , get_activation
24- from .attention_processor import Attention
24+ from .attention_processor import Attention , FusedAttnProcessor2_0
2525
2626
2727def get_timestep_embedding (
@@ -2104,76 +2104,55 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
21042104 return out
21052105
21062106
2107- # Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2108- class TimePerceiverAttention (nn .Module ):
2107+ class IPAdapterTimeImageProjectionBlock (nn .Module ):
21092108 def __init__ (
21102109 self ,
2111- * ,
2112- dim : int ,
2110+ hidden_dim : int = 768 ,
21132111 dim_head : int = 64 ,
2114- heads : int = 8 ,
2112+ heads : int = 16 ,
2113+ ffn_ratio : float = 4 ,
21152114 ) -> None :
21162115 super ().__init__ ()
2116+ from .attention import FeedForward
21172117
2118- self .scale = dim_head ** - 0.5
2119- self .dim_head = dim_head
2120- self .heads = heads
2121- inner_dim = dim_head * heads
2122-
2123- self .norm1 = nn .LayerNorm (dim )
2124- self .norm2 = nn .LayerNorm (dim )
2125-
2126- self .to_q = nn .Linear (dim , inner_dim , bias = False )
2127- self .to_kv = nn .Linear (dim , inner_dim * 2 , bias = False )
2128- self .to_out = nn .Linear (inner_dim , dim , bias = False )
2129-
2130- def forward (self , x , latents , shift = None , scale = None ):
2131- """
2132- Args:
2133- x (torch.Tensor): image features
2134- shape (b, n1, D)
2135- latent (torch.Tensor): latent features
2136- shape (b, n2, D)
2137- """
2138-
2139- def reshape_tensor (x , heads ):
2140- bs , length , _ = x .shape
2141- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
2142- x = x .view (bs , length , heads , - 1 )
2143- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
2144- x = x .transpose (1 , 2 )
2145- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
2146- return x .reshape (bs , heads , length , - 1 )
2147-
2148- x = self .norm1 (x )
2149- latents = self .norm2 (latents )
2150-
2151- if shift is not None and scale is not None :
2152- latents = latents * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
2153-
2154- b , l , _ = latents .shape
2118+ self .ln0 = nn .LayerNorm (hidden_dim )
2119+ self .ln1 = nn .LayerNorm (hidden_dim )
2120+ self .attn = Attention (
2121+ query_dim = hidden_dim ,
2122+ cross_attention_dim = hidden_dim ,
2123+ dim_head = dim_head ,
2124+ heads = heads ,
2125+ bias = False ,
2126+ out_bias = False ,
2127+ processor = FusedAttnProcessor2_0 (),
2128+ )
2129+ self .ff = FeedForward (hidden_dim , hidden_dim , activation_fn = "gelu" , mult = ffn_ratio , bias = False )
21552130
2156- q = self .to_q (latents )
2157- kv_input = torch .cat ((x , latents ), dim = - 2 )
2158- k , v = self .to_kv (kv_input ).chunk (2 , dim = - 1 )
2131+ # AdaLayerNorm
2132+ self .adaln_silu = nn .SiLU ()
2133+ self .adaln_proj = nn .Linear (hidden_dim , 4 * hidden_dim )
2134+ self .adaln_norm = nn .LayerNorm (hidden_dim )
21592135
2160- q = reshape_tensor (q , self .heads )
2161- k = reshape_tensor (k , self .heads )
2162- v = reshape_tensor (v , self .heads )
2136+ # Custom scale cannot be passed in constructor
2137+ self .attn .scale = 1 / math .sqrt (math .sqrt (dim_head ))
2138+ self .attn .fuse_projections ()
2139+ self .attn .to_k = None
2140+ self .attn .to_v = None
21632141
2164- # attention
2165- scale = 1 / math .sqrt (math .sqrt (self .dim_head ))
2166- weight = (q * scale ) @ (k * scale ).transpose (- 2 , - 1 ) # More stable with f16 than dividing afterwards
2167- weight = torch .softmax (weight .float (), dim = - 1 ).type (weight .dtype )
2168- out = weight @ v
2142+ def forward (self , x , latents , timestep_emb ):
2143+ shift_msa , scale_msa , shift_mlp , scale_mlp = self .adaln_proj (self .adaln_silu (timestep_emb ))
21692144
2170- out = out .permute (0 , 2 , 1 , 3 ).reshape (b , l , - 1 )
2145+ x = self .ln0 (x )
2146+ latents = self .ln1 (latents ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
2147+ latents = self .attn (x , latents ) + latents
21712148
2172- return self .to_out (out )
2149+ residual = latents
2150+ latents = self .adaln_norm (latents ) * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
2151+ return self .ff (latents ) + residual
21732152
21742153
21752154# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2176- class TimePerceiverResampler (nn .Module ):
2155+ class IPAdapterTimeImageProjection (nn .Module ):
21772156 def __init__ (
21782157 self ,
21792158 embed_dim : int = 1152 ,
@@ -2189,65 +2168,32 @@ def __init__(
21892168 timestep_freq_shift : int = 0 ,
21902169 ) -> None :
21912170 super ().__init__ ()
2192-
21932171 self .latents = nn .Parameter (torch .randn (1 , num_queries , hidden_dim ) / hidden_dim ** 0.5 )
21942172 self .proj_in = nn .Linear (embed_dim , hidden_dim )
21952173 self .proj_out = nn .Linear (hidden_dim , output_dim )
21962174 self .norm_out = nn .LayerNorm (output_dim )
2197-
2198- ff_inner_dim = int (hidden_dim * ffn_ratio )
2199- self .layers = nn .ModuleList ([])
2200- for _ in range (depth ):
2201- self .layers .append (
2202- nn .ModuleList (
2203- [
2204- # msa
2205- TimePerceiverAttention (dim = hidden_dim , dim_head = dim_head , heads = heads ),
2206- # ff
2207- nn .Sequential (
2208- nn .LayerNorm (hidden_dim ),
2209- nn .Linear (hidden_dim , ff_inner_dim , bias = False ),
2210- nn .GELU (),
2211- nn .Linear (ff_inner_dim , hidden_dim , bias = False ),
2212- ),
2213- # adaLN
2214- nn .Sequential (nn .SiLU (), nn .Linear (hidden_dim , ff_inner_dim , bias = True )),
2215- ]
2216- )
2217- )
2218-
2219- # Time
2175+ self .layers = nn .ModuleList (
2176+ [IPAdapterTimeImageProjectionBlock (hidden_dim , dim_head , heads , ffn_ratio ) for _ in range (depth )]
2177+ )
22202178 self .time_proj = Timesteps (timestep_in_dim , timestep_flip_sin_to_cos , timestep_freq_shift )
22212179 self .time_embedding = TimestepEmbedding (timestep_in_dim , hidden_dim , act_fn = "silu" )
22222180
2223- def forward (self , x , timestep , need_temb = False ):
2181+ def forward (self , x , timestep ):
22242182 timestep_emb = self .time_proj (timestep ).to (dtype = x .dtype )
2225- timestep_emb = self .time_embedding (timestep_emb , None )
2183+ timestep_emb = self .time_embedding (timestep_emb )
22262184
22272185 latents = self .latents .repeat (x .size (0 ), 1 , 1 )
22282186
22292187 x = self .proj_in (x )
22302188 x = x + timestep_emb [:, None ]
22312189
2232- for attn , ff , adaLN_modulation in self .layers :
2233- shift_msa , scale_msa , shift_mlp , scale_mlp = adaLN_modulation (timestep_emb ).chunk (4 , dim = 1 )
2234- latents = attn (x , latents , shift_msa , scale_msa ) + latents
2235-
2236- res = latents
2237- for idx_ff in range (len (ff )):
2238- layer_ff = ff [idx_ff ]
2239- latents = layer_ff (latents )
2240- if idx_ff == 0 and isinstance (layer_ff , nn .LayerNorm ): # adaLN
2241- latents = latents * (1 + scale_mlp .unsqueeze (1 )) + shift_mlp .unsqueeze (1 )
2242- latents = latents + res
2190+ for block in self .layers :
2191+ latents = block (x , latents , timestep_emb )
22432192
22442193 latents = self .proj_out (latents )
22452194 latents = self .norm_out (latents )
22462195
2247- if need_temb :
2248- return latents , timestep_emb
2249- else :
2250- return latents
2196+ return latents , timestep_emb
22512197
22522198
22532199class MultiIPAdapterImageProjection (nn .Module ):
0 commit comments