deepseek && qwen tp performance tuning#934
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on significant performance optimizations for Deepseek and Qwen models, particularly within Mixture of Experts (MoE) and Tensor Parallelism (TP) contexts. It introduces new, highly-tuned Triton kernels for grouped GEMM operations, refines weight fusion and data handling for MoE, and enhances FP8 quantization. Additionally, it includes specialized Rotary Embedding kernels and integrates FlashInfer for improved CUDA architecture support, all aimed at boosting inference speed and efficiency.
Highlights
- MoE Kernel Tuning: New kernel configurations are added for grouped MoE GEMM operations, specifically optimized for NVIDIA H100 and H200 GPUs across various model dimensions and quantization settings.
- MoE Weight Fusion Optimization: The process of fusing MoE expert weights and their scales is made more memory-efficient by pre-allocating tensors and using direct slicing instead of concatenation and flattening.
- Grouped MoE GEMM Data Handling: The
moe_align2andgrouped_matmulfunctions are refactored to improve data access patterns and reduce data movement during grouped GEMM computations. - SiLU and Mul Kernel Optimization: A new, faster Triton kernel (
_silu_and_mul_kernel_fast) is introduced for the SiLU and Mul activation function, specifically benefiting smaller batch sizes in MoE layers. - FP8 Quantization Enhancements: The FP8 input quantization process is made more flexible and efficient, particularly concerning the handling and alignment of scales.
- Rotary Embedding Kernel Split: The Rotary Embedding kernel is split into separate, optimized kernels for Query and Key tensors, allowing for more fine-grained performance tuning.
- FlashInfer Integration: Conditional import and usage of FlashInfer are added to dynamically set CUDA architecture flags, improving compatibility and performance with FlashInfer-enabled operations.
- Benchmark and Tuning Updates: New benchmarks are introduced for fused MoE Triton kernels, and existing tuning scripts are updated to support new FP8 quantization parameters and DeepseekV2 large model configurations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
52dc131 to
6362c4a
Compare
There was a problem hiding this comment.
Code Review
This pull request focuses on performance tuning for MoE layers, particularly for DeepSeek and Qwen models, with significant enhancements to FP8 quantization and Triton kernel optimizations. Key changes include refactoring MoE data alignment and GEMM kernels, optimizing weight fusing, specializing rotary embedding kernels, and updating quantization utilities. A new benchmark for fused MoE has also been added. The changes are generally aimed at improving efficiency and memory usage.
| gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape | ||
| up_out_dim, up_in_dim = self.experts_up_projs[0].shape | ||
| assert gate_in_dim == up_in_dim | ||
| dtype = self.experts_gate_projs[0].dtype | ||
| total_expert_num = self.n_routed_experts | ||
|
|
||
| w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") | ||
|
|
||
| for i_experts in range(self.n_routed_experts): | ||
| expert_gate_up_proj = torch.cat( | ||
| [self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0 | ||
| ) | ||
| expert_gate_up_proj = expert_gate_up_proj | ||
| w1_list.append(expert_gate_up_proj) | ||
|
|
||
| inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1] | ||
| w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) | ||
| w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] | ||
| w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] |
There was a problem hiding this comment.
The change to pre-allocate w1 (and similarly w1_scale in _fuse_weight_scale) and then fill it using slicing is a good optimization. This approach avoids the creation of intermediate lists of tensors and the subsequent overhead of torch.cat and torch._utils._flatten_dense_tensors (for w1_list), which should lead to reduced peak memory usage and potentially faster execution.
|
|
||
| if expert_id == -1: | ||
| return | ||
|
|
||
| tile_m_idx = tl.load(mblocks_to_m_index + pid_m) | ||
| tile_n_idx = pid_n | ||
|
|
||
| # get the gemm size of the current problem | ||
| cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last") | ||
|
|
||
| # do regular gemm here | ||
| offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| token_mask = offs_am < cur_m | ||
| offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| # token_mask = offs_am < cur_m | ||
| a_m_index = tl.load( | ||
| expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, | ||
| mask=token_mask, | ||
| other=0, | ||
| padded_expert_to_token_index + offs_am, | ||
| ) | ||
| if MUL_ROUTED_WEIGHT: | ||
| a_m_scale = tl.load( | ||
| expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am, | ||
| mask=token_mask, | ||
| other=0.0, | ||
| ) | ||
|
|
||
| token_mask = a_m_index != -1 |
There was a problem hiding this comment.
The indexing logic for a_m_index and the subsequent token_mask derivation has been updated to use the new padded_expert_to_token_index.
Previously, tile_m_idx was loaded from mblocks_to_m_index, and cur_m (expert-specific token count) was used for masking. Now, offs_am is directly computed from pid_m, and a_m_index is loaded from padded_expert_to_token_index + offs_am. The token_mask is then derived from a_m_index != -1.
This simplification relies on padded_expert_to_token_index being correctly populated by moe_align2 (with -1 for padding). Ensure thorough testing, as this is a critical part of the GEMM computation.
| a_m_scale = tl.load( | ||
| padded_expert_to_weights + offs_am, | ||
| mask=token_mask, | ||
| other=0.0, | ||
| ) | ||
| accumulator *= a_m_scale[:, None] |
There was a problem hiding this comment.
| qinput_tensor = self.cache_manager.alloc_tensor( | ||
| (m, k), qweight.dtype, device=qweight.device, is_graph_out=False | ||
| ) | ||
| per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale) | ||
| input_scale = tma_align_input_scale(input_scale) | ||
| _, input_scale = per_token_group_quant_fp8( | ||
| input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True | ||
| ) |
There was a problem hiding this comment.
The per_token_group_quant_fp8 function is now called with column_major_scales=True and scale_tma_aligned=True. The returned input_scale is used directly, and the previous explicit call to tma_align_input_scale is removed.
This is a good refactoring as it encapsulates the TMA alignment logic for scales within the per_token_group_quant_fp8 function itself.
| x: torch.Tensor, | ||
| group_size: int, | ||
| x_q: torch.Tensor, | ||
| x_s: torch.Tensor, | ||
| x_s: torch.Tensor = None, | ||
| eps: float = 1e-10, | ||
| dtype: torch.dtype = torch.float8_e4m3fn, | ||
| column_major_scales: bool = False, | ||
| scale_tma_aligned: bool = False, | ||
| alloc_func: Callable = torch.empty, | ||
| ): | ||
| # Adapted from | ||
| # https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290 | ||
| if HAS_SGL_KERNEL: | ||
| finfo = torch.finfo(dtype) | ||
| fp8_max, fp8_min = finfo.max, finfo.min | ||
| if column_major_scales: | ||
| if scale_tma_aligned: | ||
| # aligned to 4 * sizeof(float) | ||
| aligned_size = (x.shape[-2] + 3) // 4 * 4 | ||
| x_s = alloc_func( | ||
| x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), | ||
| device=x.device, | ||
| dtype=torch.float32, | ||
| ).permute(-1, -2)[: x.shape[-2], :] | ||
| else: | ||
| x_s = alloc_func( | ||
| (x.shape[-1] // group_size,) + x.shape[:-1], | ||
| device=x.device, | ||
| dtype=torch.float32, | ||
| ).permute(-1, -2) | ||
| else: | ||
| if x_s is None: | ||
| x_s = alloc_func( | ||
| x.shape[:-1] + (x.shape[-1] // group_size,), | ||
| device=x.device, | ||
| dtype=torch.float32, | ||
| ) | ||
| sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max) | ||
| else: | ||
| lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) | ||
|
|
||
| return x_q, x_s |
There was a problem hiding this comment.
The per_token_group_quant_fp8 function has been significantly enhanced:
x_s(scales tensor) is now optional and can be allocated internally.- New parameters
column_major_scalesandscale_tma_alignedallow for more control over the scale tensor's layout and alignment. - The
alloc_funcparameter adds flexibility for tensor allocation. - The function now returns
(x_q, x_s).
These changes make the function more versatile and self-contained.
| w1_scale = w2_scale = None | ||
|
|
||
| if use_fp8_w8a8: | ||
| init_dtype = dtype | ||
| w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda() | ||
| w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda() | ||
| w1 = w1.to(torch.float8_e4m3fn) | ||
| w2 = w2.to(torch.float8_e4m3fn) | ||
|
|
||
| if block_shape is None: | ||
| w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda() | ||
| w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda() | ||
| else: | ||
| block_n, block_k = block_shape[0], block_shape[1] | ||
| n_tiles_w1 = (2 * n + block_n - 1) // block_n | ||
| n_tiles_w2 = (k + block_n - 1) // block_n | ||
| k_tiles_w1 = (k + block_k - 1) // block_k | ||
| k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k | ||
| w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda() | ||
| w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda() | ||
| else: | ||
| w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() | ||
| w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() |
There was a problem hiding this comment.
The weight and scale initialization logic for FP8 (use_fp8_w8a8=True) has been significantly revised. Weights (w1, w2) are now directly converted to torch.float8_e4m3fn, and their scales (w1_scale, w2_scale) are initialized based on block_shape (supporting per-expert or block-wise scales).
The explicit call to quantize_moe has been removed. This implies that the grouped_matmul kernel is now expected to consume FP8 weights and their corresponding scales directly.
| block_shape = getattr(model_config, "block_shape", None) | ||
| block_shape = [128, 128] |
There was a problem hiding this comment.
In the benchmark function, block_shape is first potentially retrieved from model_config. However, it's immediately overwritten by block_shape = [128, 128].
If the intention is to use a dynamic block_shape from the model configuration, the overwrite should be removed or made conditional. If [128, 128] is a specific override for this benchmark, it might be clearer to document this or use a different variable name for the override.
| block_shape = getattr(model_config, "block_shape", None) | |
| block_shape = [128, 128] | |
| dtype = model_config["dtype"] | |
| # block_shape = model_config.get("block_shape", None) # Get from config if needed | |
| # For this specific benchmark, we might be testing a fixed block_shape: | |
| block_shape_override = [128, 128] # Or use model_config["block_shape"] if available and desired | |
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) |
47a243a to
587e0a4
Compare
d00e17e to
c5936e3
Compare
… into fused_moe_improve
Co-authored-by: wangzaijun <wzjhelloworld@qq.com> Co-authored-by: sufubao <47234901+sufubao@users.noreply.github.com> Co-authored-by: sufubao <sufubao@sensetime.com> Co-authored-by: none <none> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
No description provided.