@@ -482,7 +482,7 @@ def forward(
482482 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
483483
484484 attn_parameters = set (inspect .signature (self .processor .__call__ ).parameters .keys ())
485- quiet_attn_parameters = {"ip_adapter_masks" }
485+ quiet_attn_parameters = {"ip_adapter_masks" , "image_projection" }
486486 unused_kwargs = [
487487 k for k , _ in cross_attention_kwargs .items () if k not in attn_parameters and k not in quiet_attn_parameters
488488 ]
@@ -1893,31 +1893,43 @@ def __call__(
18931893 return hidden_states
18941894
18951895
1896- class FluxIPAdapterAttnProcessor2_0 :
1896+ class FluxIPAdapterAttnProcessor2_0 ( torch . nn . Module ) :
18971897 """Flux Attention processor for IP-Adapter."""
18981898
1899- def __init__ (self , hidden_size : int , cross_attention_dim : int , scale : float = 1.0 ):
1899+ def __init__ (
1900+ self , hidden_size : int , cross_attention_dim : int , num_tokens = (4 ,), scale = 1.0 , device = None , dtype = None
1901+ ):
19001902 super ().__init__ ()
19011903
1902- r"""
1903- Args:
1904- hidden_size (`int`):
1905- The hidden size of the attention layer.
1906- cross_attention_dim (`int`):
1907- The number of channels in the `encoder_hidden_states`.
1908- scale (`float`, defaults to 1.0):
1909- the weight scale of image prompt.
1910- """
1911-
19121904 if not hasattr (F , "scaled_dot_product_attention" ):
19131905 raise ImportError (
1914- "FluxIPAdapterAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1906+ f" { self . __class__ . __name__ } requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
19151907 )
19161908
1909+ self .hidden_size = hidden_size
1910+ self .cross_attention_dim = cross_attention_dim
1911+
1912+ if not isinstance (num_tokens , (tuple , list )):
1913+ num_tokens = [num_tokens ]
1914+
1915+ if not isinstance (scale , list ):
1916+ scale = [scale ] * len (num_tokens )
1917+ if len (scale ) != len (num_tokens ):
1918+ raise ValueError ("`scale` should be a list of integers with the same length as `num_tokens`." )
19171919 self .scale = scale
19181920
1919- self .to_k_ip = nn .Linear (cross_attention_dim , hidden_size )
1920- self .to_v_ip = nn .Linear (cross_attention_dim , hidden_size )
1921+ self .to_k_ip = nn .ModuleList (
1922+ [
1923+ nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
1924+ for _ in range (len (num_tokens ))
1925+ ]
1926+ )
1927+ self .to_v_ip = nn .ModuleList (
1928+ [
1929+ nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
1930+ for _ in range (len (num_tokens ))
1931+ ]
1932+ )
19211933
19221934 def __call__ (
19231935 self ,
@@ -1926,24 +1938,27 @@ def __call__(
19261938 encoder_hidden_states : torch .FloatTensor = None ,
19271939 attention_mask : Optional [torch .FloatTensor ] = None ,
19281940 image_rotary_emb : Optional [torch .Tensor ] = None ,
1929- image_projection : Optional [torch .Tensor ] = None ,
1941+ image_projection : Optional [List [torch .Tensor ]] = None ,
1942+ ip_adapter_masks : Optional [torch .Tensor ] = None ,
19301943 ) -> torch .FloatTensor :
1944+ if image_projection is None :
1945+ raise ValueError ("image_projection is None" )
19311946 batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
19321947
19331948 # `sample` projections.
1934- query = attn .to_q (hidden_states )
1949+ hidden_states_query_proj = attn .to_q (hidden_states )
19351950 key = attn .to_k (hidden_states )
19361951 value = attn .to_v (hidden_states )
19371952
19381953 inner_dim = key .shape [- 1 ]
19391954 head_dim = inner_dim // attn .heads
19401955
1941- query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1956+ hidden_states_query_proj = hidden_states_query_proj .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
19421957 key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
19431958 value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
19441959
19451960 if attn .norm_q is not None :
1946- query = attn .norm_q (query )
1961+ hidden_states_query_proj = attn .norm_q (hidden_states_query_proj )
19471962 if attn .norm_k is not None :
19481963 key = attn .norm_k (key )
19491964
@@ -1970,7 +1985,7 @@ def __call__(
19701985 encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
19711986
19721987 # attention
1973- query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
1988+ query = torch .cat ([encoder_hidden_states_query_proj , hidden_states_query_proj ], dim = 2 )
19741989 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
19751990 value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
19761991
@@ -1997,19 +2012,104 @@ def __call__(
19972012 encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
19982013
19992014 # IP-adapter
2000- ip_key = self .to_k_ip (image_projection )
2001- ip_value = self .to_v_ip (image_projection )
2015+ ip_hidden_states = image_projection
2016+
2017+ if ip_adapter_masks is not None :
2018+ if not isinstance (ip_adapter_masks , List ):
2019+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2020+ ip_adapter_masks = list (ip_adapter_masks .unsqueeze (1 ))
2021+ if not (len (ip_adapter_masks ) == len (self .scale ) == len (ip_hidden_states )):
2022+ raise ValueError (
2023+ f"Length of ip_adapter_masks array ({ len (ip_adapter_masks )} ) must match "
2024+ f"length of self.scale array ({ len (self .scale )} ) and number of ip_hidden_states "
2025+ f"({ len (ip_hidden_states )} )"
2026+ )
2027+ else :
2028+ for index , (mask , scale , ip_state ) in enumerate (
2029+ zip (ip_adapter_masks , self .scale , ip_hidden_states )
2030+ ):
2031+ if not isinstance (mask , torch .Tensor ) or mask .ndim != 4 :
2032+ raise ValueError (
2033+ "Each element of the ip_adapter_masks array should be a tensor with shape "
2034+ "[1, num_images_for_ip_adapter, height, width]."
2035+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2036+ )
2037+ if mask .shape [1 ] != ip_state .shape [1 ]:
2038+ raise ValueError (
2039+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
2040+ f"number of ip images ({ ip_state .shape [1 ]} ) at index { index } "
2041+ )
2042+ if isinstance (scale , list ) and not len (scale ) == mask .shape [1 ]:
2043+ raise ValueError (
2044+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
2045+ f"number of scales ({ len (scale )} ) at index { index } "
2046+ )
2047+ else :
2048+ ip_adapter_masks = [None ] * len (self .scale )
2049+
2050+ ip_query = hidden_states_query_proj
2051+ # for ip-adapter
2052+ for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
2053+ ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
2054+ ):
2055+ skip = False
2056+ if isinstance (scale , list ):
2057+ if all (s == 0 for s in scale ):
2058+ skip = True
2059+ elif scale == 0 :
2060+ skip = True
2061+ if not skip :
2062+ if mask is not None :
2063+ if not isinstance (scale , list ):
2064+ scale = [scale ] * mask .shape [1 ]
2065+
2066+ current_num_images = mask .shape [1 ]
2067+ for i in range (current_num_images ):
2068+ ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
2069+ ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
2070+
2071+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2072+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2073+
2074+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2075+ # TODO: add support for attn.scale when we move to Torch 2.1
2076+ _current_ip_hidden_states = F .scaled_dot_product_attention (
2077+ ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2078+ )
2079+
2080+ _current_ip_hidden_states = _current_ip_hidden_states .transpose (1 , 2 ).reshape (
2081+ batch_size , - 1 , attn .heads * head_dim
2082+ )
2083+ _current_ip_hidden_states = _current_ip_hidden_states .to (ip_query .dtype )
2084+
2085+ mask_downsample = IPAdapterMaskProcessor .downsample (
2086+ mask [:, i , :, :],
2087+ batch_size ,
2088+ _current_ip_hidden_states .shape [1 ],
2089+ _current_ip_hidden_states .shape [2 ],
2090+ )
2091+
2092+ mask_downsample = mask_downsample .to (dtype = ip_query .dtype , device = ip_query .device )
2093+ hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
2094+ else :
2095+ ip_key = to_k_ip (current_ip_hidden_states )
2096+ ip_value = to_v_ip (current_ip_hidden_states )
20022097
2003- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2004- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2005- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2006- # TODO: add support for attn.scale when we move to Torch 2.1
2007- ip_hidden_states = F .scaled_dot_product_attention (
2008- query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2009- )
2010- ip_hidden_states = ip_hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2011- ip_hidden_states = ip_hidden_states .to (query .dtype )
2012- hidden_states = hidden_states + self .scale * ip_hidden_states
2098+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2099+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2100+
2101+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2102+ # TODO: add support for attn.scale when we move to Torch 2.1
2103+ current_ip_hidden_states = F .scaled_dot_product_attention (
2104+ ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2105+ )
2106+
2107+ current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
2108+ batch_size , - 1 , attn .heads * head_dim
2109+ )
2110+ current_ip_hidden_states = current_ip_hidden_states .to (ip_query .dtype )
2111+
2112+ hidden_states = hidden_states + scale * current_ip_hidden_states
20132113
20142114 return hidden_states , encoder_hidden_states
20152115 else :
0 commit comments