Skip to content

support GLM 4.7#1791

Merged
ouqingliang merged 6 commits intomainfrom
fp8_channel_moe
Jan 13, 2026
Merged

support GLM 4.7#1791
ouqingliang merged 6 commits intomainfrom
fp8_channel_moe

Conversation

@ouqingliang
Copy link
Contributor

What does this PR do?

support GLM 4.7, including fp8 and bf16.

Before submitting

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ouqingliang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the kt-kernel by adding comprehensive support for GLM 4.7, focusing on the implementation of an FP8 Per-Channel Mixture-of-Experts (MoE) kernel. This new capability allows for more efficient inference with FP8 weights and BF16 activations. The changes span across C++ kernel implementations, Python bindings, and associated benchmarking and testing infrastructure, ensuring the new quantization mode is fully integrated and validated.

Highlights

  • GLM 4.7 Support: Introduced support for GLM 4.7, specifically enabling FP8 and BF16 quantization modes.
  • FP8 Per-Channel MoE Kernel: Added a new FP8 Per-Channel Mixture-of-Experts (MoE) kernel, designed for GLM-4.7-FP8 style quantization, which uses FP8 weights with per-channel scaling and BF16 activations, leveraging AVX-512 DPBF16 compute paths.
  • Benchmarking and Testing: Integrated new benchmark and test scripts for the FP8 Per-Channel MoE kernel, alongside updates to existing MoE benchmarks to include the new quantization mode and adjust parameters for more realistic scenarios.
  • C++ Backend Enhancements: Modified C++ bindings to register the new FP8 Per-Channel MoE operator and refactored the write_weight_scale_to_buffer_task binding for greater generality. Core AMX buffer and kernel implementations were extended to support per-channel scaling.
  • Python Loader Updates: Updated Python utility classes (amx.py, loader.py) to recognize and correctly load models using the new FP8 Per-Channel quantization scheme, including auto-detection of scale formats.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for GLM 4.7, including fp8 per-channel quantization and bf16. The changes are comprehensive, touching C++ kernels, Python bindings, and adding new benchmarks and tests. The C++ implementation shows good use of modern features and performance considerations. However, I've identified a few issues that should be addressed. There's a duplicated class definition in kt-kernel/python/utils/loader.py which needs to be resolved. Additionally, a C++ kernel function in kt-kernel/operators/amx/fp8-perchannel-moe.hpp appears to have a non-vectorized implementation that could be a performance bottleneck. I've also left some comments on the new test and benchmark scripts regarding minor improvements for efficiency and code style.

