Skip to content

Commit 637c49c

Browse files
authored
[feat](kt-kernel): support qwen3-vl weights convert (#1648)
1 parent c256150 commit 637c49c

File tree

1 file changed

+167
-75
lines changed

1 file changed

+167
-75
lines changed

kt-kernel/scripts/convert_cpu_weights.py

Lines changed: 167 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,30 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
7373
with open(config_path, "r") as f:
7474
config = json.load(f)
7575

76+
if "text_config" in config:
77+
text_cfg = config["text_config"]
78+
kt_cvt_type = "vl"
79+
else:
80+
text_cfg = config
81+
kt_cvt_type = "base"
82+
7683
# Extract required fields with fallbacks
7784
model_config = {
78-
"num_experts": config.get("n_routed_experts", config.get("num_experts")),
79-
"num_experts_per_tok": config.get("num_experts_per_tok", 2),
80-
"hidden_size": config.get("hidden_size"),
81-
"moe_intermediate_size": config.get("moe_intermediate_size", config.get("intermediate_size")),
85+
"num_experts": text_cfg.get("n_routed_experts", text_cfg.get("num_experts")),
86+
"num_experts_per_tok": text_cfg.get("num_experts_per_tok", 2),
87+
"hidden_size": text_cfg.get("hidden_size"),
88+
"moe_intermediate_size": text_cfg.get("moe_intermediate_size", text_cfg.get("intermediate_size")),
89+
"_kt_cvt_type": kt_cvt_type,
8290
}
8391

8492
# Validate required fields
85-
missing_fields = [k for k, v in model_config.items() if v is None]
93+
missing_fields = [k for k, v in model_config.items() if k != "_kt_cvt_type" and v is None]
8694
if missing_fields:
8795
raise ValueError(f"Missing required config fields: {missing_fields}")
8896

8997
# For FP8 input, extract and validate quantization_config
9098
if input_type == "fp8":
91-
quant_config = config.get("quantization_config")
99+
quant_config = config.get("quantization_config") or text_cfg.get("quantization_config")
92100
if quant_config is None:
93101
raise ValueError(
94102
"FP8 input type specified but 'quantization_config' not found in config.json. "
@@ -113,6 +121,7 @@ def load_model_config(input_path: str, input_type: str = None) -> Dict:
113121
print(f" format: {quant_config.get('fmt', 'unknown')}")
114122
print(f" weight_block_size: {weight_block_size}")
115123

124+
print(f"Model Type: {model_config['_kt_cvt_type']}")
116125
return model_config
117126

118127

@@ -260,6 +269,7 @@ def __init__(
260269
self.num_experts_per_tok = model_config["num_experts_per_tok"]
261270
self.hidden_size = model_config["hidden_size"]
262271
self.moe_intermediate_size = model_config["moe_intermediate_size"]
272+
self.kt_cvt_type = model_config.get("_kt_cvt_type", "base")
263273

264274
# Load input safetensors files
265275
self._load_input_files()
@@ -302,6 +312,24 @@ def _load_tensor(self, key: str) -> torch.Tensor:
302312
def _find_expert_layers(self) -> Dict[int, List[int]]:
303313
"""Find all layers and experts in the model"""
304314
layers = defaultdict(set)
315+
316+
# vl weights have a fused layout
317+
# Pattern: model.language_model.layers.{layer}.mlp.experts.{proj}
318+
if self.kt_cvt_type == "vl":
319+
layers = set()
320+
for key in self.tensor_file_map.keys():
321+
if "model.language_model.layers." in key and ".mlp.experts." in key:
322+
parts = key.split(".")
323+
if len(parts) >= 7:
324+
layer_idx = int(parts[3])
325+
layers.add(layer_idx)
326+
327+
result: Dict[int, List[int]] = {}
328+
for layer_idx in sorted(layers):
329+
result[layer_idx] = [-1]
330+
331+
print(f"Found {len(result)} layers with fused MoE experts")
332+
return result
305333

306334
# Pattern: model.layers.{layer}.mlp.experts.{expert}.{proj}.{type}
307335
for key in self.tensor_file_map.keys():
@@ -675,76 +703,141 @@ def _remove_layer_folder(self, layer_idx: int):
675703
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
676704
"""Convert all experts in a layer using online quantization via AMXMoEWrapper"""
677705
start_time = time.time()
678-
print(f"Converting layer {layer_idx} with {len(expert_ids)} experts via online quantization...")
679-
706+
print(f"Converting layer {layer_idx} with {len(expert_ids) if self.kt_cvt_type == 'base' else 'fused'} experts via online quantization...")
680707
# Load all expert weights for this layer
681-
gate_weights = []
682-
up_weights = []
683-
down_weights = []
684-
685-
for expert_id in expert_ids:
686-
gate_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight"
687-
up_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight"
688-
down_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight"
689-
690-
if gate_key not in self.tensor_file_map:
691-
raise KeyError(f"Missing gate weight for layer {layer_idx}, expert {expert_id}")
692-
if up_key not in self.tensor_file_map:
693-
raise KeyError(f"Missing up weight for layer {layer_idx}, expert {expert_id}")
694-
if down_key not in self.tensor_file_map:
695-
raise KeyError(f"Missing down weight for layer {layer_idx}, expert {expert_id}")
696-
697-
# Load weights based on input type
698-
if self.input_type == "fp8":
699-
# Load FP8 weights and their scale_inv tensors
700-
gate_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv"
701-
up_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv"
702-
down_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv"
703-
704-
if gate_scale_key not in self.tensor_file_map:
705-
raise KeyError(f"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}")
706-
if up_scale_key not in self.tensor_file_map:
707-
raise KeyError(f"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}")
708-
if down_scale_key not in self.tensor_file_map:
709-
raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}")
710-
711-
# Load FP8 weights and scales
712-
gate_fp8 = self._load_tensor(gate_key).to("cuda")
713-
up_fp8 = self._load_tensor(up_key).to("cuda")
714-
down_fp8 = self._load_tensor(down_key).to("cuda")
715-
716-
gate_scale_inv = self._load_tensor(gate_scale_key).to("cuda")
717-
up_scale_inv = self._load_tensor(up_scale_key).to("cuda")
718-
down_scale_inv = self._load_tensor(down_scale_key).to("cuda")
719-
720-
# Dequantize FP8 to BF16 using block-wise scaling
721-
gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
722-
up_weight = weight_dequant(up_fp8, up_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
723-
down_weight = weight_dequant(down_fp8, down_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
724-
725-
elif self.input_type == "fp16":
726-
# Load FP16 and convert to BF16
727-
gate_weight = self._load_tensor(gate_key).to(torch.bfloat16)
728-
up_weight = self._load_tensor(up_key).to(torch.bfloat16)
729-
down_weight = self._load_tensor(down_key).to(torch.bfloat16)
730-
731-
elif self.input_type == "bf16":
732-
# Load BF16 directly
733-
gate_weight = self._load_tensor(gate_key)
734-
up_weight = self._load_tensor(up_key)
735-
down_weight = self._load_tensor(down_key)
736-
737-
else:
738-
raise ValueError(f"Unsupported input_type for INT4 conversion: {self.input_type}")
708+
if self.kt_cvt_type == "vl":
709+
if self.input_type not in ["bf16", "fp16"]:
710+
raise ValueError(f"VL path currently supports bf16/fp16 only, got input_type={self.input_type}")
711+
712+
proj_set = set()
713+
prefix = f"model.language_model.layers.{layer_idx}.mlp.experts."
714+
for key in self.tensor_file_map.keys():
715+
if key.startswith(prefix):
716+
parts = key.split(".")
717+
if len(parts) >= 7:
718+
proj_set.add(parts[6])
719+
720+
if not proj_set:
721+
raise ValueError(
722+
f"[VL] No fused MoE experts found for layer {layer_idx} under 'model.language_model.layers'"
723+
)
724+
725+
projs = sorted(proj_set)
726+
print(f" [VL] layer {layer_idx} fused proj keys: {projs}")
727+
728+
if len(projs) < 2:
729+
raise ValueError(
730+
f"[VL] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}"
731+
)
732+
733+
fused_tensors = []
734+
for p in projs:
735+
key = f"model.language_model.layers.{layer_idx}.mlp.experts.{p}"
736+
if key not in self.tensor_file_map:
737+
raise KeyError(f"[VL] Missing fused tensor {key} for layer {layer_idx}")
738+
w = self._load_tensor(key)
739+
if self.input_type == "fp16":
740+
w = w.to(torch.bfloat16)
741+
print(f" [VL] tensor {p} shape: {tuple(w.shape)}")
742+
fused_tensors.append(w)
743+
744+
# fused_tensors[0] : down-like, [E, I, H]
745+
# fused_tensors[1] : gate_up-like, [E, H, 2I]
746+
down_fused = fused_tensors[0]
747+
gate_up_fused = fused_tensors[1]
748+
749+
# gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up
750+
if gate_up_fused.dim() != 3:
751+
raise ValueError(f"[VL] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}")
752+
E, H, twoI = gate_up_fused.shape
753+
if twoI % 2 != 0:
754+
raise ValueError(f"[VL] gate_up last dim (2I) not even: {twoI}")
755+
I = twoI // 2
756+
757+
gate_up_T = gate_up_fused.transpose(1, 2).contiguous() # [E, 2I, H]
758+
gate_proj = gate_up_T[:, :I, :] # [E, I, H]
759+
up_proj = gate_up_T[:, I:, :] # [E, I, H]
760+
761+
if down_fused.dim() != 3:
762+
raise ValueError(f"[VL] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}")
763+
if down_fused.shape[0] != E:
764+
raise ValueError(
765+
f"[VL] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}"
766+
)
767+
down_proj = down_fused.transpose(1, 2).contiguous() # [E, H, I]
768+
del fused_tensors
769+
del gate_up_fused
770+
del down_fused
771+
else:
772+
gate_weights = []
773+
up_weights = []
774+
down_weights = []
739775

740-
gate_weights.append(gate_weight)
741-
up_weights.append(up_weight)
742-
down_weights.append(down_weight)
776+
for expert_id in expert_ids:
777+
gate_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight"
778+
up_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight"
779+
down_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight"
780+
781+
if gate_key not in self.tensor_file_map:
782+
raise KeyError(f"Missing gate weight for layer {layer_idx}, expert {expert_id}")
783+
if up_key not in self.tensor_file_map:
784+
raise KeyError(f"Missing up weight for layer {layer_idx}, expert {expert_id}")
785+
if down_key not in self.tensor_file_map:
786+
raise KeyError(f"Missing down weight for layer {layer_idx}, expert {expert_id}")
787+
788+
# Load weights based on input type
789+
if self.input_type == "fp8":
790+
# Load FP8 weights and their scale_inv tensors
791+
gate_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv"
792+
up_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv"
793+
down_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv"
794+
795+
if gate_scale_key not in self.tensor_file_map:
796+
raise KeyError(f"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}")
797+
if up_scale_key not in self.tensor_file_map:
798+
raise KeyError(f"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}")
799+
if down_scale_key not in self.tensor_file_map:
800+
raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}")
801+
802+
# Load FP8 weights and scales
803+
gate_fp8 = self._load_tensor(gate_key).to("cuda")
804+
up_fp8 = self._load_tensor(up_key).to("cuda")
805+
down_fp8 = self._load_tensor(down_key).to("cuda")
806+
807+
gate_scale_inv = self._load_tensor(gate_scale_key).to("cuda")
808+
up_scale_inv = self._load_tensor(up_scale_key).to("cuda")
809+
down_scale_inv = self._load_tensor(down_scale_key).to("cuda")
810+
811+
# Dequantize FP8 to BF16 using block-wise scaling
812+
gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
813+
up_weight = weight_dequant(up_fp8, up_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
814+
down_weight = weight_dequant(down_fp8, down_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
815+
816+
elif self.input_type == "fp16":
817+
# Load FP16 and convert to BF16
818+
gate_weight = self._load_tensor(gate_key).to(torch.bfloat16)
819+
up_weight = self._load_tensor(up_key).to(torch.bfloat16)
820+
down_weight = self._load_tensor(down_key).to(torch.bfloat16)
821+
822+
elif self.input_type == "bf16":
823+
# Load BF16 directly
824+
gate_weight = self._load_tensor(gate_key)
825+
up_weight = self._load_tensor(up_key)
826+
down_weight = self._load_tensor(down_key)
827+
828+
else:
829+
raise ValueError(f"Unsupported input_type for INT4 conversion: {self.input_type}")
830+
831+
gate_weights.append(gate_weight)
832+
up_weights.append(up_weight)
833+
down_weights.append(down_weight)
834+
835+
# Stack weights into single tensors: [num_experts, ...]
836+
gate_proj = torch.stack(gate_weights, dim=0).contiguous()
837+
up_proj = torch.stack(up_weights, dim=0).contiguous()
838+
down_proj = torch.stack(down_weights, dim=0).contiguous()
839+
del gate_weights, up_weights, down_weights
743840

744-
# Stack weights into single tensors: [num_experts, ...]
745-
gate_proj = torch.stack(gate_weights, dim=0).contiguous()
746-
up_proj = torch.stack(up_weights, dim=0).contiguous()
747-
down_proj = torch.stack(down_weights, dim=0).contiguous()
748841

749842
print(f" Loaded weights shapes:")
750843
print(f" gate_proj: {gate_proj.shape}")
@@ -784,8 +877,7 @@ def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[
784877
# This triggers the quantization process and saves to disk
785878
wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)
786879

787-
# Clean up to free memory
788-
del gate_weights, up_weights, down_weights
880+
# Clean up to free memory
789881
del gate_proj, up_proj, down_proj
790882
gc.collect()
791883

0 commit comments

Comments
 (0)