Skip to content

feat: add rope_in_place tilelang kernel for npu device.#964

Draft
zhang-minchao wants to merge 1 commit intojd-opensource:mainfrom
zhang-minchao:feat/tilelang
Draft

feat: add rope_in_place tilelang kernel for npu device.#964
zhang-minchao wants to merge 1 commit intojd-opensource:mainfrom
zhang-minchao:feat/tilelang

Conversation

@zhang-minchao
Copy link
Collaborator

Requires the CANN 8.5 image.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new TileLang-based RoPE kernel for Ascend NPUs, including the necessary build system infrastructure, C++ wrappers, and tests. The changes are extensive and add significant new functionality. However, there are several critical and high-severity issues that should be addressed. The build process relies on a brittle script patching mechanism in setup.py that could easily break. The C++ wrapper for the kernel has a critical performance issue due to inefficient tensor broadcasting and an unsafe use of const_cast. Additionally, the CMake configuration for building the kernel is not flexible, with hardcoded dimensions and a fragile sed-based code modification step. Addressing these issues will improve the robustness, performance, and maintainability of the new kernel and its build process.

Comment on lines +87 to +112
def _patch_tilelang_install_script(tilelang_root: str) -> None:
script_path = os.path.join(tilelang_root, "install_ascend.sh")
if not os.path.isfile(script_path):
raise RuntimeError(
"[ERROR] Missing tilelang install script: install_ascend.sh"
)

line_no = 145
current_line = subprocess.run(
["sed", "-n", f"{line_no}p", script_path],
check=True,
text=True,
capture_output=True,
).stdout.strip()

if current_line == "make -j${MAKE_JOBS}":
subprocess.check_call(["sed", "-i", f"{line_no}c\\make -j", script_path])
print(f"[INFO] Applied tilelang install parallel patch at line {line_no}: make -j")
return

if current_line == "make -j":
return

raise RuntimeError(
f"[ERROR] Unexpected install_ascend.sh content at line {line_no}: {current_line!r}"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _patch_tilelang_install_script uses sed with a hardcoded line number to patch the install_ascend.sh script from the tilelang-ascend submodule. This is very brittle and likely to break if the submodule is updated and the script changes. A build failure due to this would be hard to debug.

Instead of patching the script in-place, consider one of these more robust approaches:
1.  Fork the `tilelang-ascend` repository, apply the change there, and use your fork as the submodule. This is the safest option.
2.  Use a proper patch file (`.patch`) and apply it with `git apply`. This is more resilient to line number changes than `sed`.
3.  If possible, modify the `make` command by passing arguments or environment variables instead of editing the script file. For example: `MAKE_JOBS="" bash install_ascend.sh`. This would depend on how `MAKE_JOBS` is used in the script.

Comment on lines +158 to +165
auto sin_rows = sin_cache.unsqueeze(1)
.expand({input.size(0), input.size(1), sin_cache.size(1)})
.contiguous()
.view({input_rows.size(0), sin_cache.size(1)});
auto cos_rows = cos_cache.unsqueeze(1)
.expand({input.size(0), input.size(1), cos_cache.size(1)})
.contiguous()
.view({input_rows.size(0), cos_cache.size(1)});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The rope_in_place function prepares sin_rows and cos_rows by expanding and then calling .contiguous(). The .contiguous() call creates a full copy of the expanded tensor on every invocation. For large inputs (e.g., long sequences), this results in significant memory allocation and data copying overhead, which can be a major performance bottleneck. For example, with a sequence length of 2048 and 32 heads, this could allocate and copy over 30MB of data on every call.

To avoid this overhead, modify the TileLang kernel to handle the broadcasting internally. The kernel can accept the original `sin_cache` and `cos_cache` tensors (of shape `[num_tokens, rope_dim]`) and an additional `num_heads` parameter. Inside the kernel, you can calculate the correct index into the cache for each row using `token_idx = row_idx / num_heads`. This would eliminate the need for the expensive `expand().contiguous()` pattern in the C++ wrapper.

"${Python_EXECUTABLE}" "${TL_PY_SCRIPT}"
--output "${_kernel_cpp}"
${TL_CODEGEN_ARGS}
COMMAND sed -i -e "s/\\<${TL_SOURCE_ENTRY_SYMBOL}\\>/${_entry_symbol}/g" "${_kernel_cpp}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The custom command uses sed to replace the entry function name in the C++ source file generated by TileLang. While using word boundaries (\\<, \\>) makes it safer than a simple string replacement, this approach is still brittle. It relies on an external tool (sed) and can break if the code generation logic in TileLang changes in a way that affects the call symbol.

A more robust solution would be to modify the Python codegen script (`xllm/core/kernels/python/tilelang/rope.py`) to accept the desired function name as a command-line argument. The CMake function `tilelang_add_ascendc_kernel` can then pass `_entry_symbol` to the script, eliminating the need for `sed`.

Comment on lines +129 to +130
set(TILELANG_ROPE_HEAD_DIM 576)
set(TILELANG_ROPE_ROPE_DIM 64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The RoPE dimensions TILELANG_ROPE_HEAD_DIM and TILELANG_ROPE_ROPE_DIM are hardcoded in this CMakeLists.txt file. This makes it difficult to reuse the kernel for different model architectures without modifying the build scripts.

set(TILELANG_ROPE_HEAD_DIM 576 CACHE STRING "RoPE head dimension for TileLang kernel")
set(TILELANG_ROPE_ROPE_DIM 64 CACHE STRING "RoPE dimension for TileLang kernel")

Comment on lines +140 to +141
reinterpret_cast<uint8_t*>(const_cast<void*>(sin_rows.data_ptr())),
reinterpret_cast<uint8_t*>(const_cast<void*>(cos_rows.data_ptr())),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code uses const_cast to remove the const qualifier from the data pointers of sin_rows and cos_rows before passing them to the kernel. This violates const correctness and is unsafe. The sin and cos caches are read-only inputs and should be treated as such throughout the call stack.

To fix this:

  1. In rope_wrapper.cpp, declare the kernel entry point with const pointers for read-only inputs:
extern "C" void XLLM_TL_ROPE_ENTRY(uint8_t* x_handle,
                                   const uint8_t* sin_handle,
                                   const uint8_t* cos_handle,
                                   ...);
  1. Update the call to XLLM_TL_ROPE_ENTRY to remove the const_cast.
      reinterpret_cast<const uint8_t*>(sin_rows.data_ptr()),
      reinterpret_cast<const uint8_t*>(cos_rows.data_ptr()),

@XuZhang99 XuZhang99 changed the title feat: add rope_in_place tilelang kernel. feat: add rope_in_place tilelang kernel for npu device. Feb 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant