1616parser  =  argparse .ArgumentParser ()
1717parser .add_argument ("--checkpoint_path" , type = str )
1818parser .add_argument ("--output_path" , type = str )
19- parser .add_argument ("--dtype" , type = str ,  default = "fp16" )
19+ parser .add_argument ("--dtype" , type = str )
2020
2121args  =  parser .parse_args ()
22- dtype  =  torch .float16  if  args .dtype  ==  "fp16"  else  torch .float32 
2322
2423
2524def  load_original_checkpoint (ckpt_path ):
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
4039    return  new_weight 
4140
4241
43- def  convert_sd3_transformer_checkpoint_to_diffusers (original_state_dict , num_layers , caption_projection_dim ):
42+ def  convert_sd3_transformer_checkpoint_to_diffusers (
43+     original_state_dict , num_layers , caption_projection_dim , dual_attention_layers , has_qk_norm 
44+ ):
4445    converted_state_dict  =  {}
4546
4647    # Positional and patch embeddings. 
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
110111        converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([context_v ])
111112        converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([context_v_bias ])
112113
114+         # qk norm 
115+         if  has_qk_norm :
116+             converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
117+                 f"joint_blocks.{ i }  
118+             )
119+             converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
120+                 f"joint_blocks.{ i }  
121+             )
122+             converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
123+                 f"joint_blocks.{ i }  
124+             )
125+             converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
126+                 f"joint_blocks.{ i }  
127+             )
128+ 
113129        # output projections. 
114130        converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
115131            f"joint_blocks.{ i }  
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
125141                f"joint_blocks.{ i }  
126142            )
127143
144+         # attn2 
145+         if  i  in  dual_attention_layers :
146+             # Q, K, V 
147+             sample_q2 , sample_k2 , sample_v2  =  torch .chunk (
148+                 original_state_dict .pop (f"joint_blocks.{ i }  ), 3 , dim = 0 
149+             )
150+             sample_q2_bias , sample_k2_bias , sample_v2_bias  =  torch .chunk (
151+                 original_state_dict .pop (f"joint_blocks.{ i }  ), 3 , dim = 0 
152+             )
153+             converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([sample_q2 ])
154+             converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([sample_q2_bias ])
155+             converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([sample_k2 ])
156+             converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([sample_k2_bias ])
157+             converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([sample_v2 ])
158+             converted_state_dict [f"transformer_blocks.{ i }  ] =  torch .cat ([sample_v2_bias ])
159+ 
160+             # qk norm 
161+             if  has_qk_norm :
162+                 converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
163+                     f"joint_blocks.{ i }  
164+                 )
165+                 converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
166+                     f"joint_blocks.{ i }  
167+                 )
168+ 
169+             # output projections. 
170+             converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
171+                 f"joint_blocks.{ i }  
172+             )
173+             converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
174+                 f"joint_blocks.{ i }  
175+             )
176+ 
128177        # norms. 
129178        converted_state_dict [f"transformer_blocks.{ i }  ] =  original_state_dict .pop (
130179            f"joint_blocks.{ i }  
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
195244    )
196245
197246
247+ def  get_attn2_layers (state_dict ):
248+     attn2_layers  =  []
249+     for  key  in  state_dict .keys ():
250+         if  "attn2."  in  key :
251+             # Extract the layer number from the key 
252+             layer_num  =  int (key .split ("." )[1 ])
253+             attn2_layers .append (layer_num )
254+     return  tuple (sorted (set (attn2_layers )))
255+ 
256+ 
257+ def  get_pos_embed_max_size (state_dict ):
258+     num_patches  =  state_dict ["pos_embed" ].shape [1 ]
259+     pos_embed_max_size  =  int (num_patches ** 0.5 )
260+     return  pos_embed_max_size 
261+ 
262+ 
263+ def  get_caption_projection_dim (state_dict ):
264+     caption_projection_dim  =  state_dict ["context_embedder.weight" ].shape [0 ]
265+     return  caption_projection_dim 
266+ 
267+ 
198268def  main (args ):
199269    original_ckpt  =  load_original_checkpoint (args .checkpoint_path )
270+     original_dtype  =  next (iter (original_ckpt .values ())).dtype 
271+ 
272+     # Initialize dtype with a default value 
273+     dtype  =  None 
274+ 
275+     if  args .dtype  is  None :
276+         dtype  =  original_dtype 
277+     elif  args .dtype  ==  "fp16" :
278+         dtype  =  torch .float16 
279+     elif  args .dtype  ==  "bf16" :
280+         dtype  =  torch .bfloat16 
281+     elif  args .dtype  ==  "fp32" :
282+         dtype  =  torch .float32 
283+     else :
284+         raise  ValueError (f"Unsupported dtype: { args .dtype }  )
285+ 
286+     if  dtype  !=  original_dtype :
287+         print (
288+             f"Checkpoint dtype { original_dtype } { dtype }  
289+         )
290+ 
200291    num_layers  =  list (set (int (k .split ("." , 2 )[1 ]) for  k  in  original_ckpt  if  "joint_blocks"  in  k ))[- 1 ] +  1   # noqa: C401 
201-     caption_projection_dim  =  1536 
292+ 
293+     caption_projection_dim  =  get_caption_projection_dim (original_ckpt )
294+ 
295+     # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 
296+     attn2_layers  =  get_attn2_layers (original_ckpt )
297+ 
298+     # sd3.5 use qk norm("rms_norm") 
299+     has_qk_norm  =  any ("ln_q"  in  key  for  key  in  original_ckpt .keys ())
300+ 
301+     # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192 
302+     pos_embed_max_size  =  get_pos_embed_max_size (original_ckpt )
202303
203304    converted_transformer_state_dict  =  convert_sd3_transformer_checkpoint_to_diffusers (
204-         original_ckpt , num_layers , caption_projection_dim 
305+         original_ckpt , num_layers , caption_projection_dim ,  attn2_layers ,  has_qk_norm 
205306    )
206307
207308    with  CTX ():
208309        transformer  =  SD3Transformer2DModel (
209-             sample_size = 64 ,
310+             sample_size = 128 ,
210311            patch_size = 2 ,
211312            in_channels = 16 ,
212313            joint_attention_dim = 4096 ,
213314            num_layers = num_layers ,
214315            caption_projection_dim = caption_projection_dim ,
215-             num_attention_heads = 24 ,
216-             pos_embed_max_size = 192 ,
316+             num_attention_heads = num_layers ,
317+             pos_embed_max_size = pos_embed_max_size ,
318+             qk_norm = "rms_norm"  if  has_qk_norm  else  None ,
319+             dual_attention_layers = attn2_layers ,
217320        )
218321    if  is_accelerate_available ():
219322        load_model_dict_into_meta (transformer , converted_transformer_state_dict )
0 commit comments