Skip to content

Commit 297c0e7

Browse files
committed
1. change code based on AutoencoderDC;
2. fix the bug of new GLUMBConv; 3. run success;
1 parent b76493f commit 297c0e7

File tree

6 files changed

+36
-48
lines changed

6 files changed

+36
-48
lines changed

scripts/convert_sana_pag_to_diffusers.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,11 @@ def main(args):
5959
# y norm
6060
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
6161

62+
flow_shift = 3.0
6263
if args.model_type == "SanaMS_1600M_P1_D20":
6364
layer_num = 20
64-
flow_shift = 3.0
6565
elif args.model_type == "SanaMS_600M_P1_D28":
6666
layer_num = 28
67-
flow_shift = 4.0
6867
else:
6968
raise ValueError(f"{args.model_type} is not supported.")
7069

@@ -89,19 +88,19 @@ def main(args):
8988
)
9089

9190
# Feed-forward.
92-
converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.weight"] = state_dict.pop(
91+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
9392
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
9493
)
95-
converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.bias"] = state_dict.pop(
94+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
9695
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
9796
)
98-
converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.weight"] = state_dict.pop(
97+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
9998
f"blocks.{depth}.mlp.depth_conv.conv.weight"
10099
)
101-
converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.bias"] = state_dict.pop(
100+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
102101
f"blocks.{depth}.mlp.depth_conv.conv.bias"
103102
)
104-
converted_state_dict[f"transformer_blocks.{depth}.ff.point_conv.conv.weight"] = state_dict.pop(
103+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
105104
f"blocks.{depth}.mlp.point_conv.conv.weight"
106105
)
107106

@@ -156,8 +155,6 @@ def main(args):
156155
attention_type="default",
157156
use_pe=False,
158157
expand_ratio=2.5,
159-
ff_bias=(True, True, False),
160-
ff_norm=(None, None, None),
161158
)
162159
if is_accelerate_available():
163160
load_model_dict_into_meta(transformer, converted_state_dict)

scripts/convert_sana_to_diffusers.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,11 @@ def main(args):
5959
# y norm
6060
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
6161

62+
flow_shift = 3.0
6263
if args.model_type == "SanaMS_1600M_P1_D20":
6364
layer_num = 20
64-
flow_shift = 3.0
6565
elif args.model_type == "SanaMS_600M_P1_D28":
6666
layer_num = 28
67-
flow_shift = 4.0
6867
else:
6968
raise ValueError(f"{args.model_type} is not supported.")
7069

@@ -89,19 +88,19 @@ def main(args):
8988
)
9089

9190
# Feed-forward.
92-
converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.weight"] = state_dict.pop(
91+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
9392
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
9493
)
95-
converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.bias"] = state_dict.pop(
94+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
9695
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
9796
)
98-
converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.weight"] = state_dict.pop(
97+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
9998
f"blocks.{depth}.mlp.depth_conv.conv.weight"
10099
)
101-
converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.bias"] = state_dict.pop(
100+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
102101
f"blocks.{depth}.mlp.depth_conv.conv.bias"
103102
)
104-
converted_state_dict[f"transformer_blocks.{depth}.ff.point_conv.conv.weight"] = state_dict.pop(
103+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
105104
f"blocks.{depth}.mlp.point_conv.conv.weight"
106105
)
107106

@@ -156,8 +155,6 @@ def main(args):
156155
attention_type="default",
157156
use_pe=False,
158157
expand_ratio=2.5,
159-
ff_bias=(True, True, False),
160-
ff_norm=(None, None, None),
161158
)
162159
if is_accelerate_available():
163160
load_model_dict_into_meta(transformer, converted_state_dict)
@@ -188,8 +185,8 @@ def main(args):
188185
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
189186
# VAE
190187
ae = AutoencoderDC.from_pretrained(
191-
"Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers",
192-
torch_dtype=torch.float32,
188+
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
189+
torch_dtype=torch.bfloat16,
193190
).to(device)
194191

195192
# Text Encoder

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5552,6 +5552,11 @@ def __call__(
55525552
CustomDiffusionAttnProcessor2_0,
55535553
SlicedAttnProcessor,
55545554
SlicedAttnAddedKVProcessor,
5555+
SanaLinearAttnProcessor2_0,
5556+
SanaMultiscaleLinearAttention,
5557+
SanaMultiscaleAttnProcessor2_0,
5558+
SanaMultiscaleAttentionProjection,
5559+
PAGCFGSanaLinearAttnProcessor2_0,
55555560
IPAdapterAttnProcessor,
55565561
IPAdapterAttnProcessor2_0,
55575562
IPAdapterXFormersAttnProcessor,
@@ -5562,5 +5567,4 @@ def __call__(
55625567
LoRAXFormersAttnProcessor,
55635568
LoRAAttnAddedKVProcessor,
55645569
SanaLinearAttnProcessor2_0,
5565-
PAGCFGSanaLinearAttnProcessor2_0,
55665570
]

src/diffusers/models/transformers/sana_transformer_2d.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
FusedAttnProcessor2_0,
2727
SanaLinearAttnProcessor2_0,
2828
)
29-
from ..autoencoders.autoencoder_dc import GLUMBConv
3029
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding
3130
from ..modeling_outputs import Transformer2DModelOutput
3231
from ..modeling_utils import ModelMixin
@@ -58,40 +57,40 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_facto
5857
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
5958