Comment on lines +204 to +242
static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) {
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
constexpr int K_STEP = T::K_STEP; // 32

// Reinterpret as uint64 arrays for efficient access
const uint64_t* src0 = reinterpret_cast<const uint64_t*>(src[0]);
const uint64_t* src1 = reinterpret_cast<const uint64_t*>(src[1]);
const uint64_t* src2 = reinterpret_cast<const uint64_t*>(src[2]);
const uint64_t* src3 = reinterpret_cast<const uint64_t*>(src[3]);

// Process all 32 rows, writing 128 bytes (4 x 32) per row
for (int packed_i = 0; packed_i < 8; packed_i++) {
const int base_row = row_map[packed_i];

// Process 4 rows at a time
for (int r = 0; r < 4; r++) {
uint16_t* row_dst = reinterpret_cast<uint16_t*>(dst + (size_t)(base_row + r) * dst_row_stride);
const int shift = r * 16;

// Unroll: process all 4 blocks x 16 columns = 64 uint16 values
// Block 0: columns 0-15
for (int j = 0; j < 16; j++) {
row_dst[j] = static_cast<uint16_t>(src0[8 * j + packed_i] >> shift);
}
// Block 1: columns 16-31
for (int j = 0; j < 16; j++) {
row_dst[16 + j] = static_cast<uint16_t>(src1[8 * j + packed_i] >> shift);
}
// Block 2: columns 32-47
for (int j = 0; j < 16; j++) {
row_dst[32 + j] = static_cast<uint16_t>(src2[8 * j + packed_i] >> shift);
}
// Block 3: columns 48-63
for (int j = 0; j < 16; j++) {
row_dst[48 + j] = static_cast<uint16_t>(src3[8 * j + packed_i] >> shift);
}
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of unpack_4nk_blocks uses scalar loops, which is likely to be a performance bottleneck. This function should be vectorized using AVX512 intrinsics, similar to how unpack_nk_block is implemented, to improve performance. The current implementation reads and writes data element by element, which underutilizes the CPU's vector processing capabilities.

Comment on lines +419 to +509
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).

Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight

The format is auto-detected during initialization.
"""

MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}

def __init__(self, file_path: str):
super().__init__(file_path)
self._detected_format = None
self._detect_format()

def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
sample_keys = list(self.tensor_file_map.keys())[:1000]

for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
for key in sample_keys:
if ".experts." in key and f".{gate}.weight" in key:
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return

self._detected_format = "deepseek"
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")

def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)

def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
return gate, up, down

def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if device == "cpu":
return tensor
return tensor.to(device)

def load_experts(self, base_key: str, device: str = "cpu"):
"""Load BF16 expert weights (no scales needed)."""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()

expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1

if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")

gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count

for exp_id in range(expert_count):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"

gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()

return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This adds a new definition for BF16SafeTensorLoader, but another definition for the same class already exists starting at line 512. This duplication will lead to runtime errors or unexpected behavior. Please remove one of the definitions, presumably the older one if this new one is the intended implementation.

Comment on lines +105 to +110
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()

# Generate random per-channel scales (one per output row)
# Use reasonable scale range (e.g., 2^-8 to 2^8)
exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cuda").to("cpu").contiguous()
scales = (2.0 ** exponents.float()).to(torch.float32).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating tensors on a CUDA device and then immediately moving them to the CPU is inefficient. It's better to create these tensors directly on the CPU to avoid unnecessary device-to-host data transfers, especially since they are used for CPU-side operations.

Suggested change
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()
# Generate random per-channel scales (one per output row)
# Use reasonable scale range (e.g., 2^-8 to 2^8)
exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cuda").to("cpu").contiguous()
scales = (2.0 ** exponents.float()).to(torch.float32).contiguous()
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cpu").contiguous()
# Generate random per-channel scales (one per output row)
# Use reasonable scale range (e.g., 2^-8 to 2^8)
exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cpu").contiguous()
scales = (2.0 ** exponents.float()).to(torch.float32).contiguous()

Comment on lines +157 to +161
gate_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a tensor on the CUDA device and then immediately moving it to the CPU is inefficient. It's better to create the tensor directly on the CPU to avoid an unnecessary device-to-host data transfer. This applies to all tensor creations in this function.

    gate_q = (
        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cpu")
        .contiguous()
    )

with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
torch.manual_seed(114514 + i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The random seed 114514 appears to be a meme number. While it doesn't affect correctness, using more standard or neutral numbers for seeds (like 42 or 2024) improves the professionalism and maintainability of the code.

Suggested change
torch.manual_seed(114514 + i)
torch.manual_seed(42 + i)

import os
import sys

sys.path.insert(0, os.path.dirname(__file__) + "/../build")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using os.path.join is more robust for constructing file paths as it handles path separators correctly across different operating systems.

Suggested change
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))

val = min(val, FP8_E4M3_MAX)

# Find exponent
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a best practice to place all imports at the top of the file. This makes it easier to see what modules the script depends on.

# Find exponent
import math

if val < 2**-9: # Subnormal threshold
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 2**-9 is a magic number. It would be clearer to define it as a named constant with a comment explaining its significance (e.g., FP8_E4M3_SUBNORMAL_THRESHOLD). This also applies to its use on line 200.

@ouqingliang ouqingliang merged commit 6277da4 into main Jan 13, 2026
7 of 9 checks passed
@ouqingliang ouqingliang deleted the fp8_channel_moe branch January 23, 2026 06:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant