@@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
243243 Supported formats:
244244 - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
245245 - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
246+ - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
246247
247248 Supported scale formats (auto-detected):
248249 - Block-wise: weight_scale_inv (DeepSeek FP8)
@@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
255256 MOE_FORMATS = {
256257 "deepseek" : ("{base}.mlp.experts" , "gate_proj" , "up_proj" , "down_proj" ),
257258 "mixtral" : ("{base}.block_sparse_moe.experts" , "w1" , "w3" , "w2" ),
259+ "mistral" : ("{base}.experts" , "w1" , "w3" , "w2" ),
258260 }
259261
260262 def __init__ (self , file_path : str , scale_suffix : str = None ):
@@ -297,6 +299,10 @@ def _detect_format(self):
297299 self ._detected_format = fmt_name
298300 print (f"[FP8SafeTensorLoader] Detected format: { fmt_name } " )
299301 break
302+ elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key :
303+ self ._detected_format = fmt_name
304+ print (f"[FP8SafeTensorLoader] Detected format: { fmt_name } " )
305+ break
300306 if self ._detected_format :
301307 break
302308
@@ -321,8 +327,21 @@ def _detect_format(self):
321327 return
322328 elif f".{ gate } .weight_scale" in key and "weight_scale_inv" not in key :
323329 self ._scale_suffix = "weight_scale"
324- self ._is_per_channel = True
325- print ("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)" )
330+ # Some models (e.g., Mistral) use block-wise FP8 scales but keep
331+ # the key suffix as `weight_scale` (without `_inv`). Infer format
332+ # from scale tensor shape instead of suffix alone:
333+ # - per-channel: [N] or [N, 1]
334+ # - block-wise: [N_block, K_block] (both dims > 1)
335+ scale_tensor = self .load_tensor (key , device = "cpu" )
336+ if scale_tensor .dim () == 1 :
337+ self ._is_per_channel = True
338+ elif scale_tensor .dim () == 2 and scale_tensor .shape [1 ] == 1 :
339+ self ._is_per_channel = True
340+ else :
341+ self ._is_per_channel = False
342+
343+ scale_kind = "per-channel" if self ._is_per_channel else "block-wise"
344+ print (f"[FP8SafeTensorLoader] Detected scale format: { scale_kind } (weight_scale)" )
326345 return
327346 # Default to weight_scale_inv
328347 self ._scale_suffix = "weight_scale_inv"
@@ -333,12 +352,20 @@ def _detect_format(self):
333352 scale_type = "per-channel" if self ._is_per_channel else "block-wise"
334353 print (f"[FP8SafeTensorLoader] Using explicit scale format: { scale_type } ({ self ._scale_suffix } )" )
335354
336- def _get_experts_prefix (self , base_key : str ) -> str :
337- """Get the experts prefix based on detected format."""
355+ def _get_experts_prefix_candidates (self , base_key : str ) -> list [ str ] :
356+ """Get candidate experts prefixes based on detected format and base key variants ."""
338357 path_tpl , _ , _ , _ = self .MOE_FORMATS [self ._detected_format ]
358+ candidates = []
339359 if self ._is_vl_model :
340360 base_key = base_key .replace ("model.layers" , "model.language_model.layers" )
341- return path_tpl .format (base = base_key )
361+ candidates .append (path_tpl .format (base = base_key ))
362+
363+ # Some model weights (e.g., Mistral native format) do not have "model." prefix.
364+ if base_key .startswith ("model." ):
365+ candidates .append (path_tpl .format (base = base_key [len ("model." ) :]))
366+
367+ # Deduplicate while preserving order.
368+ return list (dict .fromkeys (candidates ))
342369
343370 def _get_proj_names (self ):
344371 """Get projection names (gate, up, down) based on detected format."""
@@ -363,15 +390,21 @@ def load_experts(self, base_key: str, device: str = "cpu"):
363390 Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
364391 Per-channel scales are squeezed from [N, 1] to [N] if needed.
365392 """
366- experts_prefix = self ._get_experts_prefix (base_key )
393+ experts_prefix_candidates = self ._get_experts_prefix_candidates (base_key )
367394 gate_name , up_name , down_name = self ._get_proj_names ()
368395
369396 expert_count = 0
370- while self .has_tensor (f"{ experts_prefix } .{ expert_count } .{ gate_name } .weight" ):
371- expert_count += 1
397+ experts_prefix = None
398+ for prefix in experts_prefix_candidates :
399+ expert_count = 0
400+ while self .has_tensor (f"{ prefix } .{ expert_count } .{ gate_name } .weight" ):
401+ expert_count += 1
402+ if expert_count > 0 :
403+ experts_prefix = prefix
404+ break
372405
373- if expert_count == 0 :
374- raise ValueError (f"No experts found for key { experts_prefix } " )
406+ if expert_count == 0 or experts_prefix is None :
407+ raise ValueError (f"No experts found for keys: { experts_prefix_candidates } " )
375408
376409 gate_weights = [None ] * expert_count
377410 up_weights = [None ] * expert_count
@@ -423,20 +456,21 @@ def is_per_channel(self) -> bool:
423456 return self ._is_per_channel
424457
425458
426-
427459class BF16SafeTensorLoader (SafeTensorLoader ):
428460 """Loader for native BF16 expert weights (no quantization, no scales).
429461
430462 Supported formats:
431463 - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
432464 - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
465+ - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
433466
434467 The format is auto-detected during initialization.
435468 """
436469
437470 MOE_FORMATS = {
438471 "deepseek" : ("{base}.mlp.experts" , "gate_proj" , "up_proj" , "down_proj" ),
439472 "mixtral" : ("{base}.block_sparse_moe.experts" , "w1" , "w3" , "w2" ),
473+ "mistral" : ("{base}.experts" , "w1" , "w3" , "w2" ),
440474 }
441475
442476 def __init__ (self , file_path : str ):
@@ -466,14 +500,24 @@ def _detect_format(self):
466500 self ._detected_format = fmt_name
467501 print (f"[BF16SafeTensorLoader] Detected format: { fmt_name } " )
468502 return
503+ elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key :
504+ self ._detected_format = fmt_name
505+ print (f"[BF16SafeTensorLoader] Detected format: { fmt_name } " )
506+ return
469507
470508 self ._detected_format = "deepseek"
471509 print ("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek" )
472510
473- def _get_experts_prefix (self , base_key : str ) -> str :
474- """Get the experts prefix based on detected format."""
511+ def _get_experts_prefix_candidates (self , base_key : str ) -> list [ str ] :
512+ """Get candidate experts prefixes based on detected format and base key variants ."""
475513 path_tpl , _ , _ , _ = self .MOE_FORMATS [self ._detected_format ]
476- return path_tpl .format (base = base_key )
514+ candidates = [path_tpl .format (base = base_key )]
515+
516+ # Some model weights (e.g., Mistral native format) do not have "model." prefix.
517+ if base_key .startswith ("model." ):
518+ candidates .append (path_tpl .format (base = base_key [len ("model." ) :]))
519+
520+ return list (dict .fromkeys (candidates ))
477521
478522 def _get_proj_names (self ):
479523 """Get projection names (gate, up, down) based on detected format."""
@@ -497,15 +541,21 @@ def load_experts(self, base_key: str, device: str = "cpu"):
497541 if self ._detected_format == "packed" :
498542 return self ._load_experts_packed (base_key , device )
499543
500- experts_prefix = self ._get_experts_prefix (base_key )
544+ experts_prefix_candidates = self ._get_experts_prefix_candidates (base_key )
501545 gate_name , up_name , down_name = self ._get_proj_names ()
502546
503547 expert_count = 0
504- while self .has_tensor (f"{ experts_prefix } .{ expert_count } .{ gate_name } .weight" ):
505- expert_count += 1
548+ experts_prefix = None
549+ for prefix in experts_prefix_candidates :
550+ expert_count = 0
551+ while self .has_tensor (f"{ prefix } .{ expert_count } .{ gate_name } .weight" ):
552+ expert_count += 1
553+ if expert_count > 0 :
554+ experts_prefix = prefix
555+ break
506556
507- if expert_count == 0 :
508- raise ValueError (f"No experts found for key { experts_prefix } " )
557+ if expert_count == 0 or experts_prefix is None :
558+ raise ValueError (f"No experts found for keys: { experts_prefix_candidates } " )
509559
510560 gate_weights = [None ] * expert_count
511561 up_weights = [None ] * expert_count
0 commit comments