@@ -482,7 +482,7 @@ def forward(
482482 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
483483
484484 attn_parameters = set (inspect .signature (self .processor .__call__ ).parameters .keys ())
485- quiet_attn_parameters = {"ip_adapter_masks" }
485+ quiet_attn_parameters = {"ip_adapter_masks" , "image_projection" }
486486 unused_kwargs = [
487487 k for k , _ in cross_attention_kwargs .items () if k not in attn_parameters and k not in quiet_attn_parameters
488488 ]
@@ -1987,31 +1987,43 @@ def __call__(
19871987 return hidden_states
19881988
19891989
1990- class FluxIPAdapterAttnProcessor2_0 :
1990+ class FluxIPAdapterAttnProcessor2_0 ( torch . nn . Module ) :
19911991 """Flux Attention processor for IP-Adapter."""
19921992
1993- def __init__ (self , hidden_size : int , cross_attention_dim : int , scale : float = 1.0 ):
1993+ def __init__ (
1994+ self , hidden_size : int , cross_attention_dim : int , num_tokens = (4 ,), scale = 1.0 , device = None , dtype = None
1995+ ):
19941996 super ().__init__ ()
19951997
1996- r"""
1997- Args:
1998- hidden_size (`int`):
1999- The hidden size of the attention layer.
2000- cross_attention_dim (`int`):
2001- The number of channels in the `encoder_hidden_states`.
2002- scale (`float`, defaults to 1.0):
2003- the weight scale of image prompt.
2004- """
2005-
20061998 if not hasattr (F , "scaled_dot_product_attention" ):
20071999 raise ImportError (
2008- "FluxIPAdapterAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2000+ f" { self . __class__ . __name__ } requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
20092001 )
20102002
2003+ self .hidden_size = hidden_size
2004+ self .cross_attention_dim = cross_attention_dim
2005+
2006+ if not isinstance (num_tokens , (tuple , list )):
2007+ num_tokens = [num_tokens ]
2008+
2009+ if not isinstance (scale , list ):
2010+ scale = [scale ] * len (num_tokens )
2011+ if len (scale ) != len (num_tokens ):
2012+ raise ValueError ("`scale` should be a list of integers with the same length as `num_tokens`." )
20112013 self .scale = scale
20122014
2013- self .to_k_ip = nn .Linear (cross_attention_dim , hidden_size )
2014- self .to_v_ip = nn .Linear (cross_attention_dim , hidden_size )
2015+ self .to_k_ip = nn .ModuleList (
2016+ [
2017+ nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
2018+ for _ in range (len (num_tokens ))
2019+ ]
2020+ )
2021+ self .to_v_ip = nn .ModuleList (
2022+ [
2023+ nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
2024+ for _ in range (len (num_tokens ))
2025+ ]
2026+ )
20152027
20162028 def __call__ (
20172029 self ,
@@ -2020,24 +2032,27 @@ def __call__(
20202032 encoder_hidden_states : torch .FloatTensor = None ,
20212033 attention_mask : Optional [torch .FloatTensor ] = None ,
20222034 image_rotary_emb : Optional [torch .Tensor ] = None ,
2023- image_projection : Optional [torch .Tensor ] = None ,
2035+ image_projection : Optional [List [torch .Tensor ]] = None ,
2036+ ip_adapter_masks : Optional [torch .Tensor ] = None ,
20242037 ) -> torch .FloatTensor :
2038+ if image_projection is None :
2039+ raise ValueError ("image_projection is None" )
20252040 batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
20262041
20272042 # `sample` projections.
2028- query = attn .to_q (hidden_states )
2043+ hidden_states_query_proj = attn .to_q (hidden_states )
20292044 key = attn .to_k (hidden_states )
20302045 value = attn .to_v (hidden_states )
20312046
20322047 inner_dim = key .shape [- 1 ]
20332048 head_dim = inner_dim // attn .heads
20342049
2035- query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2050+ hidden_states_query_proj = hidden_states_query_proj .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
20362051 key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
20372052 value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
20382053
20392054 if attn .norm_q is not None :
2040- query = attn .norm_q (query )
2055+ hidden_states_query_proj = attn .norm_q (hidden_states_query_proj )
20412056 if attn .norm_k is not None :
20422057 key = attn .norm_k (key )
20432058
@@ -2064,7 +2079,7 @@ def __call__(
20642079 encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
20652080
20662081 # attention
2067- query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
2082+ query = torch .cat ([encoder_hidden_states_query_proj , hidden_states_query_proj ], dim = 2 )
20682083 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
20692084 value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
20702085
@@ -2091,19 +2106,104 @@ def __call__(
20912106 encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
20922107
20932108 # IP-adapter
2094- ip_key = self .to_k_ip (image_projection )
2095- ip_value = self .to_v_ip (image_projection )
2109+ ip_hidden_states = image_projection
2110+
2111+ if ip_adapter_masks is not None :
2112+ if not isinstance (ip_adapter_masks , List ):
2113+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2114+ ip_adapter_masks = list (ip_adapter_masks .unsqueeze (1 ))
2115+ if not (len (ip_adapter_masks ) == len (self .scale ) == len (ip_hidden_states )):
2116+ raise ValueError (
2117+ f"Length of ip_adapter_masks array ({ len (ip_adapter_masks )} ) must match "
2118+ f"length of self.scale array ({ len (self .scale )} ) and number of ip_hidden_states "
2119+ f"({ len (ip_hidden_states )} )"
2120+ )
2121+ else :
2122+ for index , (mask , scale , ip_state ) in enumerate (
2123+ zip (ip_adapter_masks , self .scale , ip_hidden_states )
2124+ ):
2125+ if not isinstance (mask , torch .Tensor ) or mask .ndim != 4 :
2126+ raise ValueError (
2127+ "Each element of the ip_adapter_masks array should be a tensor with shape "
2128+ "[1, num_images_for_ip_adapter, height, width]."
2129+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2130+ )
2131+ if mask .shape [1 ] != ip_state .shape [1 ]:
2132+ raise ValueError (
2133+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
2134+ f"number of ip images ({ ip_state .shape [1 ]} ) at index { index } "
2135+ )
2136+ if isinstance (scale , list ) and not len (scale ) == mask .shape [1 ]:
2137+ raise ValueError (
2138+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
2139+ f"number of scales ({ len (scale )} ) at index { index } "
2140+ )
2141+ else :
2142+ ip_adapter_masks = [None ] * len (self .scale )
2143+
2144+ ip_query = hidden_states_query_proj
2145+ # for ip-adapter
2146+ for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
2147+ ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
2148+ ):
2149+ skip = False
2150+ if isinstance (scale , list ):
2151+ if all (s == 0 for s in scale ):
2152+ skip = True
2153+ elif scale == 0 :
2154+ skip = True
2155+ if not skip :
2156+ if mask is not None :
2157+ if not isinstance (scale , list ):
2158+ scale = [scale ] * mask .shape [1 ]
2159+
2160+ current_num_images = mask .shape [1 ]
2161+ for i in range (current_num_images ):
2162+ ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
2163+ ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
2164+
2165+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2166+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2167+
2168+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2169+ # TODO: add support for attn.scale when we move to Torch 2.1
2170+ _current_ip_hidden_states = F .scaled_dot_product_attention (
2171+ ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2172+ )
2173+
2174+ _current_ip_hidden_states = _current_ip_hidden_states .transpose (1 , 2 ).reshape (
2175+ batch_size , - 1 , attn .heads * head_dim
2176+ )
2177+ _current_ip_hidden_states = _current_ip_hidden_states .to (ip_query .dtype )
2178+
2179+ mask_downsample = IPAdapterMaskProcessor .downsample (
2180+ mask [:, i , :, :],
2181+ batch_size ,
2182+ _current_ip_hidden_states .shape [1 ],
2183+ _current_ip_hidden_states .shape [2 ],
2184+ )
2185+
2186+ mask_downsample = mask_downsample .to (dtype = ip_query .dtype , device = ip_query .device )
2187+ hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
2188+ else :
2189+ ip_key = to_k_ip (current_ip_hidden_states )
2190+ ip_value = to_v_ip (current_ip_hidden_states )
20962191
2097- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2098- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2099- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2100- # TODO: add support for attn.scale when we move to Torch 2.1
2101- ip_hidden_states = F .scaled_dot_product_attention (
2102- query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2103- )
2104- ip_hidden_states = ip_hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2105- ip_hidden_states = ip_hidden_states .to (query .dtype )
2106- hidden_states = hidden_states + self .scale * ip_hidden_states
2192+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2193+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2194+
2195+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2196+ # TODO: add support for attn.scale when we move to Torch 2.1
2197+ current_ip_hidden_states = F .scaled_dot_product_attention (
2198+ ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2199+ )
2200+
2201+ current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
2202+ batch_size , - 1 , attn .heads * head_dim
2203+ )
2204+ current_ip_hidden_states = current_ip_hidden_states .to (ip_query .dtype )
2205+
2206+ hidden_states = hidden_states + scale * current_ip_hidden_states
21072207
21082208 return hidden_states , encoder_hidden_states
21092209 else :
0 commit comments