Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion kt-kernel/python/experts_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,93 @@
from abc import ABC, abstractmethod
import os
import ctypes
import warnings

import kt_kernel_ext


def _get_allowed_numa_nodes() -> Optional[List[int]]:
"""
Detect the NUMA nodes that the current process is allowed to access,
honouring any ``numactl --membind`` / ``--cpunodebind`` constraints.

Implementation notes
--------------------
``/proc/self/status`` fields ``Mems_allowed`` / ``Mems_allowed_list`` and
the ``get_mempolicy(MPOL_F_MEMS_ALLOWED)`` flag all reflect the *cpuset*
(cgroup) constraints, **not** the per-process NUMA memory policy set by
``numactl``. They always show all nodes when no cpuset restriction is
active, even if the process was launched with ``numactl --membind=1``.

The correct approach is to call ``get_mempolicy(flags=0)`` which returns
the thread's current default memory policy (mode + nodemask) as set by
``set_mempolicy`` / ``numactl``. ``numactl --membind=N`` sets the policy
to ``MPOL_BIND`` (mode=2) with the requested nodemask. When the policy is
``MPOL_DEFAULT`` (mode=0) or the nodemask is empty, no binding is active
and we return ``None`` to let the caller fall back to sequential IDs.

Returns:
Sorted list of allowed NUMA node IDs reflecting the active membind
policy, or ``None`` if no binding is active / cannot be determined
(e.g. non-Linux OS, no NUMA support, syscall failure).
"""
import platform

if platform.system() != "Linux":
return None

# SYS_get_mempolicy: 239 on x86-64
SYS_get_mempolicy = 239
Comment on lines +52 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The syscall number for get_mempolicy is hardcoded for the x86-64 architecture. This will cause the function to fail on other architectures, such as aarch64, which uses a different syscall number. To ensure portability, you should detect the machine's architecture and use the appropriate syscall number.

    # SYS_get_mempolicy: 239 on x86-64, 236 on aarch64
    arch = platform.machine()
    if arch == "x86_64":
        SYS_get_mempolicy = 239
    elif arch == "aarch64":
        SYS_get_mempolicy = 236
    else:
        warnings.warn(
            f"NUMA node detection via get_mempolicy is not supported on "
            f"architecture '{arch}'. Falling back to sequential NUMA IDs."
        )
        return None

# MPOL_DEFAULT means "no policy / local alloc" – nodemask is meaningless.
MPOL_DEFAULT = 0

# We support up to 2048 NUMA nodes (256 bytes = 2048 bits).
MAX_NODES = 2048
nodemask_ulongs = (MAX_NODES + ctypes.sizeof(ctypes.c_ulong) * 8 - 1) // (ctypes.sizeof(ctypes.c_ulong) * 8)

try:
libc = ctypes.CDLL("libc.so.6", use_errno=True)

mode = ctypes.c_int(0)
nodemask = (ctypes.c_ulong * nodemask_ulongs)()

# flags=0 : return the calling thread's default memory policy
# addr must be NULL when flags=0
ret = libc.syscall(
ctypes.c_long(SYS_get_mempolicy),
ctypes.byref(mode),
nodemask,
ctypes.c_ulong(MAX_NODES),
ctypes.c_void_p(0),
ctypes.c_ulong(0), # flags = 0
)

if ret != 0:
return None

# MPOL_DEFAULT means no explicit binding – return None so the caller
# falls back to sequential NUMA IDs.
if mode.value == MPOL_DEFAULT:
return None

# Decode the nodemask bitmask.
bits_per_ulong = ctypes.sizeof(ctypes.c_ulong) * 8
allowed: List[int] = []
for word_idx, word_val in enumerate(nodemask):
if word_val == 0:
continue
bit_base = word_idx * bits_per_ulong
for bit in range(bits_per_ulong):
if word_val & (1 << bit):
allowed.append(bit_base + bit)

# If nodemask is empty despite a non-default policy, fall back.
return sorted(allowed) if allowed else None

except Exception:
return None
Comment on lines +100 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except block catches a broad Exception and silently returns None. This can hide underlying issues, making debugging difficult. It would be better to at least log a warning to inform the user that NUMA detection failed and why.

    except Exception as e:
        warnings.warn(f"Failed to get NUMA policy via syscall: {e}. Falling back to sequential NUMA IDs.")
        return None



class KExpertsCPUBuffer:
"""
CPU buffer management for expert computation.
Expand Down Expand Up @@ -148,7 +231,33 @@ def __init__(
if BaseMoEWrapper._cpu_infer_instance is None:
worker_config = kt_kernel_ext.WorkerPoolConfig()

subpool_numa_map = list(range(threadpool_count))
# Detect the NUMA nodes actually available to this process so that
# subpool_numa_map reflects the real node IDs rather than always
# starting from 0. For example, when the process is bound to NUMA
# node 1 via numactl and threadpool_count==1, the map must be [1],
# not [0].
# _get_allowed_numa_nodes() returns None when:
# - not on Linux, or syscall fails → fall back silently
# - policy is MPOL_DEFAULT (no membind active) → sequential is correct
allowed_nodes = _get_allowed_numa_nodes()
if allowed_nodes is not None and len(allowed_nodes) >= threadpool_count:
# Explicit membind is active; use the bound node IDs directly.
subpool_numa_map = allowed_nodes[:threadpool_count]
else:
# No membind active, or cannot detect: fall back to sequential
# 0-based IDs. Warn only when we detected fewer nodes than
# threadpool_count (genuine misconfiguration).
if allowed_nodes is not None and len(allowed_nodes) < threadpool_count:
warnings.warn(
f"threadpool_count={threadpool_count} but only "
f"{len(allowed_nodes)} NUMA node(s) are in the membind "
f"policy for this process ({allowed_nodes}). "
f"Falling back to sequential NUMA IDs [0 .. {threadpool_count - 1}].",
RuntimeWarning,
stacklevel=2,
)
subpool_numa_map = list(range(threadpool_count))

subpool_thread_count = [
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)
Expand Down
Loading