Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
681 changes: 681 additions & 0 deletions GPT_OSS_KERNEL_WORK.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions kt-kernel/ext_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,12 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
.DEF_PTR_PROPERTY(GeneralMOEConfig, up_zero)
.DEF_PTR_PROPERTY(GeneralMOEConfig, down_zero)

.DEF_PTR_PROPERTY(GeneralMOEConfig, gate_bias)
.DEF_PTR_PROPERTY(GeneralMOEConfig, up_bias)
.DEF_PTR_PROPERTY(GeneralMOEConfig, down_bias)
.def_readwrite("gemm1_alpha", &GeneralMOEConfig::gemm1_alpha)
.def_readwrite("gemm1_clamp_limit", &GeneralMOEConfig::gemm1_clamp_limit)

.def_readwrite("quant_config", &GeneralMOEConfig::quant_config)

.def_readwrite("max_len", &GeneralMOEConfig::max_len)
Expand Down
21 changes: 21 additions & 0 deletions kt-kernel/operators/amx/la/amx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
return _mm512_mul_ps(act_val, up_val);
}

// GPT-OSS activation: gate * sigmoid(gate * alpha) * (up + 1)
// with asymmetric clamping: gate clamped to max=limit, up clamped to [-limit, limit]
static inline __m512 act_fn_alpha(__m512 gate_val, __m512 up_val,
__m512 alpha, __m512 limit) {
// Clamp: gate max=limit, up both sides
gate_val = _mm512_min_ps(gate_val, limit);
__m512 neg_limit = _mm512_sub_ps(_mm512_setzero_ps(), limit);
up_val = _mm512_max_ps(_mm512_min_ps(up_val, limit), neg_limit);

// sigmoid(gate * alpha) = 1 / (1 + exp(-gate * alpha))
__m512 neg_scaled = _mm512_sub_ps(_mm512_setzero_ps(), _mm512_mul_ps(gate_val, alpha));
neg_scaled = _mm512_min_ps(neg_scaled, _mm512_set1_ps(88.0f)); // avoid exp overflow
__m512 exp_neg = exp_avx512(neg_scaled);
__m512 sigmoid_val = _mm512_div_ps(_mm512_set1_ps(1.0f),
_mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg));

// gate * sigmoid * (up + 1)
__m512 up_plus_1 = _mm512_add_ps(up_val, _mm512_set1_ps(1.0f));
return _mm512_mul_ps(_mm512_mul_ps(gate_val, sigmoid_val), up_plus_1);
}

