@@ -2501,152 +2501,6 @@ def __call__(
25012501 return hidden_states
25022502
25032503
2504- class FluxIPAdapterJointAttnProcessor2_0 (torch .nn .Module ):
2505- """Flux Attention processor for IP-Adapter."""
2506-
2507- def __init__ (
2508- self , hidden_size : int , cross_attention_dim : int , num_tokens = (4 ,), scale = 1.0 , device = None , dtype = None
2509- ):
2510- super ().__init__ ()
2511-
2512- if not hasattr (F , "scaled_dot_product_attention" ):
2513- raise ImportError (
2514- f"{ self .__class__ .__name__ } requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2515- )
2516-
2517- self .hidden_size = hidden_size
2518- self .cross_attention_dim = cross_attention_dim
2519-
2520- if not isinstance (num_tokens , (tuple , list )):
2521- num_tokens = [num_tokens ]
2522-
2523- if not isinstance (scale , list ):
2524- scale = [scale ] * len (num_tokens )
2525- if len (scale ) != len (num_tokens ):
2526- raise ValueError ("`scale` should be a list of integers with the same length as `num_tokens`." )
2527- self .scale = scale
2528-
2529- self .to_k_ip = nn .ModuleList (
2530- [
2531- nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
2532- for _ in range (len (num_tokens ))
2533- ]
2534- )
2535- self .to_v_ip = nn .ModuleList (
2536- [
2537- nn .Linear (cross_attention_dim , hidden_size , bias = True , device = device , dtype = dtype )
2538- for _ in range (len (num_tokens ))
2539- ]
2540- )
2541-
2542- def __call__ (
2543- self ,
2544- attn : Attention ,
2545- hidden_states : torch .FloatTensor ,
2546- encoder_hidden_states : torch .FloatTensor = None ,
2547- attention_mask : Optional [torch .FloatTensor ] = None ,
2548- image_rotary_emb : Optional [torch .Tensor ] = None ,
2549- ip_hidden_states : Optional [List [torch .Tensor ]] = None ,
2550- ip_adapter_masks : Optional [torch .Tensor ] = None ,
2551- ) -> torch .FloatTensor :
2552- batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
2553-
2554- # `sample` projections.
2555- hidden_states_query_proj = attn .to_q (hidden_states )
2556- key = attn .to_k (hidden_states )
2557- value = attn .to_v (hidden_states )
2558-
2559- inner_dim = key .shape [- 1 ]
2560- head_dim = inner_dim // attn .heads
2561-
2562- hidden_states_query_proj = hidden_states_query_proj .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2563- key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2564- value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2565-
2566- if attn .norm_q is not None :
2567- hidden_states_query_proj = attn .norm_q (hidden_states_query_proj )
2568- if attn .norm_k is not None :
2569- key = attn .norm_k (key )
2570-
2571- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2572- if encoder_hidden_states is not None :
2573- # `context` projections.
2574- encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
2575- encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
2576- encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
2577-
2578- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
2579- batch_size , - 1 , attn .heads , head_dim
2580- ).transpose (1 , 2 )
2581- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
2582- batch_size , - 1 , attn .heads , head_dim
2583- ).transpose (1 , 2 )
2584- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
2585- batch_size , - 1 , attn .heads , head_dim
2586- ).transpose (1 , 2 )
2587-
2588- if attn .norm_added_q is not None :
2589- encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
2590- if attn .norm_added_k is not None :
2591- encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
2592-
2593- # attention
2594- query = torch .cat ([encoder_hidden_states_query_proj , hidden_states_query_proj ], dim = 2 )
2595- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
2596- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
2597-
2598- if image_rotary_emb is not None :
2599- from .embeddings import apply_rotary_emb
2600-
2601- query = apply_rotary_emb (query , image_rotary_emb )
2602- key = apply_rotary_emb (key , image_rotary_emb )
2603-
2604- hidden_states = F .scaled_dot_product_attention (
2605- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
2606- )
2607- hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2608- hidden_states = hidden_states .to (query .dtype )
2609-
2610- if encoder_hidden_states is not None :
2611- encoder_hidden_states , hidden_states = (
2612- hidden_states [:, : encoder_hidden_states .shape [1 ]],
2613- hidden_states [:, encoder_hidden_states .shape [1 ] :],
2614- )
2615-
2616- # linear proj
2617- hidden_states = attn .to_out [0 ](hidden_states )
2618- # dropout
2619- hidden_states = attn .to_out [1 ](hidden_states )
2620- encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
2621-
2622- # IP-adapter
2623- ip_query = hidden_states_query_proj
2624- ip_attn_output = torch .zeros_like (hidden_states )
2625-
2626- for current_ip_hidden_states , scale , to_k_ip , to_v_ip in zip (
2627- ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip
2628- ):
2629- ip_key = to_k_ip (current_ip_hidden_states )
2630- ip_value = to_v_ip (current_ip_hidden_states )
2631-
2632- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2633- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2634- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2635- # TODO: add support for attn.scale when we move to Torch 2.1
2636- current_ip_hidden_states = F .scaled_dot_product_attention (
2637- ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2638- )
2639- current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
2640- batch_size , - 1 , attn .heads * head_dim
2641- )
2642- current_ip_hidden_states = current_ip_hidden_states .to (ip_query .dtype )
2643- ip_attn_output += scale * current_ip_hidden_states
2644-
2645- return hidden_states , encoder_hidden_states , ip_attn_output
2646- else :
2647- return hidden_states
2648-
2649-
26502504class CogVideoXAttnProcessor2_0 :
26512505 r"""
26522506 Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -6019,6 +5873,16 @@ def __new__(cls, *args, **kwargs):
60195873 return FluxAttnProcessor (* args , ** kwargs )
60205874
60215875
5876+ class FluxIPAdapterJointAttnProcessor2_0 :
5877+ def __new__ (cls , * args , ** kwargs ):
5878+ deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
5879+ deprecate ("FluxIPAdapterJointAttnProcessor2_0" , "1.0.0" , deprecation_message )
5880+
5881+ from .transformers .transformer_flux import FluxIPAdapterAttnProcessor
5882+
5883+ return FluxIPAdapterAttnProcessor (* args , ** kwargs )
5884+
5885+
60225886ADDED_KV_ATTENTION_PROCESSORS = (
60235887 AttnAddedKVProcessor ,
60245888 SlicedAttnAddedKVProcessor ,
0 commit comments