1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Dict , Optional , Union
15+ from typing import Dict , Optional , Tuple , Union
1616
1717import torch
1818from torch import nn
3535logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3636
3737
38- # Modified from diffusers.models.autoencoders.autoencoder_dc.GLUMBConv
39- @maybe_allow_in_graph
40- class SanaGLUMBConv (nn .Module ):
41- def __init__ (self , in_channels : int , out_channels : int , mlp_ratio : float = 2.5 ) -> None :
38+ class GLUMBConv (nn .Module ):
39+ def __init__ (self , in_channels : int , out_channels : int , expand_ratio : float = 4 , norm_type : Optional [str ] = None , residual_connection : bool = True ) -> None :
4240 super ().__init__ ()
4341
44- hidden_channels = int (mlp_ratio * in_channels )
42+ hidden_channels = int (expand_ratio * in_channels )
43+ self .norm_type = norm_type
44+ self .residual_connection = residual_connection
4545
4646 self .nonlinearity = nn .SiLU ()
4747
4848 self .conv_inverted = nn .Conv2d (in_channels , hidden_channels * 2 , 1 , 1 , 0 )
4949 self .conv_depth = nn .Conv2d (hidden_channels * 2 , hidden_channels * 2 , 3 , 1 , 1 , groups = hidden_channels * 2 )
5050 self .conv_point = nn .Conv2d (hidden_channels , out_channels , 1 , 1 , 0 , bias = False )
5151
52- def forward (self , hidden_states : torch .Tensor , HW : Optional [tuple [int ]] = None ) -> torch .Tensor :
53- B , N , C = hidden_states .shape
54- if HW is None :
55- H = W = int (N ** 0.5 )
56- else :
57- H , W = HW
52+ self .norm = None
53+ if norm_type == "rms_norm" :
54+ self .norm = RMSNorm (out_channels , eps = 1e-5 , elementwise_affine = True , bias = True )
5855
59- hidden_states = hidden_states .reshape (B , H , W , C ).permute (0 , 3 , 1 , 2 )
56+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
57+ if self .residual_connection :
58+ residual = hidden_states
6059
6160 hidden_states = self .conv_inverted (hidden_states )
6261 hidden_states = self .nonlinearity (hidden_states )
@@ -66,23 +65,22 @@ def forward(self, hidden_states: torch.Tensor, HW: Optional[tuple[int]] = None)
6665 hidden_states = hidden_states * self .nonlinearity (gate )
6766
6867 hidden_states = self .conv_point (hidden_states )
69- hidden_states = hidden_states .reshape (B , C , N ).permute (0 , 2 , 1 )
70-
68+
69+ if self .norm_type == "rms_norm" :
70+ # move channel to the last dimension so we apply RMSnorm across channel dimension
71+ hidden_states = self .norm (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
72+
73+ if self .residual_connection :
74+ hidden_states = hidden_states + residual
75+
7176 return hidden_states
7277
7378
7479class SanaTransformerBlock (nn .Module ):
7580 r"""
76- A Transformer block following the Linear Transformer architecture, introduced in Sana
77-
78- Reference: https://arxiv.org/abs/2410.10629
79-
80- Parameters:
81- dim (`int`): The number of channels in the input and output.
82- num_attention_heads (`int`): The number of heads to use for multi-head attention.
83- attention_head_dim (`int`): The number of channels in each head.
81+ Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
8482 """
85-
83+
8684 def __init__ (
8785 self ,
8886 dim : int = 2240 ,
@@ -127,11 +125,7 @@ def __init__(
127125 )
128126
129127 # 3. Feed-forward
130- self .ff = SanaGLUMBConv (
131- in_channels = dim ,
132- out_channels = dim ,
133- mlp_ratio = mlp_ratio ,
134- )
128+ self .ff = GLUMBConv (dim , dim , mlp_ratio , norm_type = None , residual_connection = False )
135129
136130 self .scale_shift_table = nn .Parameter (torch .randn (6 , dim ) / dim ** 0.5 )
137131
@@ -142,7 +136,8 @@ def forward(
142136 encoder_hidden_states : Optional [torch .Tensor ] = None ,
143137 encoder_attention_mask : Optional [torch .Tensor ] = None ,
144138 timestep : Optional [torch .LongTensor ] = None ,
145- HW : Optional [tuple [int ]] = None ,
139+ height : int = None ,
140+ width : int = None ,
146141 ) -> torch .Tensor :
147142 batch_size = hidden_states .shape [0 ]
148143
@@ -171,15 +166,17 @@ def forward(
171166 norm_hidden_states = self .norm2 (hidden_states )
172167 norm_hidden_states = norm_hidden_states * (1 + scale_mlp ) + shift_mlp
173168
174- ff_output = self .ff (norm_hidden_states , HW = HW )
169+ norm_hidden_states = norm_hidden_states .unflatten (1 , (height , width )).permute (0 , 3 , 1 , 2 )
170+ ff_output = self .ff (norm_hidden_states )
171+ ff_output = ff_output .flatten (2 , 3 ).permute (0 , 2 , 1 )
175172 hidden_states = hidden_states + gate_mlp * ff_output
176173
177174 return hidden_states
178175
179176
180177class SanaTransformer2DModel (ModelMixin , ConfigMixin ):
181178 r"""
182- A 2D Transformer model as introduced in [Sana](https://arxiv.org/abs /2410.10629) family of models.
179+ A 2D Transformer model introduced in [Sana](https://huggingface.co/papers /2410.10629) family of models.
183180
184181 Args:
185182 in_channels (`int`, defaults to `32`):
@@ -204,7 +201,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
204201 The expansion ratio to use in the GLUMBConv layer.
205202 dropout (`float`, defaults to `0.0`):
206203 The dropout probability.
207- attention_bias (`bool`, defaults to `True `):
204+ attention_bias (`bool`, defaults to `False `):
208205 Whether to use bias in the attention layer.
209206 sample_size (`int`, defaults to `32`):
210207 The base size of the input latent.
@@ -233,7 +230,7 @@ def __init__(
233230 caption_channels : int = 2304 ,
234231 mlp_ratio : float = 2.5 ,
235232 dropout : float = 0.0 ,
236- attention_bias : bool = True ,
233+ attention_bias : bool = False ,
237234 sample_size : int = 32 ,
238235 patch_size : int = 1 ,
239236 norm_elementwise_affine : bool = False ,
@@ -245,7 +242,7 @@ def __init__(
245242 inner_dim = num_attention_heads * attention_head_dim
246243
247244 # 1. Patch Embedding
248- self .pos_embed = PatchEmbed (
245+ self .patch_embed = PatchEmbed (
249246 height = sample_size ,
250247 width = sample_size ,
251248 patch_size = patch_size ,
@@ -255,7 +252,9 @@ def __init__(
255252 pos_embed_type = None ,
256253 )
257254
258- # 2. Caption Embedding
255+ # 2. Additional condition embeddings
256+ self .time_embed = AdaLayerNormSingle (inner_dim )
257+
259258 self .caption_projection = PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = inner_dim )
260259 self .caption_norm = RMSNorm (inner_dim , eps = 1e-5 )
261260
@@ -285,8 +284,6 @@ def __init__(
285284 self .norm_out = nn .LayerNorm (inner_dim , elementwise_affine = False , eps = 1e-6 )
286285 self .proj_out = nn .Linear (inner_dim , patch_size * patch_size * out_channels )
287286
288- self .adaln_single = AdaLayerNormSingle (inner_dim )
289-
290287 self .gradient_checkpointing = False
291288
292289 def _set_gradient_checkpointing (self , module , value = False ):
@@ -361,7 +358,7 @@ def forward(
361358 encoder_attention_mask : Optional [torch .Tensor ] = None ,
362359 attention_mask : Optional [torch .Tensor ] = None ,
363360 return_dict : bool = True ,
364- ):
361+ ) -> Union [ Tuple [ torch . Tensor , ...], Transformer2DModelOutput ] :
365362 # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
366363 # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
367364 # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -387,11 +384,12 @@ def forward(
387384
388385 # 1. Input
389386 batch_size , num_channels , height , width = hidden_states .shape
390- post_patch_height = height // self .config .patch_size
391- post_patch_width = width // self .config .patch_size
392- hidden_states = self .pos_embed (hidden_states )
387+ p = self .config .patch_size
388+ post_patch_height , post_patch_width = height // p , width // p
389+
390+ hidden_states = self .patch_embed (hidden_states )
393391
394- timestep , embedded_timestep = self .adaln_single (
392+ timestep , embedded_timestep = self .time_embed (
395393 timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype
396394 )
397395
@@ -418,7 +416,8 @@ def create_block_forward(block):
418416 encoder_hidden_states ,
419417 encoder_attention_mask ,
420418 timestep ,
421- (post_patch_height , post_patch_width ),
419+ post_patch_height ,
420+ post_patch_width ,
422421 )
423422
424423 # 3. Normalization
@@ -436,14 +435,7 @@ def create_block_forward(block):
436435 batch_size , post_patch_height , post_patch_width , self .config .patch_size , self .config .patch_size , - 1
437436 )
438437 hidden_states = hidden_states .permute (0 , 5 , 1 , 3 , 2 , 4 )
439- output = hidden_states .reshape (
440- shape = (
441- batch_size ,
442- - 1 ,
443- post_patch_height * self .config .patch_size ,
444- post_patch_width * self .config .patch_size ,
445- )
446- )
438+ output = hidden_states .reshape (batch_size , - 1 , post_patch_height * p , post_patch_width * p )
447439
448440 if not return_dict :
449441 return (output ,)
0 commit comments