1616parser = argparse .ArgumentParser ()
1717parser .add_argument ("--checkpoint_path" , type = str )
1818parser .add_argument ("--output_path" , type = str )
19- parser .add_argument ("--dtype" , type = str , default = "fp16" )
19+ parser .add_argument ("--dtype" , type = str )
2020
2121args = parser .parse_args ()
22- dtype = torch .float16 if args .dtype == "fp16" else torch .float32
2322
2423
2524def load_original_checkpoint (ckpt_path ):
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
4039 return new_weight
4140
4241
43- def convert_sd3_transformer_checkpoint_to_diffusers (original_state_dict , num_layers , caption_projection_dim ):
42+ def convert_sd3_transformer_checkpoint_to_diffusers (
43+ original_state_dict , num_layers , caption_projection_dim , dual_attention_layers , has_qk_norm
44+ ):
4445 converted_state_dict = {}
4546
4647 # Positional and patch embeddings.
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
110111 converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.weight" ] = torch .cat ([context_v ])
111112 converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.bias" ] = torch .cat ([context_v_bias ])
112113
114+ # qk norm
115+ if has_qk_norm :
116+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_q.weight" ] = original_state_dict .pop (
117+ f"joint_blocks.{ i } .x_block.attn.ln_q.weight"
118+ )
119+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_k.weight" ] = original_state_dict .pop (
120+ f"joint_blocks.{ i } .x_block.attn.ln_k.weight"
121+ )
122+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_q.weight" ] = original_state_dict .pop (
123+ f"joint_blocks.{ i } .context_block.attn.ln_q.weight"
124+ )
125+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_k.weight" ] = original_state_dict .pop (
126+ f"joint_blocks.{ i } .context_block.attn.ln_k.weight"
127+ )
128+
113129 # output projections.
114130 converted_state_dict [f"transformer_blocks.{ i } .attn.to_out.0.weight" ] = original_state_dict .pop (
115131 f"joint_blocks.{ i } .x_block.attn.proj.weight"
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
125141 f"joint_blocks.{ i } .context_block.attn.proj.bias"
126142 )
127143
144+ # attn2
145+ if i in dual_attention_layers :
146+ # Q, K, V
147+ sample_q2 , sample_k2 , sample_v2 = torch .chunk (
148+ original_state_dict .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.weight" ), 3 , dim = 0
149+ )
150+ sample_q2_bias , sample_k2_bias , sample_v2_bias = torch .chunk (
151+ original_state_dict .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.bias" ), 3 , dim = 0
152+ )
153+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = torch .cat ([sample_q2 ])
154+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = torch .cat ([sample_q2_bias ])
155+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.weight" ] = torch .cat ([sample_k2 ])
156+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.bias" ] = torch .cat ([sample_k2_bias ])
157+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.weight" ] = torch .cat ([sample_v2 ])
158+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.bias" ] = torch .cat ([sample_v2_bias ])
159+
160+ # qk norm
161+ if has_qk_norm :
162+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_q.weight" ] = original_state_dict .pop (
163+ f"joint_blocks.{ i } .x_block.attn2.ln_q.weight"
164+ )
165+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_k.weight" ] = original_state_dict .pop (
166+ f"joint_blocks.{ i } .x_block.attn2.ln_k.weight"
167+ )
168+
169+ # output projections.
170+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = original_state_dict .pop (
171+ f"joint_blocks.{ i } .x_block.attn2.proj.weight"
172+ )
173+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = original_state_dict .pop (
174+ f"joint_blocks.{ i } .x_block.attn2.proj.bias"
175+ )
176+
128177 # norms.
129178 converted_state_dict [f"transformer_blocks.{ i } .norm1.linear.weight" ] = original_state_dict .pop (
130179 f"joint_blocks.{ i } .x_block.adaLN_modulation.1.weight"
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
195244 )
196245
197246
247+ def get_attn2_layers (state_dict ):
248+ attn2_layers = []
249+ for key in state_dict .keys ():
250+ if "attn2." in key :
251+ # Extract the layer number from the key
252+ layer_num = int (key .split ("." )[1 ])
253+ attn2_layers .append (layer_num )
254+ return tuple (sorted (set (attn2_layers )))
255+
256+
257+ def get_pos_embed_max_size (state_dict ):
258+ num_patches = state_dict ["pos_embed" ].shape [1 ]
259+ pos_embed_max_size = int (num_patches ** 0.5 )
260+ return pos_embed_max_size
261+
262+
263+ def get_caption_projection_dim (state_dict ):
264+ caption_projection_dim = state_dict ["context_embedder.weight" ].shape [0 ]
265+ return caption_projection_dim
266+
267+
198268def main (args ):
199269 original_ckpt = load_original_checkpoint (args .checkpoint_path )
270+ original_dtype = next (iter (original_ckpt .values ())).dtype
271+
272+ # Initialize dtype with a default value
273+ dtype = None
274+
275+ if args .dtype is None :
276+ dtype = original_dtype
277+ elif args .dtype == "fp16" :
278+ dtype = torch .float16
279+ elif args .dtype == "bf16" :
280+ dtype = torch .bfloat16
281+ elif args .dtype == "fp32" :
282+ dtype = torch .float32
283+ else :
284+ raise ValueError (f"Unsupported dtype: { args .dtype } " )
285+
286+ if dtype != original_dtype :
287+ print (
288+ f"Checkpoint dtype { original_dtype } does not match requested dtype { dtype } . This can lead to unexpected results, proceed with caution."
289+ )
290+
200291 num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in original_ckpt if "joint_blocks" in k ))[- 1 ] + 1 # noqa: C401
201- caption_projection_dim = 1536
292+
293+ caption_projection_dim = get_caption_projection_dim (original_ckpt )
294+
295+ # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
296+ attn2_layers = get_attn2_layers (original_ckpt )
297+
298+ # sd3.5 use qk norm("rms_norm")
299+ has_qk_norm = any ("ln_q" in key for key in original_ckpt .keys ())
300+
301+ # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
302+ pos_embed_max_size = get_pos_embed_max_size (original_ckpt )
202303
203304 converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers (
204- original_ckpt , num_layers , caption_projection_dim
305+ original_ckpt , num_layers , caption_projection_dim , attn2_layers , has_qk_norm
205306 )
206307
207308 with CTX ():
208309 transformer = SD3Transformer2DModel (
209- sample_size = 64 ,
310+ sample_size = 128 ,
210311 patch_size = 2 ,
211312 in_channels = 16 ,
212313 joint_attention_dim = 4096 ,
213314 num_layers = num_layers ,
214315 caption_projection_dim = caption_projection_dim ,
215- num_attention_heads = 24 ,
216- pos_embed_max_size = 192 ,
316+ num_attention_heads = num_layers ,
317+ pos_embed_max_size = pos_embed_max_size ,
318+ qk_norm = "rms_norm" if has_qk_norm else None ,
319+ dual_attention_layers = attn2_layers ,
217320 )
218321 if is_accelerate_available ():
219322 load_model_dict_into_meta (transformer , converted_transformer_state_dict )
0 commit comments