@@ -533,6 +533,75 @@ def copy_weights_qwen_2_5(
533
533
pbar .update (progress_per_file )
534
534
535
535
536
+ def copy_weights_qwen_3 (
537
+ config : Config ,
538
+ qkv_weights : Dict [int , List [Optional [NotYetLoadedTensor ]]],
539
+ state_dict : Dict [str , torch .Tensor ],
540
+ hf_weights : Dict [str , Union [torch .Tensor , NotYetLoadedTensor ]],
541
+ saver : Optional [incremental_save ] = None ,
542
+ dtype : Optional [torch .dtype ] = None ,
543
+ pbar : Optional [tqdm ] = None ,
544
+ progress_per_file : Optional [float ] = None ,
545
+ debug_mode : Optional [bool ] = False ,
546
+ ) -> None :
547
+ weight_map = {
548
+ "model.embed_tokens.weight" : "transformer.wte.weight" ,
549
+ "model.layers.{}.input_layernorm.weight" : "transformer.h.{}.norm_1.weight" ,
550
+ "model.layers.{}.self_attn.q_proj.weight" : None ,
551
+ "model.layers.{}.self_attn.k_proj.weight" : None ,
552
+ "model.layers.{}.self_attn.v_proj.weight" : None ,
553
+ "model.layers.{}.self_attn.q_norm.weight" : "transformer.h.{}.attn.norm_q.weight" ,
554
+ "model.layers.{}.self_attn.k_norm.weight" : "transformer.h.{}.attn.norm_k.weight" ,
555
+ "model.layers.{}.self_attn.o_proj.weight" : "transformer.h.{}.attn.proj.weight" ,
556
+ "model.layers.{}.post_attention_layernorm.weight" : "transformer.h.{}.norm_2.weight" ,
557
+ "model.layers.{}.mlp.gate_proj.weight" : "transformer.h.{}.mlp.fc_1.weight" ,
558
+ "model.layers.{}.mlp.up_proj.weight" : "transformer.h.{}.mlp.fc_2.weight" ,
559
+ "model.layers.{}.mlp.down_proj.weight" : "transformer.h.{}.mlp.proj.weight" ,
560
+ "model.norm.weight" : "transformer.ln_f.weight" ,
561
+ "lm_head.weight" : "lm_head.weight" ,
562
+ }
563
+
564
+ if progress_per_file is not None :
565
+ progress_per_file = progress_per_file / max (1 , len (hf_weights ) + len (qkv_weights ))
566
+
567
+ for from_name , param in hf_weights .items ():
568
+ name_template , * ids = layer_template (from_name , num_matches = 2 )
569
+ to_name = weight_map [name_template ]
570
+ param = load_param (param , from_name , dtype , verbose = debug_mode )
571
+ if any (w in from_name for w in ("q_proj" , "k_proj" , "v_proj" )):
572
+ qkv = qkv_weights .setdefault (ids [0 ], defaultdict (dict ))
573
+ weight_name , weight_type = from_name .split ("." )[- 2 :]
574
+ qkv [weight_type ][weight_name ] = param
575
+ if to_name is None :
576
+ continue
577
+ to_name = to_name .format (* ids )
578
+ if saver is not None :
579
+ param = saver .store_early (param )
580
+ state_dict [to_name ] = param
581
+
582
+ if progress_per_file is not None :
583
+ pbar .update (progress_per_file )
584
+
585
+ if "lm_head.weight" not in state_dict :
586
+ state_dict ["lm_head.weight" ] = state_dict ["transformer.wte.weight" ]
587
+
588
+ for i in list (qkv_weights ):
589
+ for weight_type in list (qkv_weights [i ]):
590
+ qkv = qkv_weights [i ][weight_type ]
591
+ if len (qkv ) != 3 :
592
+ # qkv is split across different .bin files
593
+ continue
594
+ q = load_param (qkv ["q_proj" ], f"layer { i } q { weight_type } " , dtype , verbose = debug_mode )
595
+ k = load_param (qkv ["k_proj" ], f"layer { i } k { weight_type } " , dtype , verbose = debug_mode )
596
+ v = load_param (qkv ["v_proj" ], f"layer { i } v { weight_type } " , dtype , verbose = debug_mode )
597
+ qkv = torch .cat ((q , k , v ))
598
+ state_dict [f"transformer.h.{ i } .attn.qkv.{ weight_type } " ] = qkv
599
+ del qkv_weights [i ][weight_type ]
600
+
601
+ if progress_per_file is not None :
602
+ pbar .update (progress_per_file )
603
+
604
+
536
605
def qkv_reassemble (
537
606
param : Union [torch .Tensor , NotYetLoadedTensor ], config : Config
538
607
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
@@ -624,6 +693,10 @@ def convert_hf_checkpoint(
624
693
# holder to reconstitute the split q, k, v
625
694
qkv_weights = {}
626
695
copy_fn = partial (copy_weights_qwen_2_5 , config , qkv_weights )
696
+ elif model_name .lower ().startswith ("qwen3" ):
697
+ # holder to reconstitute the split q, k, v
698
+ qkv_weights = {}
699
+ copy_fn = partial (copy_weights_qwen_3 , config , qkv_weights )
627
700
elif config .mlp_class_name in ("LLaMAMLP" , "GemmaMLP" , "LLaMAMoE" ):
628
701
# holder to reconstitute the split q, k, v
629
702
qkv_weights = {}
0 commit comments