Skip to content

Commit 9e69fcc

Browse files
authored
[feat]: add mistral moe loader compatibility (#1873)
Co-authored-by: chenht2022 <chenht2022@users.noreply.github.com>
1 parent 19887e4 commit 9e69fcc

File tree

2 files changed

+73
-19
lines changed

2 files changed

+73
-19
lines changed

kt-kernel/python/utils/amx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
448448
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
449449
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
450450
elif self.method == "FP8_PERCHANNEL":
451+
if self.gate_scales[0].dtype != torch.float32:
452+
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
453+
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
454+
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
451455
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
452456

453457
t2 = time.time()

kt-kernel/python/utils/loader.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
243243
Supported formats:
244244
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
245245
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
246+
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
246247
247248
Supported scale formats (auto-detected):
248249
- Block-wise: weight_scale_inv (DeepSeek FP8)
@@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
255256
MOE_FORMATS = {
256257
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
257258
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
259+
"mistral": ("{base}.experts", "w1", "w3", "w2"),
258260
}
259261

260262
def __init__(self, file_path: str, scale_suffix: str = None):
@@ -297,6 +299,10 @@ def _detect_format(self):
297299
self._detected_format = fmt_name
298300
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
299301
break
302+
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
303+
self._detected_format = fmt_name
304+
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
305+
break
300306
if self._detected_format:
301307
break
302308

@@ -321,8 +327,21 @@ def _detect_format(self):
321327
return
322328
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
323329
self._scale_suffix = "weight_scale"
324-
self._is_per_channel = True
325-
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
330+
# Some models (e.g., Mistral) use block-wise FP8 scales but keep
331+
# the key suffix as `weight_scale` (without `_inv`). Infer format
332+
# from scale tensor shape instead of suffix alone:
333+
# - per-channel: [N] or [N, 1]
334+
# - block-wise: [N_block, K_block] (both dims > 1)
335+
scale_tensor = self.load_tensor(key, device="cpu")
336+
if scale_tensor.dim() == 1:
337+
self._is_per_channel = True
338+
elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1:
339+
self._is_per_channel = True
340+
else:
341+
self._is_per_channel = False
342+
343+
scale_kind = "per-channel" if self._is_per_channel else "block-wise"
344+
print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)")
326345
return
327346
# Default to weight_scale_inv
328347
self._scale_suffix = "weight_scale_inv"
@@ -333,12 +352,20 @@ def _detect_format(self):
333352
scale_type = "per-channel" if self._is_per_channel else "block-wise"
334353
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
335354

336-
def _get_experts_prefix(self, base_key: str) -> str:
337-
"""Get the experts prefix based on detected format."""
355+
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
356+
"""Get candidate experts prefixes based on detected format and base key variants."""
338357
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
358+
candidates = []
339359
if self._is_vl_model:
340360
base_key = base_key.replace("model.layers", "model.language_model.layers")
341-
return path_tpl.format(base=base_key)
361+
candidates.append(path_tpl.format(base=base_key))
362+
363+
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
364+
if base_key.startswith("model."):
365+
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
366+
367+
# Deduplicate while preserving order.
368+
return list(dict.fromkeys(candidates))
342369

343370
def _get_proj_names(self):
344371
"""Get projection names (gate, up, down) based on detected format."""
@@ -363,15 +390,21 @@ def load_experts(self, base_key: str, device: str = "cpu"):
363390
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
364391
Per-channel scales are squeezed from [N, 1] to [N] if needed.
365392
"""
366-
experts_prefix = self._get_experts_prefix(base_key)
393+
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
367394
gate_name, up_name, down_name = self._get_proj_names()
368395

369396
expert_count = 0
370-
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
371-
expert_count += 1
397+
experts_prefix = None
398+
for prefix in experts_prefix_candidates:
399+
expert_count = 0
400+
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
401+
expert_count += 1
402+
if expert_count > 0:
403+
experts_prefix = prefix
404+
break
372405

373-
if expert_count == 0:
374-
raise ValueError(f"No experts found for key {experts_prefix}")
406+
if expert_count == 0 or experts_prefix is None:
407+
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
375408

376409
gate_weights = [None] * expert_count
377410
up_weights = [None] * expert_count
@@ -423,20 +456,21 @@ def is_per_channel(self) -> bool:
423456
return self._is_per_channel
424457

425458

426-
427459
class BF16SafeTensorLoader(SafeTensorLoader):
428460
"""Loader for native BF16 expert weights (no quantization, no scales).
429461
430462
Supported formats:
431463
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
432464
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
465+
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight
433466
434467
The format is auto-detected during initialization.
435468
"""
436469

437470
MOE_FORMATS = {
438471
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
439472
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
473+
"mistral": ("{base}.experts", "w1", "w3", "w2"),
440474
}
441475

442476
def __init__(self, file_path: str):
@@ -466,14 +500,24 @@ def _detect_format(self):
466500
self._detected_format = fmt_name
467501
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
468502
return
503+
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
504+
self._detected_format = fmt_name
505+
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
506+
return
469507

470508
self._detected_format = "deepseek"
471509
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
472510

473-
def _get_experts_prefix(self, base_key: str) -> str:
474-
"""Get the experts prefix based on detected format."""
511+
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
512+
"""Get candidate experts prefixes based on detected format and base key variants."""
475513
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
476-
return path_tpl.format(base=base_key)
514+
candidates = [path_tpl.format(base=base_key)]
515+
516+
# Some model weights (e.g., Mistral native format) do not have "model." prefix.
517+
if base_key.startswith("model."):
518+
candidates.append(path_tpl.format(base=base_key[len("model.") :]))
519+
520+
return list(dict.fromkeys(candidates))
477521

478522
def _get_proj_names(self):
479523
"""Get projection names (gate, up, down) based on detected format."""
@@ -497,15 +541,21 @@ def load_experts(self, base_key: str, device: str = "cpu"):
497541
if self._detected_format == "packed":
498542
return self._load_experts_packed(base_key, device)
499543

500-
experts_prefix = self._get_experts_prefix(base_key)
544+
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
501545
gate_name, up_name, down_name = self._get_proj_names()
502546

503547
expert_count = 0
504-
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
505-
expert_count += 1
548+
experts_prefix = None
549+
for prefix in experts_prefix_candidates:
550+
expert_count = 0
551+
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
552+
expert_count += 1
553+
if expert_count > 0:
554+
experts_prefix = prefix
555+
break
506556

507-
if expert_count == 0:
508-
raise ValueError(f"No experts found for key {experts_prefix}")
557+
if expert_count == 0 or experts_prefix is None:
558+
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")
509559

510560
gate_weights = [None] * expert_count
511561
up_weights = [None] * expert_count

0 commit comments

Comments
 (0)