@@ -717,6 +717,218 @@ def fuse_projections(self, fuse=True):
717717 self .fused_projections = fuse
718718
719719
720+ class AsymmetricAttention (nn .Module ):
721+ def __init__ (
722+ self ,
723+ query_dim : int ,
724+ query_context_dim : int ,
725+ num_attention_heads : int = 8 ,
726+ attention_head_dim : int = 64 ,
727+ bias : bool = False ,
728+ context_bias : bool = False ,
729+ out_dim : Optional [int ] = None ,
730+ out_context_dim : Optional [int ] = None ,
731+ qk_norm : Optional [str ] = None ,
732+ eps : float = 1e-5 ,
733+ elementwise_affine : bool = True ,
734+ processor : Optional ["AttnProcessor" ] = None ,
735+ ) -> None :
736+ super ().__init__ ()
737+
738+ from .normalization import RMSNorm
739+
740+ self .query_dim = query_dim
741+ self .query_context_dim = query_context_dim
742+ self .inner_dim = out_dim if out_dim is not None else num_attention_heads * attention_head_dim
743+ self .out_dim = out_dim if out_dim is not None else query_dim
744+
745+ self .scale = attention_head_dim ** - 0.5
746+ self .num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attention_heads
747+
748+ if qk_norm is None :
749+ self .norm_q = None
750+ self .norm_k = None
751+ self .norm_context_q = None
752+ self .norm_context_k = None
753+ elif qk_norm == "rms_norm" :
754+ self .norm_q = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
755+ self .norm_k = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
756+ self .norm_context_q = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
757+ self .norm_context_k = RMSNorm (attention_head_dim , eps = eps , elementwise_affine = elementwise_affine )
758+ else :
759+ raise ValueError ((f"Unknown qk_norm: { qk_norm } . Should be None or `rms_norm`." ))
760+
761+ self .to_q = nn .Linear (query_dim , self .inner_dim , bias = bias )
762+ self .to_k = nn .Linear (query_dim , self .inner_dim , bias = bias )
763+ self .to_v = nn .Linear (query_dim , self .inner_dim , bias = bias )
764+
765+ self .to_context_q = nn .Linear (query_context_dim , self .inner_dim , bias = context_bias )
766+ self .to_context_k = nn .Linear (query_context_dim , self .inner_dim , bias = context_bias )
767+ self .to_context_v = nn .Linear (query_context_dim , self .inner_dim , bias = context_bias )
768+
769+ # TODO(aryan): Take care of dropouts for training purpose in future
770+ self .to_out = nn .ModuleList ([
771+ nn .Linear (self .inner_dim , self .out_dim )
772+ ])
773+
774+ self .to_context_out = None
775+ if out_context_dim is not None :
776+ self .to_context_out = nn .ModuleList ([
777+ nn .Linear (self .inner_dim , out_context_dim )
778+ ])
779+
780+ if processor is None :
781+ processor = AsymmetricAttnProcessor2_0 ()
782+
783+ self .set_processor (processor )
784+
785+ def set_processor (self , processor : "AttnProcessor" ) -> None :
786+ r"""
787+ Set the attention processor to use.
788+
789+ Args:
790+ processor (`AttnProcessor`):
791+ The attention processor to use.
792+ """
793+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
794+ # pop `processor` from `self._modules`
795+ if (
796+ hasattr (self , "processor" )
797+ and isinstance (self .processor , torch .nn .Module )
798+ and not isinstance (processor , torch .nn .Module )
799+ ):
800+ logger .info (f"You are removing possibly trained weights of { self .processor } with { processor } " )
801+ self ._modules .pop ("processor" )
802+
803+ self .processor = processor
804+
805+ def get_processor (self ) -> "AttentionProcessor" :
806+ r"""
807+ Get the attention processor in use.
808+
809+ Returns:
810+ "AttentionProcessor": The attention processor in use.
811+ """
812+ return self .processor
813+
814+ def forward (
815+ self ,
816+ hidden_states : torch .Tensor ,
817+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
818+ attention_mask : Optional [torch .Tensor ] = None ,
819+ ** cross_attention_kwargs ,
820+ ) -> torch .Tensor :
821+ r"""
822+ The forward method of the `Attention` class.
823+
824+ Args:
825+ hidden_states (`torch.Tensor`):
826+ The hidden states of the query.
827+ encoder_hidden_states (`torch.Tensor`, *optional*):
828+ The hidden states of the encoder.
829+ attention_mask (`torch.Tensor`, *optional*):
830+ The attention mask to use. If `None`, no mask is applied.
831+ **cross_attention_kwargs:
832+ Additional keyword arguments to pass along to the cross attention.
833+
834+ Returns:
835+ `torch.Tensor`: The output of the attention layer.
836+ """
837+ # The `Attention` class can call different attention processors / attention functions
838+ # here we simply pass along all tensors to the selected processor class
839+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
840+
841+ attn_parameters = set (inspect .signature (self .processor .__call__ ).parameters .keys ())
842+ quiet_attn_parameters = {"ip_adapter_masks" }
843+ unused_kwargs = [
844+ k for k , _ in cross_attention_kwargs .items () if k not in attn_parameters and k not in quiet_attn_parameters
845+ ]
846+ if len (unused_kwargs ) > 0 :
847+ logger .warning (
848+ f"cross_attention_kwargs { unused_kwargs } are not expected by { self .processor .__class__ .__name__ } and will be ignored."
849+ )
850+ cross_attention_kwargs = {k : w for k , w in cross_attention_kwargs .items () if k in attn_parameters }
851+
852+ return self .processor (
853+ self ,
854+ hidden_states ,
855+ encoder_hidden_states = encoder_hidden_states ,
856+ attention_mask = attention_mask ,
857+ ** cross_attention_kwargs ,
858+ )
859+
860+
861+ class AsymmetricAttnProcessor2_0 :
862+ r"""
863+ Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link).
864+ """
865+
866+ def __init__ (self ):
867+ if not hasattr (F , "scaled_dot_product_attention" ):
868+ raise ImportError ("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
869+
870+ def __call__ (
871+ self ,
872+ attn : AsymmetricAttention ,
873+ hidden_states : torch .Tensor ,
874+ encoder_hidden_states : torch .Tensor ,
875+ image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
876+ ) -> torch .Tensor :
877+ batch_size = hidden_states .size (0 )
878+ query = attn .to_q (hidden_states )
879+ key = attn .to_k (hidden_states )
880+ value = attn .to_v (hidden_states )
881+
882+ query_context = attn .to_context_q (encoder_hidden_states )
883+ key_context = attn .to_context_k (encoder_hidden_states )
884+ value_context = attn .to_context_v (encoder_hidden_states )
885+
886+ inner_dim = key .shape [- 1 ]
887+ head_dim = inner_dim / attn .num_attention_heads
888+
889+ query = query .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
890+ key = key .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
891+ value = value .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
892+
893+ query_context = query_context .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
894+ key_context = key_context .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
895+ value_context = value_context .unflatten (2 , (attn .num_attention_heads , head_dim )).transpose (1 , 2 )
896+
897+ if attn .norm_q is not None :
898+ query = attn .norm_q (query )
899+ if attn .norm_k is not None :
900+ key = attn .norm_k (key )
901+
902+ if attn .norm_context_q is not None :
903+ query_context = attn .norm_context_q (query_context )
904+ if attn .norm_context_k is not None :
905+ key_context = attn .norm_context_k (key_context )
906+
907+ if image_rotary_emb is not None :
908+ from .embeddings import apply_rotary_emb
909+ query = apply_rotary_emb (query , image_rotary_emb )
910+ key = apply_rotary_emb (key , image_rotary_emb )
911+
912+ sequence_length = query .size (1 )
913+ context_sequence_length = query_context .size (1 )
914+
915+ query = torch .cat ([query , query_context ], dim = 1 )
916+ key = torch .cat ([key , key_context ], dim = 1 )
917+ value = torch .cat ([value , value_context ], dim = 1 )
918+
919+ hidden_states = F .scaled_dot_product_attention (
920+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False
921+ )
922+ hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
923+ hidden_states = hidden_states .to (query .dtype )
924+ hidden_states , encoder_hidden_states = hidden_states .split_with_sizes ([sequence_length , context_sequence_length ], dim = 1 )
925+
926+ hidden_states = attn .to_out [0 ](hidden_states )
927+ encoder_hidden_states = attn .to_context_out [0 ](encoder_hidden_states )
928+
929+ return hidden_states , encoder_hidden_states
930+
931+
720932class AttnProcessor :
721933 r"""
722934 Default processor for performing attention-related computations.
0 commit comments