Skip to content

Commit c3e107f

Browse files
committed
up
1 parent c4d049c commit c3e107f

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,6 @@ def main(args):
171171
f"blocks.{depth}.attn.proj.bias"
172172
)
173173

174-
# Add Q/K normalization for self-attention (attn1) - needed for Sana Sprint
175-
if args.model_type == "SanaSprint_1600M_P1_D20":
176-
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
177-
f"blocks.{depth}.attn.q_norm.weight"
178-
)
179-
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
180-
f"blocks.{depth}.attn.k_norm.weight"
181-
)
182-
183174
# Feed-forward.
184175
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
185176
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
@@ -218,15 +209,6 @@ def main(args):
218209
f"blocks.{depth}.cross_attn.k_norm.weight"
219210
)
220211

221-
# Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint
222-
if args.model_type == "SanaSprint_1600M_P1_D20":
223-
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
224-
f"blocks.{depth}.cross_attn.q_norm.weight"
225-
)
226-
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
227-
f"blocks.{depth}.cross_attn.k_norm.weight"
228-
)
229-
230212
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
231213
f"blocks.{depth}.cross_attn.proj.weight"
232214
)
@@ -261,6 +243,13 @@ def main(args):
261243
}
262244

263245
# Add qk_norm parameter for Sana Sprint
246+
if args.model_type in [
247+
"SanaMS1.5_1600M_P1_D20",
248+
"SanaMS1.5_4800M_P1_D60",
249+
"SanaSprint_600M_P1_D28",
250+
"SanaSprint_1600M_P1_D20",
251+
]:
252+
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
264253
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
265254
transformer_kwargs["guidance_embeds"] = True
266255

0 commit comments

Comments
 (0)