#define AMX_DISPATCH_QTYPES(QA, QB, ...) \
[&] { \
switch (QB) { \
Expand Down
82 changes: 79 additions & 3 deletions kt-kernel/operators/amx/moe_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ class AMX_MOE_BASE {
}
#endif

apply_down_bias(activated_expert, qlen);

pool->do_work_stealing_job(
qlen, nullptr,
[this, output, k, expert_ids, weights](int i) {
Expand Down Expand Up @@ -603,6 +605,8 @@ class AMX_MOE_BASE {
}
#endif

apply_down_bias(activated_expert, qlen);

for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
Expand Down Expand Up @@ -674,19 +678,61 @@ class AMX_MOE_BASE {

void apply_activation(int activated_expert, int nth, int qlen) {
auto pool = config_.pool->get_subpool(tp_part_idx);
auto fn = [this, nth](int task_id) {
const bool has_bias = (config_.gate_bias != nullptr);
const bool use_alpha = (config_.gemm1_alpha > 0.0f);

auto fn = [this, nth, has_bias, use_alpha](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);

// Bias pointers for this expert (if present)
// intermediate_size is the sharded (per-TP-part) size; bias tensors are unsharded,
// so stride by the full intermediate size and offset to this TP part's slice.
int tp_count = config_.pool->config.subpool_count;
size_t full_intermediate_size = (size_t)config_.intermediate_size * tp_count;
size_t tp_offset = (size_t)tp_part_idx * config_.intermediate_size;

ggml_bf16_t* gate_bias_ptr = has_bias ?
(ggml_bf16_t*)config_.gate_bias + (size_t)expert_idx * full_intermediate_size + tp_offset : nullptr;
ggml_bf16_t* up_bias_ptr = has_bias ?
(ggml_bf16_t*)config_.up_bias + (size_t)expert_idx * full_intermediate_size + tp_offset : nullptr;

// Alpha activation constants (set once, reused per element)
__m512 alpha_vec, limit_vec;
if (use_alpha) {
alpha_vec = _mm512_set1_ps(config_.gemm1_alpha);
limit_vec = _mm512_set1_ps(config_.gemm1_clamp_limit);
}

for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);

// Add biases
if (has_bias) {
__m512 gb0, gb1, ub0, ub1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_bias_ptr + j), &gb0, &gb1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_bias_ptr + j), &ub0, &ub1);
gate_val0 = _mm512_add_ps(gate_val0, gb0);
gate_val1 = _mm512_add_ps(gate_val1, gb1);
up_val0 = _mm512_add_ps(up_val0, ub0);
up_val1 = _mm512_add_ps(up_val1, ub1);
}

// Activation
__m512 result0, result1;
if (use_alpha) {
result0 = amx::act_fn_alpha(gate_val0, up_val0, alpha_vec, limit_vec);
result1 = amx::act_fn_alpha(gate_val1, up_val1, alpha_vec, limit_vec);
} else {
result0 = amx::act_fn(gate_val0, up_val0);
result1 = amx::act_fn(gate_val1, up_val1);
}
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
Expand All @@ -704,6 +750,36 @@ class AMX_MOE_BASE {
pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);
}
}

void apply_down_bias(int activated_expert, int qlen) {
// Only apply down_bias on tp_part 0 — merge_results sums all TP parts,
// so applying on every part would multiply bias by tp_count.
if (config_.down_bias == nullptr || tp_part_idx != 0) return;
auto pool = config_.pool->get_subpool(tp_part_idx);

auto fn = [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
ggml_bf16_t* bias_ptr = (ggml_bf16_t*)config_.down_bias +
(size_t)expert_idx * config_.hidden_size;
Comment on lines +759 to +760
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

Similar to the issue in apply_activation, this function accesses config_.down_bias using an offset derived from expert_idx without validating the underlying buffer size. This can lead to an out-of-bounds read if the bias tensor provided by the model is smaller than expected.

for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* out = m_local_down_output_ptr_[expert_idx] + i * config_.hidden_size;
for (int j = 0; j < config_.hidden_size; j += 32) {
__m512 o0, o1, b0, b1;
avx512_32xbf16_to_32xfp32((__m512i*)(out + j), &o0, &o1);
avx512_32xbf16_to_32xfp32((__m512i*)(bias_ptr + j), &b0, &b1);
o0 = _mm512_add_ps(o0, b0);
o1 = _mm512_add_ps(o1, b1);
avx512_32xfp32_to_32xbf16(&o0, &o1, (__m512i*)(out + j));
}
}
};

if (qlen < 10) {
for (int i = 0; i < activated_expert; i++) fn(i);
} else {
pool->do_work_stealing_job(activated_expert, nullptr, fn, nullptr);
}
}
};

