Skip to content

Commit 20262b2

Browse files
Fix Qwen3.5 FP8 load for VL detection (#1857)
* Fix Qwen3.5 FP8 load for VL detection 1, for VL models(Qwen3.5), modify base_key: model.layers.{N} -> model.language_model.layers.{N} 2, clean DUPLICATED class BF16SafeTensorLoader(SafeTensorLoader) , only the first overrided one. * Indent type Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 786987a commit 20262b2

File tree

1 file changed

+7
-102
lines changed

1 file changed

+7
-102
lines changed

kt-kernel/python/utils/loader.py

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def __init__(self, file_path: str, scale_suffix: str = None):
275275
self._is_per_channel = False
276276
else:
277277
self._is_per_channel = False # Will be updated in _detect_format if auto-detect
278+
self._is_vl_model = False
278279
self._detect_format()
279280

280281
def _detect_format(self):
@@ -313,6 +314,10 @@ def _detect_format(self):
313314
self._scale_suffix = "weight_scale_inv"
314315
self._is_per_channel = False
315316
print("[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)")
317+
if key.startswith("model.language_model.") and self._detected_format == "deepseek":
318+
# VL models(Qwen3.5): model.layers.{N} -> model.language_model.layers.{N}
319+
self._is_vl_model = True
320+
print("[FP8SafeTensorLoader] Detected VL model")
316321
return
317322
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
318323
self._scale_suffix = "weight_scale"
@@ -331,6 +336,8 @@ def _detect_format(self):
331336
def _get_experts_prefix(self, base_key: str) -> str:
332337
"""Get the experts prefix based on detected format."""
333338
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
339+
if self._is_vl_model:
340+
base_key = base_key.replace("model.layers", "model.language_model.layers")
334341
return path_tpl.format(base=base_key)
335342

336343
def _get_proj_names(self):
@@ -416,108 +423,6 @@ def is_per_channel(self) -> bool:
416423
return self._is_per_channel
417424

418425

419-
class BF16SafeTensorLoader(SafeTensorLoader):
420-
"""Loader for native BF16 expert weights (no quantization, no scales).
421-
422-
Supported formats:
423-
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
424-
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
425-
426-
The format is auto-detected during initialization.
427-
"""
428-
429-
MOE_FORMATS = {
430-
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
431-
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
432-
}
433-
434-
def __init__(self, file_path: str):
435-
super().__init__(file_path)
436-
self._detected_format = None
437-
self._detect_format()
438-
439-
def _detect_format(self):
440-
"""Auto-detect the MoE naming format by checking tensor keys."""
441-
sample_keys = list(self.tensor_file_map.keys())[:1000]
442-
443-
# Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor)
444-
for key in sample_keys:
445-
if key.endswith(".mlp.experts.gate_up_proj"):
446-
self._detected_format = "packed"
447-
print("[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)")
448-
return
449-
450-
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
451-
for key in sample_keys:
452-
if ".experts." in key and f".{gate}.weight" in key:
453-
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
454-
self._detected_format = fmt_name
455-
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
456-
return
457-
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
458-
self._detected_format = fmt_name
459-
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
460-
return
461-
462-
self._detected_format = "deepseek"
463-
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
464-
465-
def _get_experts_prefix(self, base_key: str) -> str:
466-
"""Get the experts prefix based on detected format."""
467-
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
468-
return path_tpl.format(base=base_key)
469-
470-
def _get_proj_names(self):
471-
"""Get projection names (gate, up, down) based on detected format."""
472-
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
473-
return gate, up, down
474-
475-
def load_tensor(self, key: str, device: str = "cpu"):
476-
if key not in self.tensor_file_map:
477-
raise KeyError(f"Key {key} not found in Safetensor files")
478-
file = self.tensor_file_map[key]
479-
f = self.file_handle_map.get(file)
480-
if f is None:
481-
raise FileNotFoundError(f"File {file} not found in Safetensor files")
482-
tensor = f.get_tensor(key)
483-
if device == "cpu":
484-
return tensor
485-
return tensor.to(device)
486-
487-
def load_experts(self, base_key: str, device: str = "cpu"):
488-
"""Load BF16 expert weights (no scales needed)."""
489-
if self._detected_format == "packed":
490-
return self._load_experts_packed(base_key, device)
491-
492-
experts_prefix = self._get_experts_prefix(base_key)
493-
gate_name, up_name, down_name = self._get_proj_names()
494-
495-
expert_count = 0
496-
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
497-
expert_count += 1
498-
499-
if expert_count == 0:
500-
raise ValueError(f"No experts found for key {experts_prefix}")
501-
502-
gate_weights = [None] * expert_count
503-
up_weights = [None] * expert_count
504-
down_weights = [None] * expert_count
505-
506-
for exp_id in range(expert_count):
507-
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
508-
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
509-
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
510-
511-
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
512-
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
513-
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
514-
515-
return {
516-
"gate": gate_weights,
517-
"up": up_weights,
518-
"down": down_weights,
519-
}
520-
521426

522427
class BF16SafeTensorLoader(SafeTensorLoader):
523428
"""Loader for native BF16 expert weights (no quantization, no scales).

0 commit comments

Comments
 (0)