@@ -171,15 +171,6 @@ def main(args):
171171 f"blocks.{ depth } .attn.proj.bias"
172172 )
173173
174- # Add Q/K normalization for self-attention (attn1) - needed for Sana Sprint
175- if args .model_type == "SanaSprint_1600M_P1_D20" :
176- converted_state_dict [f"transformer_blocks.{ depth } .attn1.norm_q.weight" ] = state_dict .pop (
177- f"blocks.{ depth } .attn.q_norm.weight"
178- )
179- converted_state_dict [f"transformer_blocks.{ depth } .attn1.norm_k.weight" ] = state_dict .pop (
180- f"blocks.{ depth } .attn.k_norm.weight"
181- )
182-
183174 # Feed-forward.
184175 converted_state_dict [f"transformer_blocks.{ depth } .ff.conv_inverted.weight" ] = state_dict .pop (
185176 f"blocks.{ depth } .mlp.inverted_conv.conv.weight"
@@ -218,15 +209,6 @@ def main(args):
218209 f"blocks.{ depth } .cross_attn.k_norm.weight"
219210 )
220211
221- # Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint
222- if args .model_type == "SanaSprint_1600M_P1_D20" :
223- converted_state_dict [f"transformer_blocks.{ depth } .attn2.norm_q.weight" ] = state_dict .pop (
224- f"blocks.{ depth } .cross_attn.q_norm.weight"
225- )
226- converted_state_dict [f"transformer_blocks.{ depth } .attn2.norm_k.weight" ] = state_dict .pop (
227- f"blocks.{ depth } .cross_attn.k_norm.weight"
228- )
229-
230212 converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_out.0.weight" ] = state_dict .pop (
231213 f"blocks.{ depth } .cross_attn.proj.weight"
232214 )
@@ -261,6 +243,13 @@ def main(args):
261243 }
262244
263245 # Add qk_norm parameter for Sana Sprint
246+ if args .model_type in [
247+ "SanaMS1.5_1600M_P1_D20" ,
248+ "SanaMS1.5_4800M_P1_D60" ,
249+ "SanaSprint_600M_P1_D28" ,
250+ "SanaSprint_1600M_P1_D20" ,
251+ ]:
252+ transformer_kwargs ["qk_norm" ] = "rms_norm_across_heads"
264253 if args .model_type in ["SanaSprint_1600M_P1_D20" , "SanaSprint_600M_P1_D28" ]:
265254 transformer_kwargs ["guidance_embeds" ] = True
266255
0 commit comments