@@ -120,6 +120,7 @@ def __init__(
120120 _from_deprecated_attn_block : bool = False ,
121121 processor : Optional ["AttnProcessor" ] = None ,
122122 out_dim : int = None ,
123+ out_context_dim : int = None ,
123124 context_pre_only = None ,
124125 pre_only = False ,
125126 elementwise_affine : bool = True ,
@@ -142,6 +143,7 @@ def __init__(
142143 self .dropout = dropout
143144 self .fused_projections = False
144145 self .out_dim = out_dim if out_dim is not None else query_dim
146+ self .out_context_dim = out_context_dim if out_context_dim is not None else query_dim
145147 self .context_pre_only = context_pre_only
146148 self .pre_only = pre_only
147149
@@ -241,7 +243,7 @@ def __init__(
241243 self .to_out .append (nn .Dropout (dropout ))
242244
243245 if self .context_pre_only is not None and not self .context_pre_only :
244- self .to_add_out = nn .Linear (self .inner_dim , self .out_dim , bias = out_bias )
246+ self .to_add_out = nn .Linear (self .inner_dim , self .out_context_dim , bias = out_bias )
245247
246248 if qk_norm is not None and added_kv_proj_dim is not None :
247249 if qk_norm == "fp32_layer_norm" :
@@ -717,221 +719,6 @@ def fuse_projections(self, fuse=True):
717719 self .fused_projections = fuse
718720
719721
720- class AsymmetricAttention (nn .Module ):
721- def __init__ (
722- self ,
723- query_dim : int ,
724- query_context_dim : int ,
725- num_attention_heads : int = 8 ,
726- attention_head_dim : int = 64 ,
727- bias : bool = False ,
728- context_bias : bool = False ,
729- out_dim : Optional [int ] = None ,
730- out_context_dim : Optional [int ] = None ,
731- qk_norm : Optional [str ] = None ,
732- eps : float = 1e-5 ,
733- elementwise_affine : bool = True ,
734- processor : Optional ["AttnProcessor" ] = None ,
735- ) -> None :
736- super ().__init__ ()
737-
738- from .normalization import RMSNorm
739-
740- self .query_dim = query_dim
741- self .query_context_dim = query_context_dim
742- self .inner_dim = out_dim if out_dim is not None else num_attention_heads * attention_head_dim
743- self .out_dim = out_dim if out_dim is not None else query_dim
744-
745- self .scale = attention_head_dim ** - 0.5
746- self .num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attention_heads
747-
748- if qk_norm is None :
749- self .norm_q = None
750- self .norm_k = None
751- self .norm_context_q = None
752- self .norm_context_k = None
753- elif qk_norm == "rms_norm" :
754- self .norm_q = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
755- self .norm_k = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
756- self .norm_context_q = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
757- self .norm_context_k = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
758- else :
759- raise ValueError ((f"Unknown qk_norm: { qk_norm } . Should be None or `rms_norm`." ))
760-
761- self .to_q = nn .Linear (query_dim , self .inner_dim , bias = bias )
762- self .to_k = nn .Linear (query_dim , self .inner_dim , bias = bias )
763- self .to_v = nn .Linear (query_dim , self .inner_dim , bias = bias )
764-
765- self .to_context_q = nn .Linear (query_context_dim , self .inner_dim , bias = context_bias )
766- self .to_context_k = nn .Linear (query_context_dim , self .inner_dim , bias = context_bias )
767- self .to_context_v = nn .Linear (query_context_dim , self .inner_dim , bias = context_bias )
768-
769- # TODO(aryan): Take care of dropouts for training purpose in future
770- self .to_out = nn .ModuleList ([
771- nn .Linear (self .inner_dim , self .out_dim )
772- ])
773-
774- if out_context_dim is not None :
775- self .to_context_out = nn .ModuleList ([
776- nn .Linear (self .inner_dim , out_context_dim )
777- ])
778- else :
779- self .to_context_out = nn .ModuleList ([
780- nn .Identity ()
781- ])
782-
783- if processor is None :
784- processor = AsymmetricAttnProcessor2_0 ()
785-
786- self .set_processor (processor )
787-
788- def set_processor (self , processor : "AttnProcessor" ) -> None :
789- r"""
790- Set the attention processor to use.
791-
792- Args:
793- processor (`AttnProcessor`):
794- The attention processor to use.
795- """
796- # if current processor is in `self._modules` and if passed `processor` is not, we need to
797- # pop `processor` from `self._modules`
798- if (
799- hasattr (self , "processor" )
800- and isinstance (self .processor , torch .nn .Module )
801- and not isinstance (processor , torch .nn .Module )
802- ):
803- logger .info (f"You are removing possibly trained weights of { self .processor } with { processor } " )
804- self ._modules .pop ("processor" )
805-
806- self .processor = processor
807-
808- def get_processor (self ) -> "AttentionProcessor" :
809- r"""
810- Get the attention processor in use.
811-
812- Returns:
813- "AttentionProcessor": The attention processor in use.
814- """
815- return self .processor
816-
817- def forward (
818- self ,
819- hidden_states : torch .Tensor ,
820- encoder_hidden_states : Optional [torch .Tensor ] = None ,
821- attention_mask : Optional [torch .Tensor ] = None ,
822- ** cross_attention_kwargs ,
823- ) -> torch .Tensor :
824- r"""
825- The forward method of the `Attention` class.
826-
827- Args:
828- hidden_states (`torch.Tensor`):
829- The hidden states of the query.
830- encoder_hidden_states (`torch.Tensor`, *optional*):
831- The hidden states of the encoder.
832- attention_mask (`torch.Tensor`, *optional*):
833- The attention mask to use. If `None`, no mask is applied.
834- **cross_attention_kwargs:
835- Additional keyword arguments to pass along to the cross attention.
836-
837- Returns:
838- `torch.Tensor`: The output of the attention layer.
839- """
840- # The `Attention` class can call different attention processors / attention functions
841- # here we simply pass along all tensors to the selected processor class
842- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
843-
844- attn_parameters = set (inspect .signature (self .processor .__call__ ).parameters .keys ())
845- quiet_attn_parameters = {"ip_adapter_masks" }
846- unused_kwargs = [
847- k for k , _ in cross_attention_kwargs .items () if k not in attn_parameters and k not in quiet_attn_parameters
848- ]
849- if len (unused_kwargs ) > 0 :
850- logger .warning (
851- f"cross_attention_kwargs { unused_kwargs } are not expected by { self .processor .__class__ .__name__ } and will be ignored."
852- )
853- cross_attention_kwargs = {k : w for k , w in cross_attention_kwargs .items () if k in attn_parameters }
854-
855- return self .processor (
856- self ,
857- hidden_states ,
858- encoder_hidden_states = encoder_hidden_states ,
859- attention_mask = attention_mask ,
860- ** cross_attention_kwargs ,
861- )
862-
863-
864- class AsymmetricAttnProcessor2_0 :
865- r"""
866- Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link).
867- """
868-
869- def __init__ (self ):
870- if not hasattr (F , "scaled_dot_product_attention" ):
871- raise ImportError ("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
872-
873- def __call__ (
874- self ,
875- attn : AsymmetricAttention ,
876- hidden_states : torch .Tensor ,
877- encoder_hidden_states : torch .Tensor ,
878- image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
879- ) -> torch .Tensor :
880- batch_size = hidden_states .size (0 )
881- query = attn .to_q (hidden_states )
882- key = attn .to_k (hidden_states )
883- value = attn .to_v (hidden_states )
884-
885- query_context = attn .to_context_q (encoder_hidden_states )
886- key_context = attn .to_context_k (encoder_hidden_states )
887- value_context = attn .to_context_v (encoder_hidden_states )
888-
889- inner_dim = key .shape [- 1 ]
890- head_dim = inner_dim / attn .num_attention_heads
891-
892- query = query .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
893- key = key .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
894- value = value .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
895-
896- query_context = query_context .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
897- key_context = key_context .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
898- value_context = value_context .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
899-
900- if attn .norm_q is not None :
901- query = attn .norm_q (query )
902- if attn .norm_k is not None :
903- key = attn .norm_k (key )
904-
905- if attn .norm_context_q is not None :
906- query_context = attn .norm_context_q (query_context )
907- if attn .norm_context_k is not None :
908- key_context = attn .norm_context_k (key_context )
909-
910- if image_rotary_emb is not None :
911- from .embeddings import apply_rotary_emb
912- query = apply_rotary_emb (query , image_rotary_emb )
913- key = apply_rotary_emb (key , image_rotary_emb )
914-
915- sequence_length = query .size (1 )
916- context_sequence_length = query_context .size (1 )
917-
918- query = torch .cat ([query , query_context ], dim = 1 )
919- key = torch .cat ([key , key_context ], dim = 1 )
920- value = torch .cat ([value , value_context ], dim = 1 )
921-
922- hidden_states = F .scaled_dot_product_attention (
923- query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False
924- )
925- hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
926- hidden_states = hidden_states .to (query .dtype )
927- hidden_states , encoder_hidden_states = hidden_states .split_with_sizes ([sequence_length , context_sequence_length ], dim = 1 )
928-
929- hidden_states = attn .to_out [0 ](hidden_states )
930- encoder_hidden_states = attn .to_context_out [0 ](encoder_hidden_states )
931-
932- return hidden_states , encoder_hidden_states
933-
934-
935722class AttnProcessor :
936723 r"""
937724 Default processor for performing attention-related computations.
0 commit comments