16
16
parser = argparse .ArgumentParser ()
17
17
parser .add_argument ("--checkpoint_path" , type = str )
18
18
parser .add_argument ("--output_path" , type = str )
19
- parser .add_argument ("--dtype" , type = str , default = "fp16" )
19
+ parser .add_argument ("--dtype" , type = str )
20
20
21
21
args = parser .parse_args ()
22
- dtype = torch .float16 if args .dtype == "fp16" else torch .float32
23
22
24
23
25
24
def load_original_checkpoint (ckpt_path ):
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
40
39
return new_weight
41
40
42
41
43
- def convert_sd3_transformer_checkpoint_to_diffusers (original_state_dict , num_layers , caption_projection_dim ):
42
+ def convert_sd3_transformer_checkpoint_to_diffusers (
43
+ original_state_dict , num_layers , caption_projection_dim , dual_attention_layers , has_qk_norm
44
+ ):
44
45
converted_state_dict = {}
45
46
46
47
# Positional and patch embeddings.
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
110
111
converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.weight" ] = torch .cat ([context_v ])
111
112
converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.bias" ] = torch .cat ([context_v_bias ])
112
113
114
+ # qk norm
115
+ if has_qk_norm :
116
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_q.weight" ] = original_state_dict .pop (
117
+ f"joint_blocks.{ i } .x_block.attn.ln_q.weight"
118
+ )
119
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_k.weight" ] = original_state_dict .pop (
120
+ f"joint_blocks.{ i } .x_block.attn.ln_k.weight"
121
+ )
122
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_q.weight" ] = original_state_dict .pop (
123
+ f"joint_blocks.{ i } .context_block.attn.ln_q.weight"
124
+ )
125
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_k.weight" ] = original_state_dict .pop (
126
+ f"joint_blocks.{ i } .context_block.attn.ln_k.weight"
127
+ )
128
+
113
129
# output projections.
114
130
converted_state_dict [f"transformer_blocks.{ i } .attn.to_out.0.weight" ] = original_state_dict .pop (
115
131
f"joint_blocks.{ i } .x_block.attn.proj.weight"
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
125
141
f"joint_blocks.{ i } .context_block.attn.proj.bias"
126
142
)
127
143
144
+ # attn2
145
+ if i in dual_attention_layers :
146
+ # Q, K, V
147
+ sample_q2 , sample_k2 , sample_v2 = torch .chunk (
148
+ original_state_dict .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.weight" ), 3 , dim = 0
149
+ )
150
+ sample_q2_bias , sample_k2_bias , sample_v2_bias = torch .chunk (
151
+ original_state_dict .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.bias" ), 3 , dim = 0
152
+ )
153
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = torch .cat ([sample_q2 ])
154
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = torch .cat ([sample_q2_bias ])
155
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.weight" ] = torch .cat ([sample_k2 ])
156
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.bias" ] = torch .cat ([sample_k2_bias ])
157
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.weight" ] = torch .cat ([sample_v2 ])
158
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.bias" ] = torch .cat ([sample_v2_bias ])
159
+
160
+ # qk norm
161
+ if has_qk_norm :
162
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_q.weight" ] = original_state_dict .pop (
163
+ f"joint_blocks.{ i } .x_block.attn2.ln_q.weight"
164
+ )
165
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_k.weight" ] = original_state_dict .pop (
166
+ f"joint_blocks.{ i } .x_block.attn2.ln_k.weight"
167
+ )
168
+
169
+ # output projections.
170
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = original_state_dict .pop (
171
+ f"joint_blocks.{ i } .x_block.attn2.proj.weight"
172
+ )
173
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = original_state_dict .pop (
174
+ f"joint_blocks.{ i } .x_block.attn2.proj.bias"
175
+ )
176
+
128
177
# norms.
129
178
converted_state_dict [f"transformer_blocks.{ i } .norm1.linear.weight" ] = original_state_dict .pop (
130
179
f"joint_blocks.{ i } .x_block.adaLN_modulation.1.weight"
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
195
244
)
196
245
197
246
247
+ def get_attn2_layers (state_dict ):
248
+ attn2_layers = []
249
+ for key in state_dict .keys ():
250
+ if "attn2." in key :
251
+ # Extract the layer number from the key
252
+ layer_num = int (key .split ("." )[1 ])
253
+ attn2_layers .append (layer_num )
254
+ return tuple (sorted (set (attn2_layers )))
255
+
256
+
257
+ def get_pos_embed_max_size (state_dict ):
258
+ num_patches = state_dict ["pos_embed" ].shape [1 ]
259
+ pos_embed_max_size = int (num_patches ** 0.5 )
260
+ return pos_embed_max_size
261
+
262
+
263
+ def get_caption_projection_dim (state_dict ):
264
+ caption_projection_dim = state_dict ["context_embedder.weight" ].shape [0 ]
265
+ return caption_projection_dim
266
+
267
+
198
268
def main (args ):
199
269
original_ckpt = load_original_checkpoint (args .checkpoint_path )
270
+ original_dtype = next (iter (original_ckpt .values ())).dtype
271
+
272
+ # Initialize dtype with a default value
273
+ dtype = None
274
+
275
+ if args .dtype is None :
276
+ dtype = original_dtype
277
+ elif args .dtype == "fp16" :
278
+ dtype = torch .float16
279
+ elif args .dtype == "bf16" :
280
+ dtype = torch .bfloat16
281
+ elif args .dtype == "fp32" :
282
+ dtype = torch .float32
283
+ else :
284
+ raise ValueError (f"Unsupported dtype: { args .dtype } " )
285
+
286
+ if dtype != original_dtype :
287
+ print (
288
+ f"Checkpoint dtype { original_dtype } does not match requested dtype { dtype } . This can lead to unexpected results, proceed with caution."
289
+ )
290
+
200
291
num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in original_ckpt if "joint_blocks" in k ))[- 1 ] + 1 # noqa: C401
201
- caption_projection_dim = 1536
292
+
293
+ caption_projection_dim = get_caption_projection_dim (original_ckpt )
294
+
295
+ # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
296
+ attn2_layers = get_attn2_layers (original_ckpt )
297
+
298
+ # sd3.5 use qk norm("rms_norm")
299
+ has_qk_norm = any ("ln_q" in key for key in original_ckpt .keys ())
300
+
301
+ # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
302
+ pos_embed_max_size = get_pos_embed_max_size (original_ckpt )
202
303
203
304
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers (
204
- original_ckpt , num_layers , caption_projection_dim
305
+ original_ckpt , num_layers , caption_projection_dim , attn2_layers , has_qk_norm
205
306
)
206
307
207
308
with CTX ():
208
309
transformer = SD3Transformer2DModel (
209
- sample_size = 64 ,
310
+ sample_size = 128 ,
210
311
patch_size = 2 ,
211
312
in_channels = 16 ,
212
313
joint_attention_dim = 4096 ,
213
314
num_layers = num_layers ,
214
315
caption_projection_dim = caption_projection_dim ,
215
- num_attention_heads = 24 ,
216
- pos_embed_max_size = 192 ,
316
+ num_attention_heads = num_layers ,
317
+ pos_embed_max_size = pos_embed_max_size ,
318
+ qk_norm = "rms_norm" if has_qk_norm else None ,
319
+ dual_attention_layers = attn2_layers ,
217
320
)
218
321
if is_accelerate_available ():
219
322
load_model_dict_into_meta (transformer , converted_transformer_state_dict )
0 commit comments