6059

61-
# Modified from diffusers.models.autoencoders.ecae.GLUMBConv
60+
# Modified from diffusers.models.autoencoders.autoencoder_dc.GLUMBConv
6261
@maybe_allow_in_graph
6362
class SanaGLUMBConv(nn.Module):
64-
def __init__(self, in_channels: int, out_channels: int) -> None:
63+
def __init__(self, in_channels: int, out_channels: int, expand_ratio: float = 2.5) -> None:
6564
super().__init__()
6665

67-
hidden_channels = int(2.5 * in_channels)
66+
hidden_channels = int(expand_ratio * in_channels)
6867

6968
self.nonlinearity = nn.SiLU()
7069

7170
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
7271
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
7372
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
74-
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
7573

76-
def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
77-
B, N, C = x.shape
74+
def forward(self, hidden_states: torch.Tensor, HW: Optional[tuple[int]] = None) -> torch.Tensor:
75+
B, N, C = hidden_states.shape
7876
if HW is None:
7977
H = W = int(N**0.5)
8078
else:
8179
H, W = HW
8280

83-
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
84-
x = self.inverted_conv(x)
85-
x = self.depth_conv(x)
81+
hidden_states = hidden_states.reshape(B, H, W, C).permute(0, 3, 1, 2)
8682

87-
x, gate = torch.chunk(x, 2, dim=1)
88-
gate = self.glu_act(gate)
89-
x = x * gate
83+
hidden_states = self.conv_inverted(hidden_states)
84+
hidden_states = self.nonlinearity(hidden_states)
9085

91-
x = self.point_conv(x)
92-
x = x.reshape(B, C, N).permute(0, 2, 1)
86+
hidden_states = self.conv_depth(hidden_states)
87+
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
88+
hidden_states = hidden_states * self.nonlinearity(gate)
9389

94-
return x
90+
hidden_states = self.conv_point(hidden_states)
91+
hidden_states = hidden_states.reshape(B, C, N).permute(0, 2, 1)
92+
93+
return hidden_states
9594

9695

9796
# Modified from diffusers.models.attention.BasicTransformerBlock
@@ -130,8 +129,6 @@ def __init__(
130129
use_pe: bool = False,
131130
num_positional_embeddings: Optional[int] = None,
132131
expand_ratio: float = 2.5,
133-
ff_bias: tuple =(True, True, False),
134-
ff_norm: tuple =(None, None, None),
135132
):
136133
super().__init__()
137134
self.dim = dim
@@ -186,9 +183,6 @@ def __init__(
186183
in_channels=dim,
187184
out_channels=dim,
188185
expand_ratio=expand_ratio,
189-
use_bias=ff_bias,
190-
norm=ff_norm,
191-
act_func=activation_fn,
192186
)
193187

194188
# 5. Scale-shift for Sana.
@@ -362,8 +356,6 @@ def __init__(
362356
attention_type: Optional[str] = "default",
363357
use_pe: Optional[bool] = False,
364358
expand_ratio=2.5,
365-
ff_bias: tuple =(True, True, False),
366-
ff_norm: tuple =(None, None, None),
367359
):
368360
super().__init__()
369361

@@ -428,8 +420,6 @@ def __init__(
428420
norm_eps=self.config.norm_eps,
429421
use_pe=self.config.use_pe,
430422
expand_ratio=self.config.expand_ratio,
431-
ff_bias=self.config.ff_bias,
432-
ff_norm=self.config.ff_norm,
433423
)
434424
for _ in range(self.config.num_layers)
435425
]

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(
172172
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
173173
)
174174

175-
self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_width_list) - 1)
175+
self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
176176
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
177177

178178
self.set_pag_applied_layers(

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(
168168
)
169169

170170
self.vae_scale_factor = (
171-
2 ** (len(self.vae.config.encoder_width_list) - 1) if hasattr(self, "vae") and self.vae is not None else 32
171+
2 ** (len(self.vae.config.encoder_block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 32
172172
)
173173
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
174174

0 commit comments

Comments
 (0)