-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Implement sampling for MTP 1-model #10019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
281a63f to
9a94402
Compare
|
/bot run |
📝 WalkthroughWalkthroughThis change introduces a new abstract base class Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~40 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/interface.py (1)
24-38: Guard against negative values inTLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS
get_force_num_accepted_tokens()accepts any integer, including negatives, andforce_num_accepted_tokensis later used to overwritenum_accepted_tokens. A negative override would be nonsensical and could lead to invalid lengths or indexing in downstream logic.Consider clamping to non‑negative values (e.g., treat
< 0as0with a warning) instead of accepting arbitrary ints:- try: - return int(env_value) + try: + value = int(env_value) + if value < 0: + logger.warning( + f"{FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR} must be non-negative, " + f"got '{env_value}'. Using default value 0.") + return 0 + return value
🧹 Nitpick comments (3)
tensorrt_llm/_torch/speculative/eagle3.py (1)
546-559: Clarifydraft_decoderdocstring vs actual behaviorThe docstring now says “Sampling draft tokens with support for non-greedy sampling.” but the implementation still does a plain
torch.argmaxover logits; non‑greedy sampling is handled earlier for target tokens via the shared sampler, not here.Consider rephrasing the docstring to avoid implying that this method itself performs non‑greedy sampling (or document explicitly that draft tokens remain greedy).
tensorrt_llm/_torch/speculative/interface.py (1)
361-420: Add invariants / fallbacks around advanced sampling in_sample_tokens_for_batchThe advanced path assumes:
spec_metadata.temperatures,top_ks,top_psare initialized and long enough, andnum_tokens = num_contexts + num_gens * (self.max_draft_len + 1)matcheslogits.shape[0].If upstream code forgets to call
populate_sampling_params_for_one_modelor changes the logits layout, this will fail with hard‑to‑debug runtime errors.Consider:
- Adding a sanity check on shapes (e.g.,
assert logits.shape[0] == num_tokens) in debug builds, and- Either asserting or gracefully falling back to greedy sampling when any of
temperatures/top_ks/top_psisNoneor too short.This keeps the base class robust against future changes in caller assumptions.
tensorrt_llm/_torch/speculative/mtp.py (1)
761-762: Outdated docstring: sampling is no longer limited to greedy.The docstring states "Currently only support greedy sampling" but with the new
_sample_tokens_for_batchintegration, advanced sampling (temperature, top-k, top-p) is now supported when enabled viaspec_metadata.allow_advanced_sampling.Currently only support greedy sampling. All decoding is done using Top1 and token equality is used - for acceptance. + for acceptance when advanced sampling is disabled. When enabled, temperature, top-k, and top-p + sampling are also supported.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tensorrt_llm/_torch/speculative/__init__.py(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py(3 hunks)tensorrt_llm/_torch/speculative/interface.py(2 hunks)tensorrt_llm/_torch/speculative/mtp.py(3 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., usefrom package.subpackage import fooand thenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g.,some_file.py)
Python class names should use PascalCase (e.g.,class SomeClass)
Python function and method names should use snake_case (e.g.,def my_awesome_function():)
Python local variable names should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile = ...)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g.,MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g.,self.x = 5followed by"""<type>: Description of 'x'""")
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic
Files:
tensorrt_llm/_torch/speculative/__init__.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/speculative/interface.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top
Files:
tensorrt_llm/_torch/speculative/__init__.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/speculative/interface.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/speculative/eagle3.py
🧠 Learnings (3)
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.
Applied to files:
tests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.
Applied to files:
tests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tests/integration/defs/accuracy/test_llm_api_pytorch.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/speculative/__init__.py (1)
tensorrt_llm/_torch/speculative/interface.py (2)
SpecMetadata(183-358)SpecWorkerBase(361-420)
tensorrt_llm/_torch/speculative/mtp.py (2)
tensorrt_llm/_torch/speculative/interface.py (3)
SpecMetadata(183-358)max_draft_len(374-378)_sample_tokens_for_batch(385-420)tensorrt_llm/_torch/speculative/eagle3.py (1)
max_draft_len(367-368)
tensorrt_llm/_torch/speculative/interface.py (5)
tensorrt_llm/_torch/pyexecutor/guided_decoder.py (1)
CapturableGuidedDecoder(420-569)tensorrt_llm/_torch/speculative/eagle3.py (1)
max_draft_len(367-368)tensorrt_llm/_torch/speculative/mtp.py (1)
max_draft_len(359-360)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
set_guided_decoder(478-485)tensorrt_llm/_torch/speculative/one_model_sampler.py (1)
sampling_batch_spec_dec_one_model(76-91)
tensorrt_llm/_torch/speculative/eagle3.py (2)
tensorrt_llm/_torch/speculative/interface.py (2)
SpecMetadata(183-358)max_draft_len(374-378)tensorrt_llm/_torch/speculative/mtp.py (1)
max_draft_len(359-360)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (6)
tensorrt_llm/_torch/speculative/__init__.py (1)
3-3: ExposeSpecWorkerBasein public speculative APIImporting
SpecWorkerBasefrom.interfaceand adding it to__all__is consistent with the new shared worker abstraction and keeps the public surface coherent withSpecMetadata. Looks good.Also applies to: 22-22
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
1343-1348: Wire sampling params only when MTP is enabledConditionally building
mtp_configandsampling_paramsformtp_nextn > 0and passingsampling_paramsintoGSM8K.evaluatekeeps the non‑MTP path unchanged while exercising the new 1‑model sampling behavior. This looks correct and consistent with nearby tests.Also applies to: 1357-1357
tensorrt_llm/_torch/speculative/eagle3.py (1)
14-15: RefactorEagle3OneModelWorkerto useSpecWorkerBaseSwitching
Eagle3OneModelWorkerto inherit fromSpecWorkerBase, callingsuper().__init__(), and exposingmax_draft_lenviaspec_configcleanly aligns Eagle3 1‑model with the shared speculative worker interface (common sampling, guided decoder handling). The refactor is coherent and low risk.Also applies to: 359-369
tensorrt_llm/_torch/speculative/mtp.py (3)
18-18: LGTM!Import correctly updated to bring in
SpecMetadataandSpecWorkerBasefrom the interface module, following the namespace import convention.
350-360: LGTM!The refactoring to inherit from
SpecWorkerBaseand the newmax_draft_lenproperty implementation are correct. The property returnsnum_nextn_predict_layers, which aligns with how this value is used throughout the file (asmtp_num_modules) and matches the expected interface from the base class.
892-893: Correct integration with the base class sampling method.The change to use
_sample_tokens_for_batchenables both greedy and advanced sampling (temperature, top-k, top-p) based onspec_metadata.allow_advanced_sampling.
|
PR_Github #28448 [ run ] triggered by Bot. Commit: |
9a94402 to
dabe1c4
Compare
|
/bot run |
|
PR_Github #28454 [ run ] triggered by Bot. Commit: |
dabe1c4 to
f62aa30
Compare
9612a78 to
69e0391
Compare
|
/bot run |
|
PR_Github #28464 [ run ] triggered by Bot. Commit: |
|
PR_Github #28464 [ run ] completed with state |
85f5295 to
5c88b34
Compare
|
/bot run |
|
PR_Github #28600 [ run ] triggered by Bot. Commit: |
|
PR_Github #28600 [ run ] completed with state
|
5c88b34 to
0f2144c
Compare
|
/bot run |
|
PR_Github #28604 [ run ] triggered by Bot. Commit: |
|
PR_Github #28604 [ run ] completed with state
|
0f2144c to
f00ff93
Compare
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #28823 [ run ] triggered by Bot. Commit: |
|
PR_Github #28823 [ run ] completed with state
|
|
/bot run |
|
PR_Github #29164 [ run ] triggered by Bot. Commit: |
|
PR_Github #29164 [ run ] completed with state
|
Signed-off-by: Mike Iovine <[email protected]>
b184cc8 to
a026a9e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #30193 [ run ] triggered by Bot. Commit: |
|
PR_Github #30193 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #30291 [ run ] triggered by Bot. Commit: |
|
PR_Github #30291 [ run ] completed with state |
Description
Add sampling support to MTP 1-model. Same approach as EAGLE3.
Also refactored a few things to avoid code duplication. Introduced a new
SpecWorkerBaseto facilitate the reuse.Test Coverage
Made existing DSV3 Lite accuracy tests use sampling params when MTP is enabled.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.