-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-9735][feat] Add processed logprobs functionality to TorchSampler #9675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/bot run |
|
PR_Github #26827 [ run ] triggered by Bot. Commit: |
|
PR_Github #26827 [ run ] completed with state |
0b21e18 to
d64119a
Compare
|
/bot run |
|
PR_Github #27140 [ run ] triggered by Bot. Commit: |
|
PR_Github #27140 [ run ] completed with state |
4aa1d0c to
bfb2195
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #27296 [ run ] triggered by Bot. Commit: |
ixlmar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed code minus tests.
Re the type annotations, I recommend checking using some type checker (e.g. in the IDE) as long as we don't have that in the CI. Ideally, we could maintain type correctness for sampling_utils.py, flashinfer_sampling_utils.py, and the newer parts of sampler.py.
|
PR_Github #27296 [ run ] completed with state |
bfb2195 to
18531f1
Compare
|
/bot run |
|
PR_Github #27528 [ run ] triggered by Bot. Commit: |
|
PR_Github #27528 [ run ] completed with state |
d66c983 to
5f75078
Compare
|
/bot run --disable-fail-fast |
64e640c to
3ff5537
Compare
| if tokens is not None: | ||
| for t in range(logprobs.size(0)): | ||
| token_id = tokens[t] | ||
| token_logprob = logprobs[t, token_id].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will cause a device sync. Is this really what we want here? Does this function sync before this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is for feature completeness for TRT flow and is not perf tuned. But I think we are considering deprecating the TRT flow so the perf tunning priority is low?
@hchings for confirm.
|
Could you pls run some performance tests and make sure there is no addtional sync in ths sampler part? Which can easily introduce regressions based on recent changes on sampler. |
| return resolve_sampling_strategy(params, vocab_size=vocab_size) | ||
| if not hasattr(request, "py_sampling_strategy") or _get_max_beam_width(request) > 1: | ||
| params = _request_get_sampling_params(request) | ||
| request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed this function changes its previous behavior that don't modify the input request to now modifying it.
Is this really what we want here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is to avoid repeated computation by only compute and store on request.py_sampling_strategy once, and reuse the cached value after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is caching the result. Based on profiling we found that computing this every time incurs unacceptable overhead.
| def _return_log_probs(self, requests: list[LlmRequest]) -> bool: | ||
| return any(req.py_return_log_probs for req in requests) | ||
|
|
||
| def _prepare_log_probs(self, requests: list[LlmRequest]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A style NIT. For function which does not return, I think normally we don't declare -> None, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we write nothing, then it's up to IDE to infer the actual return type. Documentation shows examples that explicitly types -> None: https://docs.python.org/3/library/typing.html#type-aliases
| return_probs: bool, | ||
| group_metadata: StrategyMetadata | None = None, | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], float | None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When shall we use Optional[float] and when to use float | None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's like "west const" vs "east const" - it's a convention issue...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained before I think we are moving to the new grammar but don't want to suddenly change everything.
hchings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tongyuantongyu Could you share some benchmark or Nsight screenshots of the _process_logprobs in the MR description as well? Thanks.
| return resolve_sampling_strategy(params, vocab_size=vocab_size) | ||
| if not hasattr(request, "py_sampling_strategy") or _get_max_beam_width(request) > 1: | ||
| params = _request_get_sampling_params(request) | ||
| request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is to avoid repeated computation by only compute and store on request.py_sampling_strategy once, and reuse the cached value after.
| n: int = 1 | ||
| best_of: Optional[int] = None | ||
| use_beam_search: bool = False | ||
| logprobs_mode: LogprobMode = LogprobMode.RAW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logprobs_mode: LogprobMode = LogprobMode.RAW |
Move close to other logprob parameters.
|
|
||
| # Keep the below fields in sync with tllme.OutputConfig | ||
| logprobs: Optional[int] = None | ||
| prompt_logprobs: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| prompt_logprobs: Optional[int] = None | |
| logprobs_mode: LogprobMode = LogprobMode.RAW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Keep the below fields in sync with tllme.OutputConfig
We don't want to add it to the legacy binding type tllme.OutputConfig.
- Iintroduces a new optional parameter, logprobs_mode, to the SamplingParams and LlmRequest classes, allowing users to specify the mode of log probabilities to return. - Create process_logprobs function to remove logprobs processing code from process_requests. - add batching based on logprobs_mode to sample_batched_by_strategy - additionally return processed logits from sampling Signed-off-by: Stefan Niebler <[email protected]>
fix step remove dependency on return_genereation_logits align API across backends, add tests test fix group sampling strategy Signed-off-by: Erin Ho <[email protected]> Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Erin Ho <[email protected]> Signed-off-by: Yuan Tong <[email protected]>
- Expand test_logits_logprobs to perform a check for processed logprobs - Fix processed logprobs for greedy sampling and when using temperature Signed-off-by: Stefan Niebler <[email protected]>
…chSampler - Updated tensor allocation in TorchSampler to use pinned memory for improved performance during D2H copies. - Modified test_sampled_token_always_in_logprobs to include logprobs_mode parameter for enhanced testing of log probabilities. Signed-off-by: Stefan Niebler <[email protected]>
…lation when not needed - Added LogProbsMode class to define modes for log probabilities: RAW and PROCESSED. - Updated SamplingParams and LlmRequest to utilize LogProbsMode for logprobs_mode parameter. - Enhanced validation to check logprobs_mode against LogProbsMode values. - Modified TorchSampler and related classes to support new logprobs_mode functionality. - Modified TorchSampler to only calculate logprobs when a request needs it - Updated tests to cover new logprobs_mode behavior and ensure correct processing of log probabilities. Signed-off-by: Stefan Niebler <[email protected]>
… logprobs handling - Added max_topk_logprobs parameter to AutoDeployConfig and LlmRequest to control the number of top-k logprobs storable for each token. - Updated TorchSampler to accommodate max_topk_logprobs in logprobs processing and storage. - Enhanced logprobs handling in the sampling process to support both sampled and top-k logprobs. - Enabled batched processing of logprobs to enhance logprobs performance - Modified tests to validate the new max_topk_logprobs functionality and ensure correct logprobs output. Signed-off-by: Stefan Niebler <[email protected]> Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
…gy when using beam search Signed-off-by: Stefan Niebler <[email protected]>
…t indices instead of a boolean tensor Signed-off-by: Stefan Niebler <[email protected]>
…ch.testing.assert_close for improved readability and precision Signed-off-by: Stefan Niebler <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
3ff5537 to
d416a39
Compare

Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.