feat: add rope_in_place tilelang kernel for npu device.#964
feat: add rope_in_place tilelang kernel for npu device.#964zhang-minchao wants to merge 1 commit intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new TileLang-based RoPE kernel for Ascend NPUs, including the necessary build system infrastructure, C++ wrappers, and tests. The changes are extensive and add significant new functionality. However, there are several critical and high-severity issues that should be addressed. The build process relies on a brittle script patching mechanism in setup.py that could easily break. The C++ wrapper for the kernel has a critical performance issue due to inefficient tensor broadcasting and an unsafe use of const_cast. Additionally, the CMake configuration for building the kernel is not flexible, with hardcoded dimensions and a fragile sed-based code modification step. Addressing these issues will improve the robustness, performance, and maintainability of the new kernel and its build process.
| def _patch_tilelang_install_script(tilelang_root: str) -> None: | ||
| script_path = os.path.join(tilelang_root, "install_ascend.sh") | ||
| if not os.path.isfile(script_path): | ||
| raise RuntimeError( | ||
| "[ERROR] Missing tilelang install script: install_ascend.sh" | ||
| ) | ||
|
|
||
| line_no = 145 | ||
| current_line = subprocess.run( | ||
| ["sed", "-n", f"{line_no}p", script_path], | ||
| check=True, | ||
| text=True, | ||
| capture_output=True, | ||
| ).stdout.strip() | ||
|
|
||
| if current_line == "make -j${MAKE_JOBS}": | ||
| subprocess.check_call(["sed", "-i", f"{line_no}c\\make -j", script_path]) | ||
| print(f"[INFO] Applied tilelang install parallel patch at line {line_no}: make -j") | ||
| return | ||
|
|
||
| if current_line == "make -j": | ||
| return | ||
|
|
||
| raise RuntimeError( | ||
| f"[ERROR] Unexpected install_ascend.sh content at line {line_no}: {current_line!r}" | ||
| ) |
There was a problem hiding this comment.
The function _patch_tilelang_install_script uses sed with a hardcoded line number to patch the install_ascend.sh script from the tilelang-ascend submodule. This is very brittle and likely to break if the submodule is updated and the script changes. A build failure due to this would be hard to debug.
Instead of patching the script in-place, consider one of these more robust approaches:
1. Fork the `tilelang-ascend` repository, apply the change there, and use your fork as the submodule. This is the safest option.
2. Use a proper patch file (`.patch`) and apply it with `git apply`. This is more resilient to line number changes than `sed`.
3. If possible, modify the `make` command by passing arguments or environment variables instead of editing the script file. For example: `MAKE_JOBS="" bash install_ascend.sh`. This would depend on how `MAKE_JOBS` is used in the script.| auto sin_rows = sin_cache.unsqueeze(1) | ||
| .expand({input.size(0), input.size(1), sin_cache.size(1)}) | ||
| .contiguous() | ||
| .view({input_rows.size(0), sin_cache.size(1)}); | ||
| auto cos_rows = cos_cache.unsqueeze(1) | ||
| .expand({input.size(0), input.size(1), cos_cache.size(1)}) | ||
| .contiguous() | ||
| .view({input_rows.size(0), cos_cache.size(1)}); |
There was a problem hiding this comment.
The rope_in_place function prepares sin_rows and cos_rows by expanding and then calling .contiguous(). The .contiguous() call creates a full copy of the expanded tensor on every invocation. For large inputs (e.g., long sequences), this results in significant memory allocation and data copying overhead, which can be a major performance bottleneck. For example, with a sequence length of 2048 and 32 heads, this could allocate and copy over 30MB of data on every call.
To avoid this overhead, modify the TileLang kernel to handle the broadcasting internally. The kernel can accept the original `sin_cache` and `cos_cache` tensors (of shape `[num_tokens, rope_dim]`) and an additional `num_heads` parameter. Inside the kernel, you can calculate the correct index into the cache for each row using `token_idx = row_idx / num_heads`. This would eliminate the need for the expensive `expand().contiguous()` pattern in the C++ wrapper.| "${Python_EXECUTABLE}" "${TL_PY_SCRIPT}" | ||
| --output "${_kernel_cpp}" | ||
| ${TL_CODEGEN_ARGS} | ||
| COMMAND sed -i -e "s/\\<${TL_SOURCE_ENTRY_SYMBOL}\\>/${_entry_symbol}/g" "${_kernel_cpp}" |
There was a problem hiding this comment.
The custom command uses sed to replace the entry function name in the C++ source file generated by TileLang. While using word boundaries (\\<, \\>) makes it safer than a simple string replacement, this approach is still brittle. It relies on an external tool (sed) and can break if the code generation logic in TileLang changes in a way that affects the call symbol.
A more robust solution would be to modify the Python codegen script (`xllm/core/kernels/python/tilelang/rope.py`) to accept the desired function name as a command-line argument. The CMake function `tilelang_add_ascendc_kernel` can then pass `_entry_symbol` to the script, eliminating the need for `sed`.
| set(TILELANG_ROPE_HEAD_DIM 576) | ||
| set(TILELANG_ROPE_ROPE_DIM 64) |
There was a problem hiding this comment.
The RoPE dimensions TILELANG_ROPE_HEAD_DIM and TILELANG_ROPE_ROPE_DIM are hardcoded in this CMakeLists.txt file. This makes it difficult to reuse the kernel for different model architectures without modifying the build scripts.
set(TILELANG_ROPE_HEAD_DIM 576 CACHE STRING "RoPE head dimension for TileLang kernel")
set(TILELANG_ROPE_ROPE_DIM 64 CACHE STRING "RoPE dimension for TileLang kernel")
| reinterpret_cast<uint8_t*>(const_cast<void*>(sin_rows.data_ptr())), | ||
| reinterpret_cast<uint8_t*>(const_cast<void*>(cos_rows.data_ptr())), |
There was a problem hiding this comment.
The code uses const_cast to remove the const qualifier from the data pointers of sin_rows and cos_rows before passing them to the kernel. This violates const correctness and is unsafe. The sin and cos caches are read-only inputs and should be treated as such throughout the call stack.
To fix this:
- In
rope_wrapper.cpp, declare the kernel entry point withconstpointers for read-only inputs:
extern "C" void XLLM_TL_ROPE_ENTRY(uint8_t* x_handle,
const uint8_t* sin_handle,
const uint8_t* cos_handle,
...);- Update the call to
XLLM_TL_ROPE_ENTRYto remove theconst_cast.
reinterpret_cast<const uint8_t*>(sin_rows.data_ptr()),
reinterpret_cast<const uint8_t*>(cos_rows.data_ptr()),
Requires the CANN 8.5 image.