1919
2020from ...configuration_utils import ConfigMixin , register_to_config
2121from ...utils import is_torch_version , logging
22- from ...utils .torch_utils import maybe_allow_in_graph
2322from ..attention_processor import (
2423 Attention ,
2524 AttentionProcessor ,
3635
3736
3837class GLUMBConv (nn .Module ):
39- def __init__ (self , in_channels : int , out_channels : int , expand_ratio : float = 4 , norm_type : Optional [str ] = None , residual_connection : bool = True ) -> None :
38+ def __init__ (
39+ self ,
40+ in_channels : int ,
41+ out_channels : int ,
42+ expand_ratio : float = 4 ,
43+ norm_type : Optional [str ] = None ,
44+ residual_connection : bool = True ,
45+ ) -> None :
4046 super ().__init__ ()
4147
4248 hidden_channels = int (expand_ratio * in_channels )
4349 self .norm_type = norm_type
4450 self .residual_connection = residual_connection
4551
4652 self .nonlinearity = nn .SiLU ()
47-
4853 self .conv_inverted = nn .Conv2d (in_channels , hidden_channels * 2 , 1 , 1 , 0 )
4954 self .conv_depth = nn .Conv2d (hidden_channels * 2 , hidden_channels * 2 , 3 , 1 , 1 , groups = hidden_channels * 2 )
5055 self .conv_point = nn .Conv2d (hidden_channels , out_channels , 1 , 1 , 0 , bias = False )
@@ -55,7 +60,7 @@ def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 4,
5560
5661 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
5762 if self .residual_connection :
58- residual = hidden_states
63+ residual = hidden_states
5964
6065 hidden_states = self .conv_inverted (hidden_states )
6166 hidden_states = self .nonlinearity (hidden_states )
@@ -65,22 +70,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6570 hidden_states = hidden_states * self .nonlinearity (gate )
6671
6772 hidden_states = self .conv_point (hidden_states )
68-
73+
6974 if self .norm_type == "rms_norm" :
7075 # move channel to the last dimension so we apply RMSnorm across channel dimension
7176 hidden_states = self .norm (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
7277
7378 if self .residual_connection :
7479 hidden_states = hidden_states + residual
75-
80+
7681 return hidden_states
7782
7883
7984class SanaTransformerBlock (nn .Module ):
8085 r"""
8186 Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
8287 """
83-
88+
8489 def __init__ (
8590 self ,
8691 dim : int = 2240 ,
@@ -149,6 +154,7 @@ def forward(
149154 # 2. Self Attention
150155 norm_hidden_states = self .norm1 (hidden_states )
151156 norm_hidden_states = norm_hidden_states * (1 + scale_msa ) + shift_msa
157+ norm_hidden_states = norm_hidden_states .to (hidden_states .dtype )
152158
153159 attn_output = self .attn1 (norm_hidden_states )
154160 hidden_states = hidden_states + gate_msa * attn_output
@@ -256,7 +262,7 @@ def __init__(
256262 self .time_embed = AdaLayerNormSingle (inner_dim )
257263
258264 self .caption_projection = PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = inner_dim )
259- self .caption_norm = RMSNorm (inner_dim , eps = 1e-5 )
265+ self .caption_norm = RMSNorm (inner_dim , eps = 1e-5 , elementwise_affine = True )
260266
261267 # 3. Transformer blocks
262268 self .transformer_blocks = nn .ModuleList (
0 commit comments