Skip to content

prefill cuda graph.#1149

Merged
hiworldwzj merged 19 commits intomainfrom
wzj
Dec 18, 2025
Merged

prefill cuda graph.#1149
hiworldwzj merged 19 commits intomainfrom
wzj

Conversation

@hiworldwzj
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the LightLLM inference engine by implementing CUDA graph optimization for the prefill stage. It introduces a specialized PrefillCudaGraph mechanism that intelligently captures and replays computation graphs for various input lengths, thereby reducing CPU overhead and improving GPU utilization. A key aspect of this change is a refined memory management strategy that leverages "no-reference" tensors to ensure efficient memory reuse within the captured graphs. This optimization is particularly beneficial for scenarios involving diverse prefill lengths, aiming to provide more consistent and faster initial token generation.

Highlights

  • Prefill CUDA Graph Implementation: Introduces a dedicated PrefillCudaGraph class and integrates it into the BaseModel to accelerate the prefill stage of LLM inference, allowing for more efficient initial token generation.
  • Dynamic Graph Capture and Replay: The system now dynamically captures and replays CUDA graphs for varying prefill token lengths, using a strategy of powers of two up to a maximum configured token count, optimizing performance for diverse input sizes.
  • Memory Management Refinement: Refactors the CacheTensorManager by removing CUDA graph-specific parameters and introduces a tensor_to_no_ref_tensor utility. This new mechanism optimizes memory reuse within CUDA graphs by bypassing PyTorch's reference counting, leading to reduced memory footprint and improved efficiency.
  • Attention Kernel Isolation for Graph Flexibility: During prefill CUDA graph capture, the attention kernel is now isolated and executed via a CPU function during replay. This allows for more flexible graph construction and better handling of dynamic shapes within the captured graph.
  • New CLI Arguments: Adds --enable_prefill_cudagraph and --prefll_cudagraph_max_handle_token command-line arguments, providing users with control and configuration options for the new prefill CUDA graph feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for prefill CUDA graphs, which is a significant performance optimization. The changes are extensive, touching model execution logic, memory management, and introducing new infrastructure for graph capture and replay.

My review focuses on several key areas:

  • Code Duplication and Redundancy: I've identified areas in basemodel.py and transformer_layer_infer_template.py where logic is repeated, and suggest refactoring to improve maintainability.
  • Correctness and Completeness: The implementation for prefill overlap mode appears incomplete and could lead to runtime errors. I've also found a couple of typos in argument names.
  • Maintainability: Some new components are complex, and I've suggested adding comments or removing empty docstrings to improve clarity.

Overall, this is a great feature addition. Addressing these points will make the implementation more robust and easier to maintain.

Comment on lines 156 to 194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for handling CUDA graph capturing of the attention mechanism is identical to the one in context_forward (lines 73-112). This duplication makes the code harder to maintain. Any changes to this logic would need to be applied in two places.

Please extract this logic into a private helper method to avoid repetition and improve code clarity. For example:

def _graph_split_attention(self, q, cache_kv, infer_state, layer_weight):
    # prefill 的 cuda graph 过程, 排除掉attention部分
    if torch.cuda.is_current_stream_capturing():
        _q, _cache_kv = (
            tensor_to_no_ref_tensor(q.contiguous()),
            tensor_to_no_ref_tensor(cache_kv.contiguous()),
        )
        pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
        pre_capture_graph.__exit__(None, None, None)

        def get_o_shape_dtype_device():
            # 在一个新的 graph 中尝试运行,并不是为了捕获图,是为了尝试得到 o 的形状等信息
            with torch.cuda.graph(cuda_graph=torch.cuda.CUDAGraph()):
                # Use _q, _cache_kv to avoid capturing original tensors
                __o = self._context_attention_kernel(_q, _cache_kv, infer_state, layer_weight)
                o_shape = __o.shape
                o_dtype = __o.dtype
                o_device = __o.device
                del __o

                import gc

                gc.collect()
                torch.cuda.empty_cache()
            return o_shape, o_dtype, o_device

        o_shape, o_dtype, o_device = get_o_shape_dtype_device()
        infer_state.prefill_cuda_graph_create_graph_obj()
        infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
        o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
        _o = tensor_to_no_ref_tensor(o)

        def att_func(new_infer_state: InferStateInfo):
            tmp_o = self._context_attention_kernel(_q, _cache_kv, new_infer_state, layer_weight)
            assert tmp_o.shape == _o.shape
            _o.copy_(tmp_o)
            return

        infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph)
        return o
    else:
        return self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)

