@@ -73,22 +73,30 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
7373 with open (config_path , "r" ) as f :
7474 config = json .load (f )
7575
76+ if "text_config" in config :
77+ text_cfg = config ["text_config" ]
78+ kt_cvt_type = "vl"
79+ else :
80+ text_cfg = config
81+ kt_cvt_type = "base"
82+
7683 # Extract required fields with fallbacks
7784 model_config = {
78- "num_experts" : config .get ("n_routed_experts" , config .get ("num_experts" )),
79- "num_experts_per_tok" : config .get ("num_experts_per_tok" , 2 ),
80- "hidden_size" : config .get ("hidden_size" ),
81- "moe_intermediate_size" : config .get ("moe_intermediate_size" , config .get ("intermediate_size" )),
85+ "num_experts" : text_cfg .get ("n_routed_experts" , text_cfg .get ("num_experts" )),
86+ "num_experts_per_tok" : text_cfg .get ("num_experts_per_tok" , 2 ),
87+ "hidden_size" : text_cfg .get ("hidden_size" ),
88+ "moe_intermediate_size" : text_cfg .get ("moe_intermediate_size" , text_cfg .get ("intermediate_size" )),
89+ "_kt_cvt_type" : kt_cvt_type ,
8290 }
8391
8492 # Validate required fields
85- missing_fields = [k for k , v in model_config .items () if v is None ]
93+ missing_fields = [k for k , v in model_config .items () if k != "_kt_cvt_type" and v is None ]
8694 if missing_fields :
8795 raise ValueError (f"Missing required config fields: { missing_fields } " )
8896
8997 # For FP8 input, extract and validate quantization_config
9098 if input_type == "fp8" :
91- quant_config = config .get ("quantization_config" )
99+ quant_config = config .get ("quantization_config" ) or text_cfg . get ( "quantization_config" )
92100 if quant_config is None :
93101 raise ValueError (
94102 "FP8 input type specified but 'quantization_config' not found in config.json. "
@@ -113,6 +121,7 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
113121 print (f" format: { quant_config .get ('fmt' , 'unknown' )} " )
114122 print (f" weight_block_size: { weight_block_size } " )
115123
124+ print (f"Model Type: { model_config ['_kt_cvt_type' ]} " )
116125 return model_config
117126
118127
@@ -260,6 +269,7 @@ def __init__(
260269 self .num_experts_per_tok = model_config ["num_experts_per_tok" ]
261270 self .hidden_size = model_config ["hidden_size" ]
262271 self .moe_intermediate_size = model_config ["moe_intermediate_size" ]
272+ self .kt_cvt_type = model_config .get ("_kt_cvt_type" , "base" )
263273
264274 # Load input safetensors files
265275 self ._load_input_files ()
@@ -302,6 +312,24 @@ def _load_tensor(self, key: str) -> torch.Tensor:
302312 def _find_expert_layers (self ) -> Dict [int , List [int ]]:
303313 """Find all layers and experts in the model"""
304314 layers = defaultdict (set )
315+
316+ # vl weights have a fused layout
317+ # Pattern: model.language_model.layers.{layer}.mlp.experts.{proj}
318+ if self .kt_cvt_type == "vl" :
319+ layers = set ()
320+ for key in self .tensor_file_map .keys ():
321+ if "model.language_model.layers." in key and ".mlp.experts." in key :
322+ parts = key .split ("." )
323+ if len (parts ) >= 7 :
324+ layer_idx = int (parts [3 ])
325+ layers .add (layer_idx )
326+
327+ result : Dict [int , List [int ]] = {}
328+ for layer_idx in sorted (layers ):
329+ result [layer_idx ] = [- 1 ]
330+
331+ print (f"Found { len (result )} layers with fused MoE experts" )
332+ return result
305333
306334 # Pattern: model.layers.{layer}.mlp.experts.{expert}.{proj}.{type}
307335 for key in self .tensor_file_map .keys ():
@@ -675,76 +703,141 @@ def _remove_layer_folder(self, layer_idx: int):
675703 def _convert_layer_experts (self , layer_idx : int , expert_ids : List [int ]) -> Dict [str , torch .Tensor ]:
676704 """Convert all experts in a layer using online quantization via AMXMoEWrapper"""
677705 start_time = time .time ()
678- print (f"Converting layer { layer_idx } with { len (expert_ids )} experts via online quantization..." )
679-
706+ print (f"Converting layer { layer_idx } with { len (expert_ids ) if self .kt_cvt_type == 'base' else 'fused' } experts via online quantization..." )
680707 # Load all expert weights for this layer
681- gate_weights = []
682- up_weights = []
683- down_weights = []
684-
685- for expert_id in expert_ids :
686- gate_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .gate_proj.weight"
687- up_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .up_proj.weight"
688- down_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .down_proj.weight"
689-
690- if gate_key not in self .tensor_file_map :
691- raise KeyError (f"Missing gate weight for layer { layer_idx } , expert { expert_id } " )
692- if up_key not in self .tensor_file_map :
693- raise KeyError (f"Missing up weight for layer { layer_idx } , expert { expert_id } " )
694- if down_key not in self .tensor_file_map :
695- raise KeyError (f"Missing down weight for layer { layer_idx } , expert { expert_id } " )
696-
697- # Load weights based on input type
698- if self .input_type == "fp8" :
699- # Load FP8 weights and their scale_inv tensors
700- gate_scale_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .gate_proj.weight_scale_inv"
701- up_scale_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .up_proj.weight_scale_inv"
702- down_scale_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .down_proj.weight_scale_inv"
703-
704- if gate_scale_key not in self .tensor_file_map :
705- raise KeyError (f"Missing gate weight_scale_inv for layer { layer_idx } , expert { expert_id } " )
706- if up_scale_key not in self .tensor_file_map :
707- raise KeyError (f"Missing up weight_scale_inv for layer { layer_idx } , expert { expert_id } " )
708- if down_scale_key not in self .tensor_file_map :
709- raise KeyError (f"Missing down weight_scale_inv for layer { layer_idx } , expert { expert_id } " )
710-
711- # Load FP8 weights and scales
712- gate_fp8 = self ._load_tensor (gate_key ).to ("cuda" )
713- up_fp8 = self ._load_tensor (up_key ).to ("cuda" )
714- down_fp8 = self ._load_tensor (down_key ).to ("cuda" )
715-
716- gate_scale_inv = self ._load_tensor (gate_scale_key ).to ("cuda" )
717- up_scale_inv = self ._load_tensor (up_scale_key ).to ("cuda" )
718- down_scale_inv = self ._load_tensor (down_scale_key ).to ("cuda" )
719-
720- # Dequantize FP8 to BF16 using block-wise scaling
721- gate_weight = weight_dequant (gate_fp8 , gate_scale_inv ).to ("cpu" ).to (torch .bfloat16 ).contiguous ()
722- up_weight = weight_dequant (up_fp8 , up_scale_inv ).to ("cpu" ).to (torch .bfloat16 ).contiguous ()
723- down_weight = weight_dequant (down_fp8 , down_scale_inv ).to ("cpu" ).to (torch .bfloat16 ).contiguous ()
724-
725- elif self .input_type == "fp16" :
726- # Load FP16 and convert to BF16
727- gate_weight = self ._load_tensor (gate_key ).to (torch .bfloat16 )
728- up_weight = self ._load_tensor (up_key ).to (torch .bfloat16 )
729- down_weight = self ._load_tensor (down_key ).to (torch .bfloat16 )
730-
731- elif self .input_type == "bf16" :
732- # Load BF16 directly
733- gate_weight = self ._load_tensor (gate_key )
734- up_weight = self ._load_tensor (up_key )
735- down_weight = self ._load_tensor (down_key )
736-
737- else :
738- raise ValueError (f"Unsupported input_type for INT4 conversion: { self .input_type } " )
708+ if self .kt_cvt_type == "vl" :
709+ if self .input_type not in ["bf16" , "fp16" ]:
710+ raise ValueError (f"VL path currently supports bf16/fp16 only, got input_type={ self .input_type } " )
711+
712+ proj_set = set ()
713+ prefix = f"model.language_model.layers.{ layer_idx } .mlp.experts."
714+ for key in self .tensor_file_map .keys ():
715+ if key .startswith (prefix ):
716+ parts = key .split ("." )
717+ if len (parts ) >= 7 :
718+ proj_set .add (parts [6 ])
719+
720+ if not proj_set :
721+ raise ValueError (
722+ f"[VL] No fused MoE experts found for layer { layer_idx } under 'model.language_model.layers'"
723+ )
724+
725+ projs = sorted (proj_set )
726+ print (f" [VL] layer { layer_idx } fused proj keys: { projs } " )
727+
728+ if len (projs ) < 2 :
729+ raise ValueError (
730+ f"[VL] Expect at least 2 fused tensors (down & gate_up) in layer { layer_idx } , got { len (projs )} "
731+ )
732+
733+ fused_tensors = []
734+ for p in projs :
735+ key = f"model.language_model.layers.{ layer_idx } .mlp.experts.{ p } "
736+ if key not in self .tensor_file_map :
737+ raise KeyError (f"[VL] Missing fused tensor { key } for layer { layer_idx } " )
738+ w = self ._load_tensor (key )
739+ if self .input_type == "fp16" :
740+ w = w .to (torch .bfloat16 )
741+ print (f" [VL] tensor { p } shape: { tuple (w .shape )} " )
742+ fused_tensors .append (w )
743+
744+ # fused_tensors[0] : down-like, [E, I, H]
745+ # fused_tensors[1] : gate_up-like, [E, H, 2I]
746+ down_fused = fused_tensors [0 ]
747+ gate_up_fused = fused_tensors [1 ]
748+
749+ # gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
750+ if gate_up_fused .dim () != 3 :
751+ raise ValueError (f"[VL] Expect gate_up fused tensor to be 3D, got shape { tuple (gate_up_fused .shape )} " )
752+ E , H , twoI = gate_up_fused .shape
753+ if twoI % 2 != 0 :
754+ raise ValueError (f"[VL] gate_up last dim (2I) not even: { twoI } " )
755+ I = twoI // 2
756+
757+ gate_up_T = gate_up_fused .transpose (1 , 2 ).contiguous () # [E, 2I, H]
758+ gate_proj = gate_up_T [:, :I , :] # [E, I, H]
759+ up_proj = gate_up_T [:, I :, :] # [E, I, H]
760+
761+ if down_fused .dim () != 3 :
762+ raise ValueError (f"[VL] Expect down fused tensor to be 3D, got shape { tuple (down_fused .shape )} " )
763+ if down_fused .shape [0 ] != E :
764+ raise ValueError (
765+ f"[VL] down_fused expert dim mismatch: { down_fused .shape [0 ]} vs gate_up { E } "
766+ )
767+ down_proj = down_fused .transpose (1 , 2 ).contiguous () # [E, H, I]
768+ del fused_tensors
769+ del gate_up_fused
770+ del down_fused
771+ else :
772+ gate_weights = []
773+ up_weights = []
774+ down_weights = []
739775
740- gate_weights .append (gate_weight )
741- up_weights .append (up_weight )
742- down_weights .append (down_weight )
776+ for expert_id in expert_ids :
777+ gate_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .gate_proj.weight"
778+ up_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .up_proj.weight"
779+ down_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .down_proj.weight"
780+
781+ if gate_key not in self .tensor_file_map :
782+ raise KeyError (f"Missing gate weight for layer { layer_idx } , expert { expert_id } " )
783+ if up_key not in self .tensor_file_map :
784+ raise KeyError (f"Missing up weight for layer { layer_idx } , expert { expert_id } " )
785+ if down_key not in self .tensor_file_map :
786+ raise KeyError (f"Missing down weight for layer { layer_idx } , expert { expert_id } " )
787+
788+ # Load weights based on input type
789+ if self .input_type == "fp8" :
790+ # Load FP8 weights and their scale_inv tensors
791+ gate_scale_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .gate_proj.weight_scale_inv"
792+ up_scale_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .up_proj.weight_scale_inv"
793+ down_scale_key = f"model.layers.{ layer_idx } .mlp.experts.{ expert_id } .down_proj.weight_scale_inv"
794+
795+ if gate_scale_key not in self .tensor_file_map :
796+ raise KeyError (f"Missing gate weight_scale_inv for layer { layer_idx } , expert { expert_id } " )
797+ if up_scale_key not in self .tensor_file_map :
798+ raise KeyError (f"Missing up weight_scale_inv for layer { layer_idx } , expert { expert_id } " )
799+ if down_scale_key not in self .tensor_file_map :
800+ raise KeyError (f"Missing down weight_scale_inv for layer { layer_idx } , expert { expert_id } " )
801+
802+ # Load FP8 weights and scales
803+ gate_fp8 = self ._load_tensor (gate_key ).to ("cuda" )
804+ up_fp8 = self ._load_tensor (up_key ).to ("cuda" )
805+ down_fp8 = self ._load_tensor (down_key ).to ("cuda" )
806+
807+ gate_scale_inv = self ._load_tensor (gate_scale_key ).to ("cuda" )
808+ up_scale_inv = self ._load_tensor (up_scale_key ).to ("cuda" )
809+ down_scale_inv = self ._load_tensor (down_scale_key ).to ("cuda" )
810+
811+ # Dequantize FP8 to BF16 using block-wise scaling
812+ gate_weight = weight_dequant (gate_fp8 , gate_scale_inv ).to ("cpu" ).to (torch .bfloat16 ).contiguous ()
813+ up_weight = weight_dequant (up_fp8 , up_scale_inv ).to ("cpu" ).to (torch .bfloat16 ).contiguous ()
814+ down_weight = weight_dequant (down_fp8 , down_scale_inv ).to ("cpu" ).to (torch .bfloat16 ).contiguous ()
815+
816+ elif self .input_type == "fp16" :
817+ # Load FP16 and convert to BF16
818+ gate_weight = self ._load_tensor (gate_key ).to (torch .bfloat16 )
819+ up_weight = self ._load_tensor (up_key ).to (torch .bfloat16 )
820+ down_weight = self ._load_tensor (down_key ).to (torch .bfloat16 )
821+
822+ elif self .input_type == "bf16" :
823+ # Load BF16 directly
824+ gate_weight = self ._load_tensor (gate_key )
825+ up_weight = self ._load_tensor (up_key )
826+ down_weight = self ._load_tensor (down_key )
827+
828+ else :
829+ raise ValueError (f"Unsupported input_type for INT4 conversion: { self .input_type } " )
830+
831+ gate_weights .append (gate_weight )
832+ up_weights .append (up_weight )
833+ down_weights .append (down_weight )
834+
835+ # Stack weights into single tensors: [num_experts, ...]
836+ gate_proj = torch .stack (gate_weights , dim = 0 ).contiguous ()
837+ up_proj = torch .stack (up_weights , dim = 0 ).contiguous ()
838+ down_proj = torch .stack (down_weights , dim = 0 ).contiguous ()
839+ del gate_weights , up_weights , down_weights
743840
744- # Stack weights into single tensors: [num_experts, ...]
745- gate_proj = torch .stack (gate_weights , dim = 0 ).contiguous ()
746- up_proj = torch .stack (up_weights , dim = 0 ).contiguous ()
747- down_proj = torch .stack (down_weights , dim = 0 ).contiguous ()
748841
749842 print (f" Loaded weights shapes:" )
750843 print (f" gate_proj: { gate_proj .shape } " )
@@ -784,8 +877,7 @@ def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[
784877 # This triggers the quantization process and saves to disk
785878 wrapper .load_weights_from_tensors (gate_proj , up_proj , down_proj , physical_to_logical_map )
786879
787- # Clean up to free memory
788- del gate_weights , up_weights , down_weights
880+ # Clean up to free memory
789881 del gate_proj , up_proj , down_proj
790882 gc .collect ()
791883
0 commit comments