Skip to content

Commit 459628c

Browse files
committed
fix the Positinoal Embedding bug in 2K model;
1 parent 233dffd commit 459628c

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,18 @@ def main(args):
8888
# y norm
8989
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
9090

91+
# scheduler
9192
flow_shift = 3.0
93+
94+
# model config
9295
if args.model_type == "SanaMS_1600M_P1_D20":
9396
layer_num = 20
9497
elif args.model_type == "SanaMS_600M_P1_D28":
9598
layer_num = 28
9699
else:
97100
raise ValueError(f"{args.model_type} is not supported.")
101+
# Positional embedding interpolation scale.
102+
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
98103

99104
for depth in range(layer_num):
100105
# Transformer blocks.
@@ -176,6 +181,7 @@ def main(args):
176181
patch_size=1,
177182
norm_elementwise_affine=False,
178183
norm_eps=1e-6,
184+
interpolation_scale=interpolation_scale[args.image_size],
179185
)
180186

181187
if is_accelerate_available():

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,26 @@ def __init__(
242242
patch_size: int = 1,
243243
norm_elementwise_affine: bool = False,
244244
norm_eps: float = 1e-6,
245+
interpolation_scale: Optional[int] = None,
245246
) -> None:
246247
super().__init__()
247248

248249
out_channels = out_channels or in_channels
249250
inner_dim = num_attention_heads * attention_head_dim
250251

251252
# 1. Patch Embedding
253+
interpolation_scale = (
254+
interpolation_scale
255+
if interpolation_scale is not None
256+
else max(sample_size // 64, 1)
257+
)
252258
self.patch_embed = PatchEmbed(
253259
height=sample_size,
254260
width=sample_size,
255261
patch_size=patch_size,
256262
in_channels=in_channels,
257263
embed_dim=inner_dim,
258-
interpolation_scale=None,
259-
pos_embed_type=None,
264+
interpolation_scale=interpolation_scale,
260265
)
261266

262267
# 2. Additional condition embeddings

0 commit comments

Comments
 (0)