1414
1515from typing import Any , Dict , Optional , Tuple , Union
1616
17+ import math
1718import torch
1819import torch .nn .functional as F
1920from torch import nn
@@ -184,6 +185,91 @@ def __call__(
184185
185186 return hidden_states
186187
188+
189+ class SanaAttnProcessor3_0 :
190+ r"""
191+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192+ """
193+
194+ def __init__ (self ):
195+ if not hasattr (F , "scaled_dot_product_attention" ):
196+ raise ImportError ("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
197+
198+ @staticmethod
199+ def scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None
200+ ) -> torch .Tensor :
201+ B , H , L , S = * query .size ()[:- 1 ], key .size (- 2 )
202+ scale_factor = 1 / math .sqrt (query .size (- 1 )) if scale is None else scale
203+ attn_bias = torch .zeros (B , H , L , S , dtype = query .dtype , device = query .device )
204+
205+ if attn_mask is not None :
206+ if attn_mask .dtype == torch .bool :
207+ attn_bias .masked_fill_ (attn_mask .logical_not (), float ("-inf" ))
208+ else :
209+ attn_bias += attn_mask
210+ attn_weight = query @ key .transpose (- 2 , - 1 ) * scale_factor
211+ attn_weight += attn_bias
212+ attn_weight = torch .softmax (attn_weight , dim = - 1 )
213+ attn_weight = torch .dropout (attn_weight , dropout_p , train = True )
214+ return attn_weight @ value
215+
216+ # return x
217+ def __call__ (
218+ self ,
219+ attn : Attention ,
220+ hidden_states : torch .Tensor ,
221+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
222+ attention_mask : Optional [torch .Tensor ] = None ,
223+ ) -> torch .Tensor :
224+ batch_size , sequence_length , _ = (
225+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
226+ )
227+
228+ if attention_mask is not None :
229+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
230+ # scaled_dot_product_attention expects attention_mask shape to be
231+ # (batch, heads, source_length, target_length)
232+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
233+
234+ query = attn .to_q (hidden_states )
235+
236+ if encoder_hidden_states is None :
237+ encoder_hidden_states = hidden_states
238+
239+ key = attn .to_k (encoder_hidden_states )
240+ value = attn .to_v (encoder_hidden_states )
241+
242+ if attn .norm_q is not None :
243+ query = attn .norm_q (query )
244+ if attn .norm_k is not None :
245+ key = attn .norm_k (key )
246+
247+ inner_dim = key .shape [- 1 ]
248+ head_dim = inner_dim // attn .heads
249+
250+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
251+
252+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
253+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
254+
255+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
256+ # TODO: add support for attn.scale when we move to Torch 2.1
257+ hidden_states = self .scaled_dot_product_attention (
258+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
259+ )
260+
261+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
262+ hidden_states = hidden_states .to (query .dtype )
263+
264+ # linear proj
265+ hidden_states = attn .to_out [0 ](hidden_states )
266+ # dropout
267+ hidden_states = attn .to_out [1 ](hidden_states )
268+
269+ hidden_states = hidden_states / attn .rescale_output_factor
270+
271+ return hidden_states
272+
187273
188274class SanaTransformerBlock (nn .Module ):
189275 r"""
@@ -205,6 +291,7 @@ def __init__(
205291 attention_out_bias : bool = True ,
206292 mlp_ratio : float = 2.5 ,
207293 qk_norm : Optional [str ] = None ,
294+ cross_attention_type : str = "flash" ,
208295 ) -> None :
209296 super ().__init__ ()
210297
@@ -223,6 +310,12 @@ def __init__(
223310 )
224311
225312 # 2. Cross Attention
313+ if cross_attention_type == "flash" :
314+ cross_attention_processor = SanaAttnProcessor2_0 ()
315+ elif cross_attention_type == "vanilla" :
316+ cross_attention_processor = SanaAttnProcessor3_0 ()
317+ else :
318+ raise ValueError (f"Cross attention type { cross_attention_type } is not defined." )
226319 if cross_attention_dim is not None :
227320 self .norm2 = nn .LayerNorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
228321 self .attn2 = Attention (
@@ -235,7 +328,7 @@ def __init__(
235328 dropout = dropout ,
236329 bias = True ,
237330 out_bias = attention_out_bias ,
238- processor = SanaAttnProcessor2_0 () ,
331+ processor = cross_attention_processor ,
239332 )
240333
241334 # 3. Feed-forward
@@ -360,6 +453,7 @@ def __init__(
360453 guidance_embeds_scale : float = 0.1 ,
361454 qk_norm : Optional [str ] = None ,
362455 timestep_scale : float = 1.0 ,
456+ cross_attention_type : str = "flash" ,
363457 ) -> None :
364458 super ().__init__ ()
365459
@@ -402,6 +496,7 @@ def __init__(
402496 norm_eps = norm_eps ,
403497 mlp_ratio = mlp_ratio ,
404498 qk_norm = qk_norm ,
499+ cross_attention_type = cross_attention_type ,
405500 )
406501 for _ in range (num_layers )
407502 ]
0 commit comments