// ============================================================================
Expand Down
10 changes: 10 additions & 0 deletions kt-kernel/operators/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,16 @@ struct GeneralMOEConfig {
void* up_zero;
void* down_zero;

// Expert biases (nullable — only used by models like gpt-oss)
// Layout: [expert_num, size] contiguous BF16, same expert ordering as weights
void* gate_bias = nullptr; // [expert_num, intermediate_size] bf16
void* up_bias = nullptr; // [expert_num, intermediate_size] bf16
void* down_bias = nullptr; // [expert_num, hidden_size] bf16

// Activation parameters (0 = standard SiLU)
float gemm1_alpha = 0.0f; // GPT-OSS: 1.702
float gemm1_clamp_limit = 0.0f; // GPT-OSS: 7.0

QuantConfig quant_config;

// for amx
Expand Down
21 changes: 19 additions & 2 deletions kt-kernel/python/utils/amx.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ def load_weights_from_tensors(
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
import time

# Propagate interleaved gate/up flag to loader (set by kt_ep_wrapper for gpt-oss)
if getattr(self, '_interleaved_gate_up', False):
self.loader._force_interleaved = True

t0 = time.time()
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
Expand Down Expand Up @@ -469,11 +473,16 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
t3 = time.time()

# Use actual tensor dimensions for MOEConfig — hidden_size may have been
# padded by GPU kernels (e.g. FlashInfer MXFP4 rounds up to 256), but
# CPU weights use the real model dimensions.
actual_hidden = self.gate_weights[0].shape[1]
actual_intermediate = self.gate_weights[0].shape[0]
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
actual_hidden,
actual_intermediate,
self.gpu_experts_mask.data_ptr(),
)
moe_config.layer_idx = self.layer_idx
Expand All @@ -488,6 +497,14 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs

# Pass expert biases if set on this wrapper (by kt_ep_wrapper)
if hasattr(self, '_gate_bias_tensor') and self._gate_bias_tensor is not None:
moe_config.gate_bias = self._gate_bias_tensor.data_ptr()
moe_config.up_bias = self._up_bias_tensor.data_ptr()
moe_config.down_bias = self._down_bias_tensor.data_ptr()
moe_config.gemm1_alpha = self._gemm1_alpha
moe_config.gemm1_clamp_limit = self._gemm1_clamp_limit

# Infer group_size from scale shape (column-major layout)
# For gate/up projection: in_features = hidden_size
# So: group_size = hidden_size / scale.shape[1]
Expand Down
110 changes: 107 additions & 3 deletions kt-kernel/python/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,46 @@
from gguf.gguf_reader import GGUFReader


def _dequant_mxfp4_to_bf16(w_blocks, w_scales):
"""Dequantize MXFP4 packed weights to BF16.

Args:
w_blocks: [*, num_blocks, 16] uint8 — packed 4-bit values (2 per byte)
w_scales: [*, num_blocks] uint8 — E8M0 block scales
Returns:
[*, num_blocks * 32] bfloat16 tensor
"""
# Unfuse uint8 → two uint4 values
low = w_blocks & 0x0F # even indices
high = (w_blocks >> 4) & 0x0F # odd indices
# Interleave: [*, num_blocks, 32]
shape = list(w_blocks.shape)
shape[-1] = shape[-1] * 2
unfused = torch.zeros(shape, dtype=torch.uint8, device=w_blocks.device)
unfused[..., 0::2] = low
unfused[..., 1::2] = high
del low, high

# E2M1 lookup: 3-bit magnitude → float value (use int32 not int64 to save memory)
E2M1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.float32, device=w_blocks.device)
sign = 1.0 - 2.0 * ((unfused & 0b1000) >> 3).float()
magnitude = (unfused & 0b0111).to(torch.int32)
del unfused
x_float = E2M1_values[magnitude]
del magnitude
x_float = sign * x_float
del sign

# Apply E8M0 block scales: scale = 2^(e8m0 - 127)
scale_factor = torch.exp2(w_scales.float() - 127.0)
# Broadcast: scale is [*, num_blocks], x_float is [*, num_blocks, 32]
x_float = x_float * scale_factor.unsqueeze(-1)
del scale_factor

# Flatten blocks dimension: [*, num_blocks * 32]
return x_float.reshape(list(w_blocks.shape[:-2]) + [w_blocks.shape[-2] * 32]).to(torch.bfloat16)


class GGMLQuantizationType(IntEnum):
"""GGML quantization type enumeration"""

Expand Down Expand Up @@ -543,6 +583,13 @@ def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
sample_keys = list(self.tensor_file_map.keys())[:1000]

# Check for MXFP4 packed format (gpt-oss style: gate_up_proj_blocks)
for key in sample_keys:
if key.endswith(".mlp.experts.gate_up_proj_blocks"):
self._detected_format = "mxfp4_packed"
print("[BF16SafeTensorLoader] Detected format: mxfp4_packed (gpt-oss style)")
return

# Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor)
for key in sample_keys:
if key.endswith(".mlp.experts.gate_up_proj"):
Expand Down Expand Up @@ -589,6 +636,8 @@ def load_tensor(self, key: str, device: str = "cpu"):

def load_experts(self, base_key: str, device: str = "cpu"):
"""Load BF16 expert weights (no scales needed)."""
if self._detected_format == "mxfp4_packed":
return self._load_experts_mxfp4_packed(base_key, device)
if self._detected_format == "packed":
return self._load_experts_packed(base_key, device)

Expand Down Expand Up @@ -621,6 +670,49 @@ def load_experts(self, base_key: str, device: str = "cpu"):
"down": down_weights,
}

