3030from ..modeling_outputs import Transformer2DModelOutput
3131from ..modeling_utils import ModelMixin
3232from ..normalization import AdaLayerNormSingle , RMSNorm
33+ from ..embeddings import TimestepEmbedding , Timesteps
3334
35+ import torch .nn .functional as F
3436
3537logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3638
@@ -96,6 +98,102 @@ def forward(
9698 return hidden_states
9799
98100
101+ class SanaCombinedTimestepGuidanceEmbeddings (nn .Module ):
102+ """
103+ For Sana.
104+
105+ Reference:
106+ """
107+
108+ def __init__ (self , embedding_dim ):
109+ super ().__init__ ()
110+ self .time_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
111+ self .timestep_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
112+
113+ self .guidance_condition_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
114+ self .guidance_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
115+
116+ self .silu = nn .SiLU ()
117+ self .linear = nn .Linear (embedding_dim , 6 * embedding_dim , bias = True )
118+
119+ def forward (self , timestep : torch .Tensor , guidance : torch .Tensor = None , hidden_dtype : torch .dtype = None ):
120+ timesteps_proj = self .time_proj (timestep )
121+ timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = hidden_dtype )) # (N, D)
122+
123+ guidance_proj = self .guidance_condition_proj (guidance )
124+ guidance_emb = self .guidance_embedder (guidance_proj .to (dtype = hidden_dtype ))
125+ conditioning = timesteps_emb + guidance_emb
126+
127+ return self .linear (self .silu (conditioning )), conditioning
128+
129+
130+
131+ class SanaAttnProcessor2_0 :
132+ r"""
133+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
134+ """
135+
136+ def __init__ (self ):
137+ if not hasattr (F , "scaled_dot_product_attention" ):
138+ raise ImportError ("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
139+
140+ def __call__ (
141+ self ,
142+ attn : Attention ,
143+ hidden_states : torch .Tensor ,
144+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
145+ attention_mask : Optional [torch .Tensor ] = None ,
146+ ) -> torch .Tensor :
147+
148+ batch_size , sequence_length , _ = (
149+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
150+ )
151+
152+ if attention_mask is not None :
153+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
154+ # scaled_dot_product_attention expects attention_mask shape to be
155+ # (batch, heads, source_length, target_length)
156+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
157+
158+ query = attn .to_q (hidden_states )
159+
160+ if encoder_hidden_states is None :
161+ encoder_hidden_states = hidden_states
162+
163+ key = attn .to_k (encoder_hidden_states )
164+ value = attn .to_v (encoder_hidden_states )
165+
166+ if attn .norm_q is not None :
167+ query = attn .norm_q (query )
168+ if attn .norm_k is not None :
169+ key = attn .norm_k (key )
170+
171+ inner_dim = key .shape [- 1 ]
172+ head_dim = inner_dim // attn .heads
173+
174+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
175+
176+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
177+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
178+
179+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
180+ # TODO: add support for attn.scale when we move to Torch 2.1
181+ hidden_states = F .scaled_dot_product_attention (
182+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
183+ )
184+
185+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
186+ hidden_states = hidden_states .to (query .dtype )
187+
188+ # linear proj
189+ hidden_states = attn .to_out [0 ](hidden_states )
190+ # dropout
191+ hidden_states = attn .to_out [1 ](hidden_states )
192+
193+ hidden_states = hidden_states / attn .rescale_output_factor
194+
195+ return hidden_states
196+
99197class SanaTransformerBlock (nn .Module ):
100198 r"""
101199 Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -115,6 +213,7 @@ def __init__(
115213 norm_eps : float = 1e-6 ,
116214 attention_out_bias : bool = True ,
117215 mlp_ratio : float = 2.5 ,
216+ qk_norm : Optional [str ] = None ,
118217 ) -> None :
119218 super ().__init__ ()
120219
@@ -124,6 +223,8 @@ def __init__(
124223 query_dim = dim ,
125224 heads = num_attention_heads ,
126225 dim_head = attention_head_dim ,
226+ kv_heads = num_attention_heads if qk_norm is not None else None ,
227+ qk_norm = qk_norm ,
127228 dropout = dropout ,
128229 bias = attention_bias ,
129230 cross_attention_dim = None ,
@@ -135,13 +236,15 @@ def __init__(
135236 self .norm2 = nn .LayerNorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
136237 self .attn2 = Attention (
137238 query_dim = dim ,
239+ qk_norm = qk_norm ,
240+ kv_heads = num_cross_attention_heads if qk_norm is not None else None ,
138241 cross_attention_dim = cross_attention_dim ,
139242 heads = num_cross_attention_heads ,
140243 dim_head = cross_attention_head_dim ,
141244 dropout = dropout ,
142245 bias = True ,
143246 out_bias = attention_out_bias ,
144- processor = AttnProcessor2_0 (),
247+ processor = SanaAttnProcessor2_0 (),
145248 )
146249
147250 # 3. Feed-forward
@@ -258,6 +361,8 @@ def __init__(
258361 norm_elementwise_affine : bool = False ,
259362 norm_eps : float = 1e-6 ,
260363 interpolation_scale : Optional [int ] = None ,
364+ guidance_embeds : bool = False ,
365+ qk_norm : Optional [str ] = None ,
261366 ) -> None :
262367 super ().__init__ ()
263368
@@ -276,7 +381,10 @@ def __init__(
276381 )
277382
278383 # 2. Additional condition embeddings
279- self .time_embed = AdaLayerNormSingle (inner_dim )
384+ if guidance_embeds :
385+ self .time_embed = SanaCombinedTimestepGuidanceEmbeddings (inner_dim )
386+ else :
387+ self .time_embed = AdaLayerNormSingle (inner_dim )
280388
281389 self .caption_projection = PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = inner_dim )
282390 self .caption_norm = RMSNorm (inner_dim , eps = 1e-5 , elementwise_affine = True )
@@ -296,6 +404,7 @@ def __init__(
296404 norm_elementwise_affine = norm_elementwise_affine ,
297405 norm_eps = norm_eps ,
298406 mlp_ratio = mlp_ratio ,
407+ qk_norm = qk_norm ,
299408 )
300409 for _ in range (num_layers )
301410 ]
@@ -372,7 +481,8 @@ def forward(
372481 self ,
373482 hidden_states : torch .Tensor ,
374483 encoder_hidden_states : torch .Tensor ,
375- timestep : torch .LongTensor ,
484+ timestep : torch .Tensor ,
485+ guidance : Optional [torch .Tensor ] = None ,
376486 encoder_attention_mask : Optional [torch .Tensor ] = None ,
377487 attention_mask : Optional [torch .Tensor ] = None ,
378488 attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -423,9 +533,14 @@ def forward(
423533
424534 hidden_states = self .patch_embed (hidden_states )
425535
426- timestep , embedded_timestep = self .time_embed (
427- timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype
428- )
536+ if guidance is not None :
537+ timestep , embedded_timestep = self .time_embed (
538+ timestep , guidance = guidance , hidden_dtype = hidden_states .dtype
539+ )
540+ else :
541+ timestep , embedded_timestep = self .time_embed (
542+ timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype
543+ )
429544
430545 encoder_hidden_states = self .caption_projection (encoder_hidden_states )
431546 encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
0 commit comments