@@ -275,6 +275,7 @@ def __init__(self, file_path: str, scale_suffix: str = None):
275275 self ._is_per_channel = False
276276 else :
277277 self ._is_per_channel = False # Will be updated in _detect_format if auto-detect
278+ self ._is_vl_model = False
278279 self ._detect_format ()
279280
280281 def _detect_format (self ):
@@ -313,6 +314,10 @@ def _detect_format(self):
313314 self ._scale_suffix = "weight_scale_inv"
314315 self ._is_per_channel = False
315316 print ("[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)" )
317+ if key .startswith ("model.language_model." ) and self ._detected_format == "deepseek" :
318+ # VL models(Qwen3.5): model.layers.{N} -> model.language_model.layers.{N}
319+ self ._is_vl_model = True
320+ print ("[FP8SafeTensorLoader] Detected VL model" )
316321 return
317322 elif f".{ gate } .weight_scale" in key and "weight_scale_inv" not in key :
318323 self ._scale_suffix = "weight_scale"
@@ -331,6 +336,8 @@ def _detect_format(self):
331336 def _get_experts_prefix (self , base_key : str ) -> str :
332337 """Get the experts prefix based on detected format."""
333338 path_tpl , _ , _ , _ = self .MOE_FORMATS [self ._detected_format ]
339+ if self ._is_vl_model :
340+ base_key = base_key .replace ("model.layers" , "model.language_model.layers" )
334341 return path_tpl .format (base = base_key )
335342
336343 def _get_proj_names (self ):
@@ -416,108 +423,6 @@ def is_per_channel(self) -> bool:
416423 return self ._is_per_channel
417424
418425
419- class BF16SafeTensorLoader (SafeTensorLoader ):
420- """Loader for native BF16 expert weights (no quantization, no scales).
421-
422- Supported formats:
423- - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
424- - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
425-
426- The format is auto-detected during initialization.
427- """
428-
429- MOE_FORMATS = {
430- "deepseek" : ("{base}.mlp.experts" , "gate_proj" , "up_proj" , "down_proj" ),
431- "mixtral" : ("{base}.block_sparse_moe.experts" , "w1" , "w3" , "w2" ),
432- }
433-
434- def __init__ (self , file_path : str ):
435- super ().__init__ (file_path )
436- self ._detected_format = None
437- self ._detect_format ()
438-
439- def _detect_format (self ):
440- """Auto-detect the MoE naming format by checking tensor keys."""
441- sample_keys = list (self .tensor_file_map .keys ())[:1000 ]
442-
443- # Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor)
444- for key in sample_keys :
445- if key .endswith (".mlp.experts.gate_up_proj" ):
446- self ._detected_format = "packed"
447- print ("[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)" )
448- return
449-
450- for fmt_name , (path_tpl , gate , up , down ) in self .MOE_FORMATS .items ():
451- for key in sample_keys :
452- if ".experts." in key and f".{ gate } .weight" in key :
453- if "block_sparse_moe.experts" in key and fmt_name == "mixtral" :
454- self ._detected_format = fmt_name
455- print (f"[BF16SafeTensorLoader] Detected format: { fmt_name } " )
456- return
457- elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek" :
458- self ._detected_format = fmt_name
459- print (f"[BF16SafeTensorLoader] Detected format: { fmt_name } " )
460- return
461-
462- self ._detected_format = "deepseek"
463- print ("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek" )
464-
465- def _get_experts_prefix (self , base_key : str ) -> str :
466- """Get the experts prefix based on detected format."""
467- path_tpl , _ , _ , _ = self .MOE_FORMATS [self ._detected_format ]
468- return path_tpl .format (base = base_key )
469-
470- def _get_proj_names (self ):
471- """Get projection names (gate, up, down) based on detected format."""
472- _ , gate , up , down = self .MOE_FORMATS [self ._detected_format ]
473- return gate , up , down
474-
475- def load_tensor (self , key : str , device : str = "cpu" ):
476- if key not in self .tensor_file_map :
477- raise KeyError (f"Key { key } not found in Safetensor files" )
478- file = self .tensor_file_map [key ]
479- f = self .file_handle_map .get (file )
480- if f is None :
481- raise FileNotFoundError (f"File { file } not found in Safetensor files" )
482- tensor = f .get_tensor (key )
483- if device == "cpu" :
484- return tensor
485- return tensor .to (device )
486-
487- def load_experts (self , base_key : str , device : str = "cpu" ):
488- """Load BF16 expert weights (no scales needed)."""
489- if self ._detected_format == "packed" :
490- return self ._load_experts_packed (base_key , device )
491-
492- experts_prefix = self ._get_experts_prefix (base_key )
493- gate_name , up_name , down_name = self ._get_proj_names ()
494-
495- expert_count = 0
496- while self .has_tensor (f"{ experts_prefix } .{ expert_count } .{ gate_name } .weight" ):
497- expert_count += 1
498-
499- if expert_count == 0 :
500- raise ValueError (f"No experts found for key { experts_prefix } " )
501-
502- gate_weights = [None ] * expert_count
503- up_weights = [None ] * expert_count
504- down_weights = [None ] * expert_count
505-
506- for exp_id in range (expert_count ):
507- gate_w_key = f"{ experts_prefix } .{ exp_id } .{ gate_name } .weight"
508- up_w_key = f"{ experts_prefix } .{ exp_id } .{ up_name } .weight"
509- down_w_key = f"{ experts_prefix } .{ exp_id } .{ down_name } .weight"
510-
511- gate_weights [exp_id ] = self .load_tensor (gate_w_key , device ).contiguous ()
512- up_weights [exp_id ] = self .load_tensor (up_w_key , device ).contiguous ()
513- down_weights [exp_id ] = self .load_tensor (down_w_key , device ).contiguous ()
514-
515- return {
516- "gate" : gate_weights ,
517- "up" : up_weights ,
518- "down" : down_weights ,
519- }
520-
521426
522427class BF16SafeTensorLoader (SafeTensorLoader ):
523428 """Loader for native BF16 expert weights (no quantization, no scales).
0 commit comments