Then both context_forward and tpsp_context_forward can call this helper method.

Comment on lines +86 to +87
# TODO
raise NotImplementedError("not impl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method _capture_prefill_overlap raises NotImplementedError. However, warmup_overlap calls model.microbatch_overlap_prefill, which seems to rely on this functionality via capture_prefill. If enable_prefill_microbatch_overlap is enabled, this will lead to a runtime crash during warmup.

The feature seems incomplete. It should either be fully implemented or guarded to prevent crashes.

Comment on lines +522 to +535
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
handle_token_num=handle_token_num
)
if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num):
output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill(
prefill_func=prefill_func,
input_tensors=input_tensors,
infer_state=infer_state,
)
else:
output_tensors: List[torch.Tensor] = self.prefill_graph.replay(
input_tensors=input_tensors, infer_state=infer_state
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's redundant logic here for checking if a CUDA graph can be used. The calling function _prefill already performs these checks (can_run, find_closest_graph_handle_token_num) and pads the input accordingly. This block repeats the same checks on the already-padded input.

This duplicated logic can be simplified. The decision to use a graph and the specific graph to use should be determined once in _prefill and then passed to _context_forward, for example via the infer_state object. This would make the code cleaner and easier to maintain.

A potential refactoring could look like this:

In _prefill:

# ...
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
    finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
        handle_token_num=handle_token_num
    )
    model_input = self._create_padded_prefill_model_input(
        model_input=model_input, new_handle_token_num=finded_handle_token_num
    )
    infer_state = self._create_inferstate(model_input)
    infer_state.use_prefill_graph = True
    infer_state.prefill_graph_handle_token_num = finded_handle_token_num
# ...

In _context_forward:

# ...
if getattr(infer_state, 'use_prefill_graph', False):
    handle_token_num = infer_state.prefill_graph_handle_token_num
    if self.prefill_graph.need_capture(handle_token_num=handle_token_num):
        # capture logic
    else:
        # replay logic
else:
    # non-graph logic
# ...

return g_cache_manager.alloc_tensor(
shape, dtype, device=device, is_graph_out=is_graph_out, microbatch_index=microbatch_index
)
""" """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This docstring is empty. It should either be removed or filled with a meaningful description of the function's purpose.

" currently only for llama and qwen model, not support ep moe model",
)
parser.add_argument(
"--prefll_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the argument name: prefll_cudagraph_max_handle_token. It should be prefill_cudagraph_max_handle_token.

Suggested change
"--prefll_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph"
"--prefill_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph"

enable_monitor_auth: bool = field(default=False)
disable_cudagraph: bool = field(default=False)
enable_prefill_cudagraph: bool = field(default=False)
prefll_cudagraph_max_handle_token: int = field(default=512)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the field name: prefll_cudagraph_max_handle_token. It should be prefill_cudagraph_max_handle_token. This should be corrected to match the argument name in api_cli.py.

Suggested change
prefll_cudagraph_max_handle_token: int = field(default=512)
prefill_cudagraph_max_handle_token: int = field(default=512)

@hiworldwzj hiworldwzj merged commit ef28098 into main Dec 18, 2025
1 check passed
@hiworldwzj hiworldwzj deleted the wzj branch December 18, 2025 05:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments