Conversation
There was a problem hiding this comment.
Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request titled "DeepSeek MTP". Based on the title and the extensive code changes, it appears this PR introduces support for the DeepSeek MTP (Multi-Task Prediction) model, specifically implementing a speculative decoding approach. The changes involve adding new components for MTP-specific logic, integrating this into the core model inference flow, and updating the server backend and CLI to support the new model and speculative decoding parameters.
Highlights
- DeepSeek MTP Model Support: Adds a new model implementation (
Deepseek3MTPModel) specifically for the DeepSeek MTP architecture, inheriting from the Deepseek2 model. - Speculative Decoding Implementation: Introduces a framework for speculative decoding, including new enums (
SpeculativeDecodeAlgorithm), batch objects (ModelInput,ModelOutput), and pre-processing logic (mtp_pre_process.py). This allows the model to generate multiple tokens speculatively in a single step. - MTP-Specific Components: Adds MTP-specific memory management (
Deepseek3MTPMemoryManager) and pre/post layer inference logic (Deepseek3MTPPreLayerInfer,Deepseek3MTPPreAndPostLayerWeight) to handle the unique aspects of the DeepSeek MTP model. - Integration into Inference Pipeline: Modifies the core
BaseModelforward pass,InferStateInfo, and CUDA graph handling to accommodate the newModelInput/ModelOutputstructure and the speculative decoding algorithm (spec_algo). Updates thecopy_kv_index_to_reqkernel to support variable decode lengths. - Backend and CLI Updates: Adds new backend implementations (
ContinuesBatchWithMTPBackend,ChunkedPrefillWithMTPBackend) to manage the speculative decoding workflow. Introduces new CLI arguments (--spec_algo,--spec_model_dir,--spec_step) to control speculative decoding settings. - Testing and Logging: Adds a new test script (
test/model/model_infer_mtp.py) to specifically test the MTP model inference with speculative decoding. Updates logging in the HTTP server manager to includeavg_token_per_stepbased on MTP acceptance length.
Changelog
Click here to see the changelog
- lightllm/common/basemodel/basemodel.py
- Imported
SpeculativeDecodeAlgorithm,ModelInput,ModelOutput(L23-24). - Added
spec_algoattribute toBaseModelinitialization (L77-78). - Refactored
forwardmethod to acceptModelInputobject (L235). - Introduced
_create_inferstatehelper method (L243-275). - Updated
_prefilland_decodemethods to useModelInputand_create_inferstate(L277-310). - Modified return type of
_context_forwardand_token_forwardtoModelOutputand included hidden states conditionally (L456-459, L484-487). - Renamed
predict_logicstopredict_logitsin various methods (L360, L368, L372, L429, L433, L507, L512, L526, L530). - Updated
_check_max_len_inferto useModelInputandModelOutput(L550-565).
- Imported
- lightllm/common/basemodel/batch_objs.py
- Added new file defining
ModelInputandModelOutputdataclasses (L1-23).
- Added new file defining
- lightllm/common/basemodel/cuda_graph.py
- Imported
ModelInput,ModelOutput,SpeculativeDecodeAlgorithm(L8-9). - Updated
_capture_decodeand_replayto handleModelOutput(L51-54, L99-103). - Added
decode_lencalculation inwarmupbased onspec_algo(L133). - Updated
warmupto useModelInputandModelOutputand handle potential hidden states for MTP modules (L145-166, L175-191).
- Imported
- lightllm/common/basemodel/infer_struct.py
- Imported
SpeculativeDecodeAlgorithm(L8). - Added
spec_algoandspec_infoattributes toInferStateInfo(L59-60).
- Imported
- lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py
- Modified the Triton kernel
_fwd_kernel_copy_kv_index_to_reqto handledecode_lenand batch size in grid calculation (L12-17). - Updated
copy_kv_index_to_reqfunction signature to acceptdecode_lenand adjusted grid calculation (L24-30). - Added a test case for
copy_kv_index_to_reqwithdecode_len1 and 2 (L46-63).
- Modified the Triton kernel
- lightllm/common/spec_info.py
- Added new file defining
SpeculativeDecodeAlgorithmIntEnum withNONE,MTP,MTP_MOUDLEvalues (L1-7). - Added helper methods
is_none,is_mtp,is_mtp_module(L9-16). - Added
from_stringstatic method to parse algorithm names (L18-27). - Added
decode_lenmethod to return the expected decode length for each algorithm (L29-35).
- Added new file defining
- lightllm/models/deepseek2/infer_struct.py
- Imported
SpeculativeDecodeAlgorithm(L6).
- Imported
- lightllm/models/deepseek_mtp/deepseek3_mtp_mem_manager.py
- Added new file defining
Deepseek3MTPMemoryManagerinheriting fromDeepseek2MemoryManager(L10). - Initializes memory manager with specific head parameters and shared memory for token count (L11-41).
- Added new file defining
- lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py
- Added new file defining
Deepseek3MTPPreLayerInferinheriting fromLlamaPreLayerInfer(L16). - Implemented
mtp_context_forwardandmtp_token_forwardmethods for MTP-specific pre-layer logic involving hidden states and projection (L25-58). - Overrode
context_forwardandtoken_forwardto call the MTP-specific methods (L60-70).
- Added new file defining
- lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py
- Added new file defining
Deepseek3MTPPreAndPostLayerWeightinheriting fromLlamaPreAndPostLayerWeight(L5). - Defined MTP-specific weights (
eh_proj_weight_,enorm_weight_,hnorm_weight_) and loading logic from HF weights (L14-21). - Implemented
verify_loadfor the new weights (L23-28).
- Added new file defining
- lightllm/models/deepseek_mtp/model.py
- Added new file defining
Deepseek3MTPModelinheriting fromDeepseek2TpPartModel(L17). - Set MTP-specific pre/post weight and layer infer classes (L19-20).
- Modified initialization to accept
main_modelandlast_mtp_module(L22-26). - Overrode
_init_req_managerto share the main model's request manager (L28-40). - Overrode
_init_mem_managerto useDeepseek3MTPMemoryManager(L42-51). - Overrode
_init_weightsto share main model's embedding, lm_head, and final norm weights (L53-57).
- Added new file defining
- lightllm/models/llama/layer_infer/post_layer_infer.py
- Minor reordering of lines in
token_forward(L68-69).
- Minor reordering of lines in
- lightllm/server/api_cli.py
- Added command-line arguments
--spec_algo,--spec_model_dir, and--spec_stepfor speculative decoding configuration (L384-401).
- Added command-line arguments
- lightllm/server/core/objs/req.py
- Added
mtp_accepted_lenfield to theReqctypes structure (L98). - Initialized
mtp_accepted_lento 0 in theinitmethod (L150).
- Added
- lightllm/server/httpserver/manager.py
- Added calculation and logging for
avg_token_per_stepusingmtp_accepted_len(L544-545, L555). - Included
mtp_accepted_lenin the metadata passed to the detokenization process (L658).
- Added calculation and logging for
- lightllm/server/router/manager.py
- Added
spec_stepattribute initialized from args (L65). - Passed
spec_algo,spec_weight_dir, andspec_steptomodel_rpc_client.init_model(L180-182). - Modified prefill and decode batch handling to send
spec_step + 1None packages to detokenization (L391-393, L410-412).
- Added
- lightllm/server/router/model_infer/infer_batch.py
- Added
cur_accepted_lenattribute toInferReq(L260). - Added
get_chunked_input_token_ids_shiftmethod (L323-328). - Added
set_total_accepted_lenmethod to update the shared memory request object (L341-342).
- Added
- lightllm/server/router/model_infer/mode_backend/init.py
- Imported
ContinuesBatchWithMTPBackend(L15).
- Imported
- lightllm/server/router/model_infer/mode_backend/base_backend.py
- Added
spec_algoto themodel_kvargsdictionary passed toget_model(L113).
- Added
- lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
- Updated
prepare_decode_inputsandprepare_prefill_inputscalls to usemodel_inputinstead ofkwargs(L41, L58). - Updated model forward calls to use
model_inputand handlemodel_output(L42, L61). - Updated sample calls to use
model_output.logits(L46, L65). - Deleted
model_outputafter sampling (L52, L71).
- Updated
- lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_mtp.py
- Added new file defining
ChunkedPrefillWithMTPBackendinheriting fromContinuesBatchWithMTPBackend(L27). - Implemented
decodemethod with MTP-specific prefill and decode logic, including calling draft models and verifying accepted tokens (L31-120). - Implemented
verifymethod to compare main model output with draft tokens and determine accepted length (L122-145). - Implemented
_save_draft_token_idsto store draft model outputs (L147-154).
- Added new file defining
- lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py
- Updated
prepare_prefill_inputsandprepare_decode_inputscalls to usemodel_inputinstead ofkwargs(L33, L51). - Updated model forward calls to use
model_input(L36, L52).
- Updated
- lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py
- Added new file defining
ContinuesBatchWithMTPBackendinheriting fromModeBackend(L29). - Implemented
init_modelto load the main model and multiple draft models based onspec_stepandspec_model_dir(L34-81). - Implemented
prefillanddecodemethods with MTP-specific logic, similar to the chunked prefill version, including calling draft models, verifying, and saving draft tokens (L83-178). - Implemented
verifyand_save_draft_token_idsmethods for MTP speculative decoding (L180-215).
- Added new file defining
- lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py
- Added new file defining MTP-specific pre-processing functions (L1-81).
- Implemented
prepare_mtp_prefill_inputsto prepare inputs for draft model prefill, potentially shifting input IDs (L11-29). - Implemented
prepare_draft_main_model_decode_inputsto prepare inputs for the main model decode step in MTP, including draft tokens (L32-81).
- lightllm/server/router/model_infer/model_rpc.py
- Imported
ContinuesBatchWithMTPBackend(L24). - Added logic in
init_modelto selectContinuesBatchWithMTPBackendifspec_algois 'MTP' anddisable_chunked_prefillis true (L157-159).
- Imported
- lightllm/utils/dist_utils.py
- Added
device0_printhelper function for printing only on device 0 (L196-198).
- Added
- test/model/model_infer_mtp.py
- Added new file with MTP-specific model inference test logic (L1-282).
- Includes
init_mtp_modelto load draft models (L20-53). - Includes
run_forward_onceto perform a single prefill and decode step with main and draft models (L113-246). - Includes
tppart_model_inferto run the test inference in a distributed environment (L249-282).
- test/model/test_model.py
- Imported
test_model_inference_mtp(L8). - Modified
test_model_inferto calltest_model_inference_mtpifspec_algois 'MTP' (L19-21). - Modified default values for
--batch_size,--input_len, and--output_lenin the argument parser (L30-32).
- Imported
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Speculative steps,
Drafting tokens, quick and bold,
Main model checks all.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces Multi-Token Prediction (MTP) speculative decoding, primarily for DeepSeek models. The changes involve significant refactoring of the model's forward pass to use ModelInput and ModelOutput dataclasses, which improves code organization. New classes and methods are added to support the MTP algorithm, including specialized memory managers, layer inference logic, and weight handling for draft models.
The core MTP logic involves the main model making a prediction, followed by one or more draft models making subsequent predictions using the main model's hidden states. These draft predictions are then verified.
Overall, the changes are substantial and introduce a complex new feature. The refactoring to ModelInput/ModelOutput is a good step. One potential issue was found in the logging logic that needs addressing. The renaming of predict_logics to predict_logits throughout the codebase is a welcome improvement for clarity.
Summary of Findings
- Potential ZeroDivisionError in Logging: In
lightllm/server/httpserver/manager.py, the calculation ofavg_token_per_step(line 544) can lead to aZeroDivisionErrorifout_token_counterequalsmetadata["mtp_accepted_len"]. This needs to be handled to prevent runtime errors. - Refactoring of Model Forward Pass: The
forwardmethod inbasemodel.pyand its usage in various backends have been refactored to useModelInputandModelOutputdataclasses. This is a good change that improves code clarity and maintainability by grouping related parameters. - MTP Speculative Decoding Implementation: The core logic for MTP, including draft model handling, hidden state passing, and token verification, has been implemented. New classes like
Deepseek3MTPModel,ContinuesBatchWithMTPBackend, and specialized pre/post processing functions support this. - CUDA Graph and Kernel Adjustments: Changes to CUDA graph handling and the
copy_kv_index_to_reqTriton kernel accommodate the multi-token nature of speculative decoding. - Type Hinting in
batch_objs.py: InModelOutputdataclass,torch.tensoris used for type hints. It's recommended to usetorch.Tensorfor consistency.
Merge Readiness
This pull request introduces a significant new feature (MTP speculative decoding) and includes substantial refactoring. While the overall structure and MTP logic appear sound, there is a high-severity issue in lightllm/server/httpserver/manager.py (potential ZeroDivisionError) that must be addressed before merging. Additionally, a medium-severity suggestion regarding type hinting in lightllm/common/basemodel/batch_objs.py should be considered.
Given these points, I recommend addressing the identified issues. I am not authorized to approve pull requests, so please ensure further review and approval from other maintainers after the changes are made.
| x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" | ||
| prompt_cache_ratio = prompt_cache_len / prompt_tokens | ||
|
|
||
| avg_token_per_step = out_token_counter / (out_token_counter - metadata["mtp_accepted_len"]) |
There was a problem hiding this comment.
The calculation for avg_token_per_step could lead to a ZeroDivisionError if out_token_counter is equal to metadata["mtp_accepted_len"]. This might happen if, for example, all tokens generated so far were accepted speculative tokens, or if out_token_counter is 0.
Could you add a check to prevent division by zero here? For example, you could set avg_token_per_step to a default value (like out_token_counter or float('inf') or None) or skip this calculation if the denominator is zero.
| logits: torch.tensor | ||
| hidden_states: torch.tensor |
There was a problem hiding this comment.
Consider using torch.Tensor for type hints instead of torch.tensor for consistency with PyTorch's official type hinting. While torch.tensor is a function to create tensors, torch.Tensor is the type. This is a minor point but improves consistency.
| logits: torch.tensor | |
| hidden_states: torch.tensor | |
| logits: torch.Tensor | |
| hidden_states: torch.Tensor |
This reverts commit b609b5f.
No description provided.