Skip to content

[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4#1892

Open
mrhaoxx wants to merge 2 commits intomainfrom
avx2
Open

[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4#1892
mrhaoxx wants to merge 2 commits intomainfrom
avx2

Conversation

@mrhaoxx
Copy link
Collaborator

@mrhaoxx mrhaoxx commented Mar 18, 2026

What does this PR do?

Fixes # (issue)

Before submitting

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the kt-kernel library by adding AVX2-optimized kernels for various data types and introducing support for Mixture-of-Experts (MoE) models. These changes aim to improve performance and efficiency on a wider range of hardware, particularly benefiting systems without AVX512 or AMX capabilities. The inclusion of quantization techniques further optimizes memory usage and inference speed.

Highlights

  • AVX2 Optimization: This PR introduces AVX2-optimized kernels for BF16, FP8, and GPTQ INT4 data types, enhancing performance on x86_64 architectures without AVX512 or AMX support.
  • MoE Support: The changes add support for Mixture-of-Experts (MoE) models, enabling efficient handling of large models by distributing computation across multiple experts.
  • Quantization: The PR incorporates quantization techniques (FP8, GPTQ INT4) to reduce memory footprint and accelerate inference, balancing accuracy and speed.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces AVX2 support for BF16, FP8, and GPTQ INT4 Mixture-of-Experts (MoE) inference, expanding the compatibility of the kt-kernel library to systems without AVX512/AMX extensions. This is a significant improvement for broader hardware support. The changes include new C++ kernel implementations, utility functions, and Python bindings to integrate these new backends. Accuracy tests for BF16 and FP8 AVX2 kernels have also been added, which is excellent for ensuring correctness. However, there are several instances of magic numbers that could be replaced with named constants for better readability and maintainability. Additionally, some C++ files use printf for logging, which could be improved by using a more robust logging framework. The GPTQ INT4 GPU offload functionality is noted as not yet implemented, which is a limitation to be aware of.

/**
* @Description : AVX2 BF16 GEMM kernel with trivial Buffer abstractions
* @Author : Claude
* @Date : 2026-03-18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @Date in the file header is set to a future date. Please update this to the current date or remove it if it's not meant to be dynamic.

static constexpr int K_STEP = 8; // Process 8 K elements at a time
static constexpr int N_BLOCK = 64; // N blocking for cache
static constexpr int K_BLOCK = 256; // K blocking for cache
static constexpr double ELEMENT_SIZE = 2.0; // BF16 = 2 bytes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ELEMENT_SIZE is defined as a double but represents a byte count. It would be more appropriate to use an int or size_t type for this constant, and potentially define BF16_BYTE_SIZE as a named constant.

  static constexpr size_t BF16_BYTE_SIZE = 2;
  static constexpr size_t ELEMENT_SIZE = BF16_BYTE_SIZE;

/**
* @Description : AVX2 BF16 utility functions (bf16<->fp32 conversion, activation)
* @Author : Claude
* @Date : 2026-03-18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @Date in the file header is set to a future date. Please update this to the current date or remove it if it's not meant to be dynamic.

static inline void store_fp32_to_bf16(ggml_bf16_t* dst, __m256 src) {
__m256i i32 = _mm256_castps_si256(src);
// Round-to-nearest-even: add 0x7FFF + ((val >> 16) & 1)
__m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 0x7FFF is a magic number. Please define it as a named constant (e.g., BF16_ROUND_MAGIC) to improve readability and maintainability.

Suggested change
__m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1));
const __m256i BF16_ROUND_MAGIC = _mm256_set1_epi32(0x7FFF);
__m256i round = _mm256_add_epi32(BF16_ROUND_MAGIC, tie_bit);

Comment on lines +104 to +105
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)),
_mm256_set1_epi32(-126));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal values 127 and -126 are magic numbers. Please define them as named constants (e.g., MAX_EXP_CLAMP and MIN_EXP_CLAMP) to improve readability and maintainability.

Suggested change
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)),
_mm256_set1_epi32(-126));
const __m256i MAX_EXP_CLAMP = _mm256_set1_epi32(127);
const __m256i MIN_EXP_CLAMP = _mm256_set1_epi32(-126);
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, MAX_EXP_CLAMP),
MIN_EXP_CLAMP);

if exp <= 0:
# Subnormal
man = int(round(val * (2**6) * 8))
man = min(man, 7)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 0x7E is a magic number representing a specific FP8 value. Please define it as a named constant (e.g., FP8_MAX_FINITE_BYTE) to improve readability.

Suggested change
man = min(man, 7)
FP8_MAX_FINITE_BYTE = 0x7E
if exp >= 15:
return (sign << 7) | FP8_MAX_FINITE_BYTE # clamp to max

Comment on lines +160 to +162
gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 10.0 used for scaling random weights is a magic number. Please define it as a named constant (e.g., WEIGHT_SCALING_FACTOR) to improve readability and configurability.

Suggested change
gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
WEIGHT_SCALING_FACTOR = 10.0
gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / WEIGHT_SCALING_FACTOR).to(torch.bfloat16)
up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / WEIGHT_SCALING_FACTOR).to(torch.bfloat16)
down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / WEIGHT_SCALING_FACTOR).to(torch.bfloat16)

Comment on lines +202 to +204
config.quant_config.bits = 8
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 8 for bits and False for zero_point are magic numbers/booleans. Please define them as named constants (e.g., FP8_BITS, FP8_ZERO_POINT_ENABLED) to improve readability and configurability.

Suggested change
config.quant_config.bits = 8
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
FP8_BITS = 8
FP8_ZERO_POINT_ENABLED = False
config.quant_config.bits = FP8_BITS
config.quant_config.group_size = group_size
config.quant_config.zero_point = FP8_ZERO_POINT_ENABLED

for i in range(validation_iter):
expert_ids = torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]).contiguous()
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 100.0 used for scaling input data is a magic number. Please define it as a named constant (e.g., INPUT_SCALING_FACTOR) to improve readability and configurability.

Suggested change
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous()
INPUT_SCALING_FACTOR = 100.0
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / INPUT_SCALING_FACTOR).to(torch.bfloat16).contiguous()


diff = torch.mean(torch.abs(output.float() - t_output.float())) / (torch.mean(torch.abs(t_output.float())) + 1e-8)
print(" Iteration %d: diff = %.6f" % (i, diff.item()))
assert diff < 0.1, "FP8 accuracy test failed: diff=%.6f >= 0.1" % diff.item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 0.1 used as an accuracy threshold is a magic number. Please define it as a named constant (e.g., FP8_ACCURACY_THRESHOLD) to improve readability and configurability.

Suggested change
assert diff < 0.1, "FP8 accuracy test failed: diff=%.6f >= 0.1" % diff.item()
FP8_ACCURACY_THRESHOLD = 0.1
assert diff < FP8_ACCURACY_THRESHOLD, "FP8 accuracy test failed: diff=%.6f >= %.1f" % (diff.item(), FP8_ACCURACY_THRESHOLD)

@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant