1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import numbers
1415from typing import Any , Dict , Optional , Union
1516
1617import torch
2526 AttnProcessor2_0 ,
2627 FusedAttnProcessor2_0 ,
2728 SanaLinearAttnProcessor2_0 ,
29+ SanaMultiscaleAttnProcessor2_0 ,
30+ SanaMultiscaleLinearAttention ,
2831)
2932from ..embeddings import PatchEmbed , PixArtAlphaTextProjection , SinusoidalPositionalEmbedding
3033from ..modeling_outputs import Transformer2DModelOutput
3134from ..modeling_utils import ModelMixin
32- from ..normalization import AdaLayerNormSingle , RMSNorm
35+ from ..normalization import AdaLayerNormSingle , RMSNormScaled
3336
3437
3538logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3639
3740
38- def _chunked_feed_forward (ff : nn .Module , hidden_states : torch .Tensor , chunk_dim : int , chunk_size : int , HW : tuple = None ):
39- # "feed_forward_chunk_size" can be used to save memory
40- if hidden_states .shape [chunk_dim ] % chunk_size != 0 :
41- raise ValueError (
42- f"`hidden_states` dimension to be chunked: { hidden_states .shape [chunk_dim ]} has to be divisible by chunk size: { chunk_size } . Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
43- )
44-
45- num_chunks = hidden_states .shape [chunk_dim ] // chunk_size
46- ff_output = torch .cat (
47- [ff (hid_slice , HW ) for hid_slice in hidden_states .chunk (num_chunks , dim = chunk_dim )],
48- dim = chunk_dim ,
49- )
50- return ff_output
51-
52-
53- @maybe_allow_in_graph
54- class RMSNormScaled (RMSNorm ):
55- def __init__ (self , dim , eps : float , elementwise_affine : bool = True , scale_factor : float = 1.0 ):
56- super ().__init__ (dim , eps , elementwise_affine )
57- self .weight = nn .Parameter (torch .ones (dim ) * scale_factor )
58-
59-
6041# Modified from diffusers.models.autoencoders.autoencoder_dc.GLUMBConv
6142@maybe_allow_in_graph
6243class SanaGLUMBConv (nn .Module ):
@@ -105,22 +86,19 @@ class SanaLinearTransformerBlock(nn.Module):
10586 dim (`int`): The number of channels in the input and output.
10687 num_attention_heads (`int`): The number of heads to use for multi-head attention.
10788 attention_head_dim (`int`): The number of channels in each head.
108- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
109- processing of `context` conditions.
11089 """
11190
11291 def __init__ (
11392 self ,
114- dim : int ,
115- num_attention_heads : int ,
116- attention_head_dim : int ,
117- dropout = 0.0 ,
118- num_cross_attention_heads : Optional [int ] = None ,
119- cross_attention_head_dim : Optional [int ] = None ,
120- cross_attention_dim : Optional [int ] = None ,
121- activation_fn : tuple = ("silu" , "silu" , None ),
122- num_embeds_ada_norm : Optional [int ] = None ,
123- attention_bias : bool = False ,
93+ dim : int = 2240 ,
94+ num_attention_heads : int = 70 ,
95+ attention_head_dim : int = 32 ,
96+ dropout : float = 0.0 ,
97+ num_cross_attention_heads : Optional [int ] = 20 ,
98+ cross_attention_head_dim : Optional [int ] = 112 ,
99+ cross_attention_dim : Optional [int ] = 2240 ,
100+ num_embeds_ada_norm : Optional [int ] = 1000 ,
101+ attention_bias : bool = True ,
124102 upcast_attention : bool = False ,
125103 norm_type : str = "ada_norm_single" ,
126104 norm_elementwise_affine : bool = False ,
@@ -136,7 +114,6 @@ def __init__(
136114 self .attention_head_dim = attention_head_dim
137115 self .dropout = dropout
138116 self .cross_attention_dim = cross_attention_dim
139- self .activation_fn = activation_fn
140117 self .attention_bias = attention_bias
141118 self .norm_elementwise_affine = norm_elementwise_affine
142119
@@ -205,8 +182,6 @@ def forward(
205182 encoder_attention_mask : Optional [torch .Tensor ] = None ,
206183 timestep : Optional [torch .LongTensor ] = None ,
207184 cross_attention_kwargs : Dict [str , Any ] = None ,
208- class_labels : Optional [torch .LongTensor ] = None ,
209- added_cond_kwargs : Optional [Dict [str , torch .Tensor ]] = None ,
210185 HW : Optional [tuple [int ]] = None ,
211186 ) -> torch .Tensor :
212187 if cross_attention_kwargs is not None :
@@ -260,11 +235,7 @@ def forward(
260235 norm_hidden_states = self .norm2 (hidden_states )
261236 norm_hidden_states = norm_hidden_states * (1 + scale_mlp ) + shift_mlp
262237
263- if self ._chunk_size is not None :
264- # "feed_forward_chunk_size" can be used to save memory
265- ff_output = _chunked_feed_forward (self .ff , norm_hidden_states , self ._chunk_dim , self ._chunk_size , HW = HW )
266- else :
267- ff_output = self .ff (norm_hidden_states , HW = HW )
238+ ff_output = self .ff (norm_hidden_states , HW = HW )
268239
269240 if self .norm_type == "ada_norm_single" :
270241 ff_output = gate_mlp * ff_output
@@ -301,8 +272,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
301272 The width of the latent images. This parameter is fixed during training.
302273 patch_size (int, defaults to 1):
303274 Size of the patches the model processes, relevant for architectures working on non-sequential data.
304- activation_fn (str, optional, defaults to "gelu-approximate"):
305- Activation function to use in feed-forward networks within Transformer blocks.
306275 num_embeds_ada_norm (int, optional, defaults to 1000):
307276 Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
308277 inference.
@@ -338,11 +307,10 @@ def __init__(
338307 norm_num_groups : int = 32 ,
339308 num_cross_attention_heads : Optional [int ] = 20 ,
340309 cross_attention_head_dim : Optional [int ] = 112 ,
341- cross_attention_dim : Optional [int ] = 1152 ,
310+ cross_attention_dim : Optional [int ] = 2240 ,
342311 attention_bias : bool = True ,
343312 sample_size : int = 32 ,
344313 patch_size : int = 1 ,
345- activation_fn : tuple = ("silu" , "silu" , None ),
346314 num_embeds_ada_norm : Optional [int ] = 1000 ,
347315 upcast_attention : bool = False ,
348316 norm_type : str = "ada_norm_single" ,
@@ -371,7 +339,7 @@ def __init__(
371339
372340 # Set some common variables used across the board.
373341 self .attention_head_dim = attention_head_dim
374- self .inner_dim = self . config . num_attention_heads * self . config . attention_head_dim
342+ self .inner_dim = num_attention_heads * attention_head_dim
375343 self .out_channels = in_channels if out_channels is None else out_channels
376344 if use_additional_conditions is None :
377345 if sample_size == 128 :
@@ -383,63 +351,66 @@ def __init__(
383351 self .gradient_checkpointing = False
384352
385353 # 2. Initialize the position embedding and transformer blocks.
386- self .height = self .config .sample_size
387- self .width = self .config .sample_size
354+ self .height = sample_size
355+ self .width = sample_size
356+
357+ if use_pe :
358+ interpolation_scale = (
359+ interpolation_scale
360+ if interpolation_scale is not None
361+ else max (sample_size // 64 , 1 )
362+ )
363+ else :
364+ interpolation_scale = None
388365
389- interpolation_scale = (
390- self .config .interpolation_scale
391- if self .config .interpolation_scale is not None
392- else max (self .config .sample_size // 64 , 1 )
393- )
394366 self .pos_embed = PatchEmbed (
395- height = self . config . sample_size ,
396- width = self . config . sample_size ,
397- patch_size = self . config . patch_size ,
398- in_channels = self . config . in_channels ,
367+ height = sample_size ,
368+ width = sample_size ,
369+ patch_size = patch_size ,
370+ in_channels = in_channels ,
399371 embed_dim = self .inner_dim ,
400372 interpolation_scale = interpolation_scale ,
401- pos_embed_type = "sincos" if self . config . use_pe else None
373+ pos_embed_type = "sincos" if use_pe else None
402374 )
403375
404376 self .transformer_blocks = nn .ModuleList (
405377 [
406378 SanaLinearTransformerBlock (
407379 self .inner_dim ,
408- self .config .num_attention_heads ,
409- self .config .attention_head_dim ,
410- dropout = self .config .dropout ,
411- num_cross_attention_heads = self .config .num_cross_attention_heads ,
412- cross_attention_head_dim = self .config .cross_attention_head_dim ,
413- cross_attention_dim = self .config .cross_attention_dim ,
414- activation_fn = self .config .activation_fn ,
415- num_embeds_ada_norm = self .config .num_embeds_ada_norm ,
416- attention_bias = self .config .attention_bias ,
417- upcast_attention = self .config .upcast_attention ,
380+ num_attention_heads ,
381+ attention_head_dim ,
382+ dropout = dropout ,
383+ num_cross_attention_heads = num_cross_attention_heads ,
384+ cross_attention_head_dim = cross_attention_head_dim ,
385+ cross_attention_dim = cross_attention_dim ,
386+ num_embeds_ada_norm = num_embeds_ada_norm ,
387+ attention_bias = attention_bias ,
388+ upcast_attention = upcast_attention ,
418389 norm_type = norm_type ,
419- norm_elementwise_affine = self . config . norm_elementwise_affine ,
420- norm_eps = self . config . norm_eps ,
421- use_pe = self . config . use_pe ,
422- expand_ratio = self . config . expand_ratio ,
390+ norm_elementwise_affine = norm_elementwise_affine ,
391+ norm_eps = norm_eps ,
392+ use_pe = use_pe ,
393+ expand_ratio = expand_ratio ,
423394 )
424- for _ in range (self . config . num_layers )
395+ for _ in range (num_layers )
425396 ]
426397 )
427398
428399 # 3. Output blocks.
429400 self .norm_out = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-6 )
430401 self .scale_shift_table = nn .Parameter (torch .randn (2 , self .inner_dim ) / self .inner_dim ** 0.5 )
431- self .proj_out = nn .Linear (self .inner_dim , self . config . patch_size * self . config . patch_size * self .out_channels )
402+ self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels )
432403
433404 self .adaln_single = AdaLayerNormSingle (
434405 self .inner_dim , use_additional_conditions = self .use_additional_conditions
435406 )
436407 self .caption_projection = None
437- if self . config . caption_channels is not None :
408+ if caption_channels is not None :
438409 self .caption_projection = PixArtAlphaTextProjection (
439- in_features = self . config . caption_channels , hidden_size = self .inner_dim
410+ in_features = caption_channels , hidden_size = self .inner_dim
440411 )
441412 self .caption_norm = None
442- if self . config . use_caption_norm :
413+ if use_caption_norm :
443414 self .caption_norm = RMSNormScaled (self .inner_dim , eps = 1e-5 , scale_factor = caption_norm_scale_factor )
444415
445416 def _set_gradient_checkpointing (self , module , value = False ):
@@ -506,46 +477,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
506477 for name , module in self .named_children ():
507478 fn_recursive_attn_processor (name , module , processor )
508479
509- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
510- def fuse_qkv_projections (self ):
511- """
512- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
513- are fused. For cross-attention modules, key and value projection matrices are fused.
514-
515- <Tip warning={true}>
516-
517- This API is 🧪 experimental.
518-
519- </Tip>
520- """
521- self .original_attn_processors = None
522-
523- for _ , attn_processor in self .attn_processors .items ():
524- if "Added" in str (attn_processor .__class__ .__name__ ):
525- raise ValueError ("`fuse_qkv_projections()` is not supported for models having added KV projections." )
526-
527- self .original_attn_processors = self .attn_processors
528-
529- for module in self .modules ():
530- if isinstance (module , Attention ):
531- module .fuse_projections (fuse = True )
532-
533- self .set_attn_processor (FusedAttnProcessor2_0 ())
534-
535- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
536- def unfuse_qkv_projections (self ):
537- """Disables the fused QKV projection if enabled.
538-
539- <Tip warning={true}>
540-
541- This API is 🧪 experimental.
542-
543- </Tip>
544-
545- """
546- if self .original_attn_processors is not None :
547- self .set_attn_processor (self .original_attn_processors )
548-
549480 def forward (
550481 self ,
551482 hidden_states : torch .Tensor ,
@@ -556,7 +487,6 @@ def forward(
556487 cross_attention_kwargs : Dict [str , Any ] = None ,
557488 attention_mask : Optional [torch .Tensor ] = None ,
558489 return_dict : bool = True ,
559- ** kwargs ,
560490 ):
561491 """
562492 The [`PixArtTransformer2DModel`] forward method.
0 commit comments