@@ -124,6 +124,7 @@ def __init__(
124124 context_pre_only = None ,
125125 pre_only = False ,
126126 elementwise_affine : bool = True ,
127+ is_causal : bool = False ,
127128 ):
128129 super ().__init__ ()
129130
@@ -146,6 +147,7 @@ def __init__(
146147 self .out_context_dim = out_context_dim if out_context_dim is not None else query_dim
147148 self .context_pre_only = context_pre_only
148149 self .pre_only = pre_only
150+ self .is_causal = is_causal
149151
150152 # we make use of this private variable to know whether this class is loaded
151153 # with an deprecated state dict so that we can convert it on the fly
@@ -195,8 +197,8 @@ def __init__(
195197 self .norm_q = RMSNorm (dim_head , eps = eps )
196198 self .norm_k = RMSNorm (dim_head , eps = eps )
197199 elif qk_norm == "l2" :
198- self .norm_q = LpNorm (p = 2 , eps = eps )
199- self .norm_k = LpNorm (p = 2 , eps = eps )
200+ self .norm_q = LpNorm (p = 2 , dim = - 1 )
201+ self .norm_k = LpNorm (p = 2 , dim = - 1 )
200202 else :
201203 raise ValueError (f"unknown qk_norm: { qk_norm } . Should be None,'layer_norm','fp32_layer_norm','rms_norm'" )
202204
@@ -2720,6 +2722,91 @@ def __call__(
27202722 return hidden_states
27212723
27222724
2725+ class MochiVaeAttnProcessor2_0 :
2726+ r"""
2727+ Attention processor used in Mochi VAE.
2728+ """
2729+
2730+ def __init__ (self ):
2731+ if not hasattr (F , "scaled_dot_product_attention" ):
2732+ raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
2733+
2734+ def __call__ (
2735+ self ,
2736+ attn : Attention ,
2737+ hidden_states : torch .Tensor ,
2738+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
2739+ attention_mask : Optional [torch .Tensor ] = None ,
2740+ ) -> torch .Tensor :
2741+ residual = hidden_states
2742+ is_single_frame = hidden_states .shape [1 ] == 1
2743+
2744+ batch_size , sequence_length , _ = (
2745+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
2746+ )
2747+
2748+ if attention_mask is not None :
2749+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
2750+ # scaled_dot_product_attention expects attention_mask shape to be
2751+ # (batch, heads, source_length, target_length)
2752+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
2753+
2754+ if is_single_frame :
2755+ hidden_states = attn .to_v (hidden_states )
2756+
2757+ # linear proj
2758+ hidden_states = attn .to_out [0 ](hidden_states )
2759+ # dropout
2760+ hidden_states = attn .to_out [1 ](hidden_states )
2761+
2762+ if attn .residual_connection :
2763+ hidden_states = hidden_states + residual
2764+
2765+ hidden_states = hidden_states / attn .rescale_output_factor
2766+ return hidden_states
2767+
2768+ query = attn .to_q (hidden_states )
2769+
2770+ if encoder_hidden_states is None :
2771+ encoder_hidden_states = hidden_states
2772+
2773+ key = attn .to_k (encoder_hidden_states )
2774+ value = attn .to_v (encoder_hidden_states )
2775+
2776+ inner_dim = key .shape [- 1 ]
2777+ head_dim = inner_dim // attn .heads
2778+
2779+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2780+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2781+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2782+
2783+ if attn .norm_q is not None :
2784+ query = attn .norm_q (query )
2785+ if attn .norm_k is not None :
2786+ key = attn .norm_k (key )
2787+
2788+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2789+ # TODO: add support for attn.scale when we move to Torch 2.1
2790+ hidden_states = F .scaled_dot_product_attention (
2791+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = attn .is_causal
2792+ )
2793+
2794+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2795+ hidden_states = hidden_states .to (query .dtype )
2796+
2797+ # linear proj
2798+ hidden_states = attn .to_out [0 ](hidden_states )
2799+ # dropout
2800+ hidden_states = attn .to_out [1 ](hidden_states )
2801+
2802+ if attn .residual_connection :
2803+ hidden_states = hidden_states + residual
2804+
2805+ hidden_states = hidden_states / attn .rescale_output_factor
2806+
2807+ return hidden_states
2808+
2809+
27232810class StableAudioAttnProcessor2_0 :
27242811 r"""
27252812 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
0 commit comments