Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kt-kernel/python/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __new__(
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
):
"""
Factory method to create the appropriate backend implementation.
Expand All @@ -85,6 +86,7 @@ def __new__(
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
numa_nodes: Explicit list of NUMA node IDs for subpool mapping. If None, defaults to sequential.
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8")

Returns:
Expand Down Expand Up @@ -117,6 +119,7 @@ def __new__(
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)

# Forward static methods to the base class
Expand Down
13 changes: 12 additions & 1 deletion kt-kernel/python/experts_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
):
"""
Initialize base MoE Wrapper.
Expand All @@ -185,6 +186,8 @@ def __init__(
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
method: Backend method string
numa_nodes: Explicit list of NUMA node IDs for subpool mapping.
If None, defaults to [0, 1, ..., threadpool_count-1].
"""
self.layer_idx = layer_idx
self.num_experts = num_experts
Expand Down Expand Up @@ -221,7 +224,15 @@ def __init__(
if BaseMoEWrapper._cpu_infer_instance is None:
worker_config = kt_kernel_ext.WorkerPoolConfig()

subpool_numa_map = list(range(threadpool_count))
if numa_nodes is not None:
if len(numa_nodes) != threadpool_count:
raise ValueError(
f"numa_nodes length ({len(numa_nodes)}) must match "
f"threadpool_count ({threadpool_count})"
)
Comment on lines +227 to +232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation logic for numa_nodes could be improved by adding a check to ensure that the provided NUMA node IDs are valid for the system. This could prevent issues if a user provides an ID that doesn't exist.

            if numa_nodes is not None:
                if len(numa_nodes) != threadpool_count:
                    raise ValueError(
                        f"numa_nodes length ({len(numa_nodes)}) must match "
                        f"threadpool_count ({threadpool_count})"
                    )
                if any(node_id >= numa_num_configured_nodes() for node_id in numa_nodes):
                    raise ValueError(
                        f"Invalid NUMA node ID found in numa_nodes. "
                        f"Node IDs must be less than {numa_num_configured_nodes()}."
                    )
                subpool_numa_map = list(numa_nodes)

subpool_numa_map = list(numa_nodes)
else:
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)
Expand Down
18 changes: 15 additions & 3 deletions kt-kernel/python/utils/amx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
import ctypes
from typing import Optional
from typing import List, Optional

# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
numa_nodes: Optional[List[int]] = None,
):
"""
Initialize AMX MoE Wrapper.
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)

# AMX-specific: Check if we should load merged safetensor weights
Expand Down Expand Up @@ -282,7 +284,11 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
moe_config.save = True
moe_config.load = False
base_key = f"model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
try:
w = self.safetensor_loader.load_experts(base_key)
except (ValueError, KeyError):
base_key = f"model.language_model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)

self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
Expand Down Expand Up @@ -379,6 +385,7 @@ def __init__(
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)

if NativeMoEWrapper._native_loader_instance is None:
Expand Down Expand Up @@ -416,7 +423,12 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):

t0 = time.time()
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
try:
weights = self.loader.load_experts(base_key)
except (ValueError, KeyError):
# For VL/multimodal models (e.g. Qwen3.5) with 'language_model' prefix
base_key = f"model.language_model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
t1 = time.time()

# Keep individual tensors instead of stacking - avoid expensive memory copy
Expand Down
3 changes: 2 additions & 1 deletion kt-kernel/python/utils/llamafile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Optional
from typing import List, Optional
import os

# Use relative imports for package structure
Expand Down Expand Up @@ -133,6 +133,7 @@ def __init__(
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)

self.weights_to_keep = None
Expand Down
3 changes: 2 additions & 1 deletion kt-kernel/python/utils/moe_kernel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
import ctypes
from typing import Optional
from typing import List, Optional

# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
Expand Down Expand Up @@ -97,6 +97,7 @@ def __init__(
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
numa_nodes=numa_nodes,
)

# moe-specific: Check if we should load merged safetensor weights
Expand Down
Loading