Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of the LightLLM inference engine by implementing CUDA graph optimization for the prefill stage. It introduces a specialized Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for prefill CUDA graphs, which is a significant performance optimization. The changes are extensive, touching model execution logic, memory management, and introducing new infrastructure for graph capture and replay.
My review focuses on several key areas:
- Code Duplication and Redundancy: I've identified areas in
basemodel.pyandtransformer_layer_infer_template.pywhere logic is repeated, and suggest refactoring to improve maintainability. - Correctness and Completeness: The implementation for prefill overlap mode appears incomplete and could lead to runtime errors. I've also found a couple of typos in argument names.
- Maintainability: Some new components are complex, and I've suggested adding comments or removing empty docstrings to improve clarity.
Overall, this is a great feature addition. Addressing these points will make the implementation more robust and easier to maintain.
There was a problem hiding this comment.
This block of code for handling CUDA graph capturing of the attention mechanism is identical to the one in context_forward (lines 73-112). This duplication makes the code harder to maintain. Any changes to this logic would need to be applied in two places.
Please extract this logic into a private helper method to avoid repetition and improve code clarity. For example:
def _graph_split_attention(self, q, cache_kv, infer_state, layer_weight):
# prefill 的 cuda graph 过程, 排除掉attention部分
if torch.cuda.is_current_stream_capturing():
_q, _cache_kv = (
tensor_to_no_ref_tensor(q.contiguous()),
tensor_to_no_ref_tensor(cache_kv.contiguous()),
)
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)
def get_o_shape_dtype_device():
# 在一个新的 graph 中尝试运行,并不是为了捕获图,是为了尝试得到 o 的形状等信息
with torch.cuda.graph(cuda_graph=torch.cuda.CUDAGraph()):
# Use _q, _cache_kv to avoid capturing original tensors
__o = self._context_attention_kernel(_q, _cache_kv, infer_state, layer_weight)
o_shape = __o.shape
o_dtype = __o.dtype
o_device = __o.device
del __o
import gc
gc.collect()
torch.cuda.empty_cache()
return o_shape, o_dtype, o_device
o_shape, o_dtype, o_device = get_o_shape_dtype_device()
infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
o = torch.empty(o_shape, dtype=o_dtype, device=o_device)
_o = tensor_to_no_ref_tensor(o)
def att_func(new_infer_state: InferStateInfo):
tmp_o = self._context_attention_kernel(_q, _cache_kv, new_infer_state, layer_weight)
assert tmp_o.shape == _o.shape
_o.copy_(tmp_o)
return
infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph)
return o
else:
return self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)Then both context_forward and tpsp_context_forward can call this helper method.
| # TODO | ||
| raise NotImplementedError("not impl") |
There was a problem hiding this comment.
The method _capture_prefill_overlap raises NotImplementedError. However, warmup_overlap calls model.microbatch_overlap_prefill, which seems to rely on this functionality via capture_prefill. If enable_prefill_microbatch_overlap is enabled, this will lead to a runtime crash during warmup.
The feature seems incomplete. It should either be fully implemented or guarded to prevent crashes.
| if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num): | ||
| finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num( | ||
| handle_token_num=handle_token_num | ||
| ) | ||
| if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num): | ||
| output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill( | ||
| prefill_func=prefill_func, | ||
| input_tensors=input_tensors, | ||
| infer_state=infer_state, | ||
| ) | ||
| else: | ||
| output_tensors: List[torch.Tensor] = self.prefill_graph.replay( | ||
| input_tensors=input_tensors, infer_state=infer_state | ||
| ) |
There was a problem hiding this comment.
There's redundant logic here for checking if a CUDA graph can be used. The calling function _prefill already performs these checks (can_run, find_closest_graph_handle_token_num) and pads the input accordingly. This block repeats the same checks on the already-padded input.
This duplicated logic can be simplified. The decision to use a graph and the specific graph to use should be determined once in _prefill and then passed to _context_forward, for example via the infer_state object. This would make the code cleaner and easier to maintain.
A potential refactoring could look like this:
In _prefill:
# ...
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
handle_token_num=handle_token_num
)
model_input = self._create_padded_prefill_model_input(
model_input=model_input, new_handle_token_num=finded_handle_token_num
)
infer_state = self._create_inferstate(model_input)
infer_state.use_prefill_graph = True
infer_state.prefill_graph_handle_token_num = finded_handle_token_num
# ...In _context_forward:
# ...
if getattr(infer_state, 'use_prefill_graph', False):
handle_token_num = infer_state.prefill_graph_handle_token_num
if self.prefill_graph.need_capture(handle_token_num=handle_token_num):
# capture logic
else:
# replay logic
else:
# non-graph logic
# ...| return g_cache_manager.alloc_tensor( | ||
| shape, dtype, device=device, is_graph_out=is_graph_out, microbatch_index=microbatch_index | ||
| ) | ||
| """ """ |
| " currently only for llama and qwen model, not support ep moe model", | ||
| ) | ||
| parser.add_argument( | ||
| "--prefll_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph" |
There was a problem hiding this comment.
There is a typo in the argument name: prefll_cudagraph_max_handle_token. It should be prefill_cudagraph_max_handle_token.
| "--prefll_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph" | |
| "--prefill_cudagraph_max_handle_token", type=int, default=512, help="max handle token num for prefill cudagraph" |
| enable_monitor_auth: bool = field(default=False) | ||
| disable_cudagraph: bool = field(default=False) | ||
| enable_prefill_cudagraph: bool = field(default=False) | ||
| prefll_cudagraph_max_handle_token: int = field(default=512) |
There was a problem hiding this comment.
There is a typo in the field name: prefll_cudagraph_max_handle_token. It should be prefill_cudagraph_max_handle_token. This should be corrected to match the argument name in api_cli.py.
| prefll_cudagraph_max_handle_token: int = field(default=512) | |
| prefill_cudagraph_max_handle_token: int = field(default=512) |
No description provided.