Skip to content

Commit 1ab6ab2

Browse files
committed
udpate pipelines
1 parent 2c828c2 commit 1ab6ab2

File tree

4 files changed

+140
-140
lines changed

4 files changed

+140
-140
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def main(args):
174174
norm_elementwise_affine=False,
175175
norm_eps=1e-6,
176176
)
177-
177+
178178
if is_accelerate_available():
179179
load_model_dict_into_meta(transformer, converted_state_dict)
180180
else:

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...utils import is_torch_version, logging
22-
from ...utils.torch_utils import maybe_allow_in_graph
2322
from ..attention_processor import (
2423
Attention,
2524
AttentionProcessor,
@@ -36,15 +35,21 @@
3635

3736

3837
class GLUMBConv(nn.Module):
39-
def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 4, norm_type: Optional[str] = None, residual_connection: bool = True) -> None:
38+
def __init__(
39+
self,
40+
in_channels: int,
41+
out_channels: int,
42+
expand_ratio: float = 4,
43+
norm_type: Optional[str] = None,
44+
residual_connection: bool = True,
45+
) -> None:
4046
super().__init__()
4147

4248
hidden_channels = int(expand_ratio * in_channels)
4349
self.norm_type = norm_type
4450
self.residual_connection = residual_connection
4551

4652
self.nonlinearity = nn.SiLU()
47-
4853
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
4954
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
5055
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
@@ -55,7 +60,7 @@ def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 4,
5560

5661
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5762
if self.residual_connection:
58-
residual = hidden_states
63+
residual = hidden_states
5964

6065
hidden_states = self.conv_inverted(hidden_states)
6166
hidden_states = self.nonlinearity(hidden_states)
@@ -65,22 +70,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6570
hidden_states = hidden_states * self.nonlinearity(gate)
6671

6772
hidden_states = self.conv_point(hidden_states)
68-
73+
6974
if self.norm_type == "rms_norm":
7075
# move channel to the last dimension so we apply RMSnorm across channel dimension
7176
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
7277

7378
if self.residual_connection:
7479
hidden_states = hidden_states + residual
75-
80+
7681
return hidden_states
7782

7883

7984
class SanaTransformerBlock(nn.Module):
8085
r"""
8186
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
8287
"""
83-
88+
8489
def __init__(
8590
self,
8691
dim: int = 2240,
@@ -149,6 +154,7 @@ def forward(
149154
# 2. Self Attention
150155
norm_hidden_states = self.norm1(hidden_states)
151156
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
157+
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
152158

153159
attn_output = self.attn1(norm_hidden_states)
154160
hidden_states = hidden_states + gate_msa * attn_output
@@ -256,7 +262,7 @@ def __init__(
256262
self.time_embed = AdaLayerNormSingle(inner_dim)
257263

258264
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
259-
self.caption_norm = RMSNorm(inner_dim, eps=1e-5)
265+
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
260266

261267
# 3. Transformer blocks
262268
self.transformer_blocks = nn.ModuleList(

0 commit comments

Comments
 (0)