def _load_experts_mxfp4_packed(self, base_key: str, device: str = "cpu"):
"""Load MXFP4-packed expert weights (gpt-oss style), dequantize to BF16.

Dequantizes one expert at a time to avoid OOM from intermediate tensors.
"""
experts_prefix = f"{base_key}.mlp.experts"

# Try alternate prefix for VL models
if not self.has_tensor(f"{experts_prefix}.gate_up_proj_blocks"):
parts = base_key.split(".", 1)
if len(parts) == 2:
alt_base = f"{parts[0]}.language_model.{parts[1]}"
experts_prefix = f"{alt_base}.mlp.experts"
if not self.has_tensor(f"{experts_prefix}.gate_up_proj_blocks"):
raise ValueError(f"No MXFP4 packed experts found for base_key '{base_key}'")

# Load packed MXFP4 tensors (kept as-is, ~1.7 GB total)
gate_up_blocks = self.load_tensor(f"{experts_prefix}.gate_up_proj_blocks", device)
gate_up_scales = self.load_tensor(f"{experts_prefix}.gate_up_proj_scales", device)

num_experts = gate_up_blocks.shape[0]
mid = gate_up_blocks.shape[1] // 2 # Split fused gate+up

# Dequantize gate_up per-expert to avoid ~40 GB intermediate spike
gate_list = [None] * num_experts
up_list = [None] * num_experts
for i in range(num_experts):
expert_bf16 = _dequant_mxfp4_to_bf16(gate_up_blocks[i], gate_up_scales[i])
gate_list[i] = expert_bf16[:mid, :].contiguous()
up_list[i] = expert_bf16[mid:, :].contiguous()
del expert_bf16
del gate_up_blocks, gate_up_scales

# Dequantize down per-expert
down_blocks = self.load_tensor(f"{experts_prefix}.down_proj_blocks", device)
down_scales = self.load_tensor(f"{experts_prefix}.down_proj_scales", device)
down_list = [None] * num_experts
for i in range(num_experts):
down_list[i] = _dequant_mxfp4_to_bf16(down_blocks[i], down_scales[i]).contiguous()
del down_blocks, down_scales

return {"gate": gate_list, "up": up_list, "down": down_list}

def _resolve_packed_experts_prefix(self, base_key: str) -> str:
"""Resolve the experts prefix for packed format, trying fallbacks."""
# Direct: model.layers.{N}.mlp.experts
Expand All @@ -644,6 +736,11 @@ def _load_experts_packed(self, base_key: str, device: str = "cpu"):
Packed format stores all experts in stacked 3D tensors:
- gate_up_proj: [num_experts, 2 * intermediate_size, hidden_size]
- down_proj: [num_experts, hidden_size, intermediate_size]

Two layouts exist for gate_up_proj:
- Concatenated (Qwen3.5): first half = gate, second half = up
- Interleaved (gpt-oss): even rows = gate, odd rows = up
Detected by presence of gate_up_proj_bias in the tensor index.
"""
experts_prefix = self._resolve_packed_experts_prefix(base_key)

Expand All @@ -653,9 +750,16 @@ def _load_experts_packed(self, base_key: str, device: str = "cpu"):
gate_up = self.load_tensor(gate_up_key, device) # [E, 2*I, H]
down = self.load_tensor(down_key, device) # [E, H, I]

mid = gate_up.shape[1] // 2
gate_list = [gate_up[i, :mid, :].contiguous() for i in range(gate_up.shape[0])]
up_list = [gate_up[i, mid:, :].contiguous() for i in range(gate_up.shape[0])]
# Detect interleaved layout (gpt-oss): set by wrapper via _force_interleaved flag
interleaved = getattr(self, '_force_interleaved', False)

if interleaved:
gate_list = [gate_up[i, ::2, :].contiguous() for i in range(gate_up.shape[0])]
up_list = [gate_up[i, 1::2, :].contiguous() for i in range(gate_up.shape[0])]
else:
mid = gate_up.shape[1] // 2
gate_list = [gate_up[i, :mid, :].contiguous() for i in range(gate_up.shape[0])]
up_list = [gate_up[i, mid:, :].contiguous() for i in range(gate_up.shape[0])]
down_list = [down[i].contiguous() for i in range(down.shape[0])]

return {
Expand Down