@@ -208,145 +208,10 @@ def __call__(
208208 return hidden_states
209209
210210
211- # TODO: support IP Adapter for Flux.2 as well
212- class FluxIPAdapterAttnProcessor (torch .nn .Module ):
213- """Flux Attention processor for IP-Adapter."""
214-
215- _attention_backend = None
216- _parallel_config = None
217-
218- def __init__ (
219- self , hidden_size : int , cross_attention_dim : int , num_tokens = (4 ,), scale = 1.0 , device = None , dtype = None
220- ):
221- super ().__init__ ()
222-
223- if not hasattr (F , "scaled_dot_product_attention" ):
224- raise ImportError (
225- f"{ self .__class__ .__name__ } requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
226- )
227-
228- self .hidden_size = hidden_size
229- self .cross_attention_dim = cross_attention_dim
230-
231- if not isinstance (num_tokens , (tuple , list )):
232- num_tokens = [num_tokens ]
233-
234- if not isinstance (scale , list ):
235- scale = [scale ] * len (num_tokens )
236- if len (scale ) != len (num_tokens ):
237- raise ValueError ("`scale` should be a list of integers with the same length as `num_tokens`." )
238- self .scale = scale
239-
240- self .to_k_ip = nn .ModuleList (
241- [
242- nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
243- for _ in range (len (num_tokens ))
244- ]
245- )
246- self .to_v_ip = nn .ModuleList (
247- [
248- nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
249- for _ in range (len (num_tokens ))
250- ]
251- )
252-
253- def __call__ (
254- self ,
255- attn : "Flux2Attention" ,
256- hidden_states : torch .Tensor ,
257- encoder_hidden_states : torch .Tensor = None ,
258- attention_mask : Optional [torch .Tensor ] = None ,
259- image_rotary_emb : Optional [torch .Tensor ] = None ,
260- ip_hidden_states : Optional [List [torch .Tensor ]] = None ,
261- ip_adapter_masks : Optional [torch .Tensor ] = None ,
262- ) -> torch .Tensor :
263- batch_size = hidden_states .shape [0 ]
264-
265- query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
266- attn , hidden_states , encoder_hidden_states
267- )
268-
269- query = query .unflatten (- 1 , (attn .heads , - 1 ))
270- key = key .unflatten (- 1 , (attn .heads , - 1 ))
271- value = value .unflatten (- 1 , (attn .heads , - 1 ))
272-
273- query = attn .norm_q (query )
274- key = attn .norm_k (key )
275- ip_query = query
276-
277- if encoder_hidden_states is not None :
278- encoder_query = encoder_query .unflatten (- 1 , (attn .heads , - 1 ))
279- encoder_key = encoder_key .unflatten (- 1 , (attn .heads , - 1 ))
280- encoder_value = encoder_value .unflatten (- 1 , (attn .heads , - 1 ))
281-
282- encoder_query = attn .norm_added_q (encoder_query )
283- encoder_key = attn .norm_added_k (encoder_key )
284-
285- query = torch .cat ([encoder_query , query ], dim = 1 )
286- key = torch .cat ([encoder_key , key ], dim = 1 )
287- value = torch .cat ([encoder_value , value ], dim = 1 )
288-
289- if image_rotary_emb is not None :
290- query = apply_rotary_emb (query , image_rotary_emb , sequence_dim = 1 )
291- key = apply_rotary_emb (key , image_rotary_emb , sequence_dim = 1 )
292-
293- hidden_states = dispatch_attention_fn (
294- query ,
295- key ,
296- value ,
297- attn_mask = attention_mask ,
298- dropout_p = 0.0 ,
299- is_causal = False ,
300- backend = self ._attention_backend ,
301- parallel_config = self ._parallel_config ,
302- )
303- hidden_states = hidden_states .flatten (2 , 3 )
304- hidden_states = hidden_states .to (query .dtype )
305-
306- if encoder_hidden_states is not None :
307- encoder_hidden_states , hidden_states = hidden_states .split_with_sizes (
308- [encoder_hidden_states .shape [1 ], hidden_states .shape [1 ] - encoder_hidden_states .shape [1 ]], dim = 1
309- )
310- hidden_states = attn .to_out [0 ](hidden_states )
311- hidden_states = attn .to_out [1 ](hidden_states )
312- encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
313-
314- # IP-adapter
315- ip_attn_output = torch .zeros_like (hidden_states )
316-
317- for current_ip_hidden_states , scale , to_k_ip , to_v_ip in zip (
318- ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip
319- ):
320- ip_key = to_k_ip (current_ip_hidden_states )
321- ip_value = to_v_ip (current_ip_hidden_states )
322-
323- ip_key = ip_key .view (batch_size , - 1 , attn .heads , attn .head_dim )
324- ip_value = ip_value .view (batch_size , - 1 , attn .heads , attn .head_dim )
325-
326- current_ip_hidden_states = dispatch_attention_fn (
327- ip_query ,
328- ip_key ,
329- ip_value ,
330- attn_mask = None ,
331- dropout_p = 0.0 ,
332- is_causal = False ,
333- backend = self ._attention_backend ,
334- parallel_config = self ._parallel_config ,
335- )
336- current_ip_hidden_states = current_ip_hidden_states .reshape (batch_size , - 1 , attn .heads * attn .head_dim )
337- current_ip_hidden_states = current_ip_hidden_states .to (ip_query .dtype )
338- ip_attn_output += scale * current_ip_hidden_states
339-
340- return hidden_states , encoder_hidden_states , ip_attn_output
341- else :
342- return hidden_states
343-
344-
345211class Flux2Attention (torch .nn .Module , AttentionModuleMixin ):
346212 _default_processor_cls = Flux2AttnProcessor
347213 _available_processors = [
348214 Flux2AttnProcessor ,
349- FluxIPAdapterAttnProcessor ,
350215 ]
351216
352217 def __init__ (
0 commit comments