Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the kt-kernel library by adding AVX2-optimized kernels for various data types and introducing support for Mixture-of-Experts (MoE) models. These changes aim to improve performance and efficiency on a wider range of hardware, particularly benefiting systems without AVX512 or AMX capabilities. The inclusion of quantization techniques further optimizes memory usage and inference speed. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces AVX2 support for BF16, FP8, and GPTQ INT4 Mixture-of-Experts (MoE) inference, expanding the compatibility of the kt-kernel library to systems without AVX512/AMX extensions. This is a significant improvement for broader hardware support. The changes include new C++ kernel implementations, utility functions, and Python bindings to integrate these new backends. Accuracy tests for BF16 and FP8 AVX2 kernels have also been added, which is excellent for ensuring correctness. However, there are several instances of magic numbers that could be replaced with named constants for better readability and maintainability. Additionally, some C++ files use printf for logging, which could be improved by using a more robust logging framework. The GPTQ INT4 GPU offload functionality is noted as not yet implemented, which is a limitation to be aware of.
| /** | ||
| * @Description : AVX2 BF16 GEMM kernel with trivial Buffer abstractions | ||
| * @Author : Claude | ||
| * @Date : 2026-03-18 |
| static constexpr int K_STEP = 8; // Process 8 K elements at a time | ||
| static constexpr int N_BLOCK = 64; // N blocking for cache | ||
| static constexpr int K_BLOCK = 256; // K blocking for cache | ||
| static constexpr double ELEMENT_SIZE = 2.0; // BF16 = 2 bytes |
There was a problem hiding this comment.
The ELEMENT_SIZE is defined as a double but represents a byte count. It would be more appropriate to use an int or size_t type for this constant, and potentially define BF16_BYTE_SIZE as a named constant.
static constexpr size_t BF16_BYTE_SIZE = 2;
static constexpr size_t ELEMENT_SIZE = BF16_BYTE_SIZE;| /** | ||
| * @Description : AVX2 BF16 utility functions (bf16<->fp32 conversion, activation) | ||
| * @Author : Claude | ||
| * @Date : 2026-03-18 |
| static inline void store_fp32_to_bf16(ggml_bf16_t* dst, __m256 src) { | ||
| __m256i i32 = _mm256_castps_si256(src); | ||
| // Round-to-nearest-even: add 0x7FFF + ((val >> 16) & 1) | ||
| __m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1)); |
There was a problem hiding this comment.
The literal 0x7FFF is a magic number. Please define it as a named constant (e.g., BF16_ROUND_MAGIC) to improve readability and maintainability.
| __m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1)); | |
| const __m256i BF16_ROUND_MAGIC = _mm256_set1_epi32(0x7FFF); | |
| __m256i round = _mm256_add_epi32(BF16_ROUND_MAGIC, tie_bit); |
| __m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)), | ||
| _mm256_set1_epi32(-126)); |
There was a problem hiding this comment.
The literal values 127 and -126 are magic numbers. Please define them as named constants (e.g., MAX_EXP_CLAMP and MIN_EXP_CLAMP) to improve readability and maintainability.
| __m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)), | |
| _mm256_set1_epi32(-126)); | |
| const __m256i MAX_EXP_CLAMP = _mm256_set1_epi32(127); | |
| const __m256i MIN_EXP_CLAMP = _mm256_set1_epi32(-126); | |
| __m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, MAX_EXP_CLAMP), | |
| MIN_EXP_CLAMP); |
| if exp <= 0: | ||
| # Subnormal | ||
| man = int(round(val * (2**6) * 8)) | ||
| man = min(man, 7) |
There was a problem hiding this comment.
| gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16) | ||
| up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16) | ||
| down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16) |
There was a problem hiding this comment.
The literal 10.0 used for scaling random weights is a magic number. Please define it as a named constant (e.g., WEIGHT_SCALING_FACTOR) to improve readability and configurability.
| gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16) | |
| up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16) | |
| down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16) | |
| WEIGHT_SCALING_FACTOR = 10.0 | |
| gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / WEIGHT_SCALING_FACTOR).to(torch.bfloat16) | |
| up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / WEIGHT_SCALING_FACTOR).to(torch.bfloat16) | |
| down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / WEIGHT_SCALING_FACTOR).to(torch.bfloat16) |
| config.quant_config.bits = 8 | ||
| config.quant_config.group_size = group_size | ||
| config.quant_config.zero_point = False |
There was a problem hiding this comment.
The literal 8 for bits and False for zero_point are magic numbers/booleans. Please define them as named constants (e.g., FP8_BITS, FP8_ZERO_POINT_ENABLED) to improve readability and configurability.
| config.quant_config.bits = 8 | |
| config.quant_config.group_size = group_size | |
| config.quant_config.zero_point = False | |
| FP8_BITS = 8 | |
| FP8_ZERO_POINT_ENABLED = False | |
| config.quant_config.bits = FP8_BITS | |
| config.quant_config.group_size = group_size | |
| config.quant_config.zero_point = FP8_ZERO_POINT_ENABLED |
| for i in range(validation_iter): | ||
| expert_ids = torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]).contiguous() | ||
| weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() | ||
| input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous() |
There was a problem hiding this comment.
The literal 100.0 used for scaling input data is a magic number. Please define it as a named constant (e.g., INPUT_SCALING_FACTOR) to improve readability and configurability.
| input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous() | |
| INPUT_SCALING_FACTOR = 100.0 | |
| input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / INPUT_SCALING_FACTOR).to(torch.bfloat16).contiguous() |
|
|
||
| diff = torch.mean(torch.abs(output.float() - t_output.float())) / (torch.mean(torch.abs(t_output.float())) + 1e-8) | ||
| print(" Iteration %d: diff = %.6f" % (i, diff.item())) | ||
| assert diff < 0.1, "FP8 accuracy test failed: diff=%.6f >= 0.1" % diff.item() |
There was a problem hiding this comment.
The literal 0.1 used as an accuracy threshold is a magic number. Please define it as a named constant (e.g., FP8_ACCURACY_THRESHOLD) to improve readability and configurability.
| assert diff < 0.1, "FP8 accuracy test failed: diff=%.6f >= 0.1" % diff.item() | |
| FP8_ACCURACY_THRESHOLD = 0.1 | |
| assert diff < FP8_ACCURACY_THRESHOLD, "FP8 accuracy test failed: diff=%.6f >= %.1f" % (diff.item(), FP8_ACCURACY_THRESHOLD) |
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
What does this PR do?
Fixes # (issue)
Before submitting