Skip to content

Conversation

@stnie
Copy link
Collaborator

@stnie stnie commented Jan 6, 2026

Summary by CodeRabbit

  • Enhancements

    • Improved per-request sequence length tracking and validation during token generation.
    • Enhanced end-of-sequence detection with more accurate per-request end-token handling.
    • Refined finish-reason computation to better respect individual request constraints.
  • Tests

    • Updated test suite to validate per-request length constraints in sampling workflow.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Adjust _are_end_id and _are_max_length function in _write_finish_reasons to omit host to device communication, by storing the end_ids and max_lengths information of each request in a device buffer. This information is only stored once per request as it is constant over time and therefore does not need to be gathered from the request every iteration.
This change is intended to improve overall performance of _write_finish_reasons and sample_async

Performance Results (Preliminary)

ToT:
Request / second (TinyLlama-1.1B, ISL 1024, OSL 2048, BatchSize 512)

  • Greedy 20.1982
  • TopK 18.8219
  • TopP 18.6185
  • TopKTopP 20.2469

This PR:

  • Greedy 20.4853
  • TopK 19.3886
  • TopP 19.2565
  • TopKTopP 20.2168

NsightSystems measurements:
average runtime in ms (TinyLlama-1.1B, ISL 1024, OSL 2048, BatchSize 512)
ToT:

  • write_finish_reasons 1.832ms
  • sample_async: 7.385ms
  • update_requests: 5.168ms

This PR:

  • write_finish_reasons 0.517ms
  • sample_async: 6.445ms
  • update_requests: 5.753ms (Slowdown is caused by longer cudaEventSynchronize)

Test Coverage

The general functionality did not change, therefore the following tests continue to cover the functionality.

  • test_torch_sampler.py::test_write_finish_reasons
  • test_torch_sampler.py::test_are_stop_words_isnt_called_when_no_stop_words

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@stnie
Copy link
Collaborator Author

stnie commented Jan 6, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30754 [ run ] triggered by Bot. Commit: 060564a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30754 [ run ] completed with state SUCCESS. Commit: 060564a
/LLM/main/L0_MergeRequest_PR pipeline #23737 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@zhenhuaw-me zhenhuaw-me added the Decoding/Sampling <NV>Token sampling algorithms in TRTLLM for text gen (top-k, top-p, beam). label Jan 7, 2026
@stnie stnie force-pushed the develop/sampler/write_finish_reasons_perf branch from 060564a to 1f9cd59 Compare January 7, 2026 09:51
@stnie
Copy link
Collaborator Author

stnie commented Jan 7, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30888 [ run ] triggered by Bot. Commit: 1f9cd59

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30888 [ run ] completed with state SUCCESS. Commit: 1f9cd59
/LLM/main/L0_MergeRequest_PR pipeline #23850 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@stnie stnie force-pushed the develop/sampler/write_finish_reasons_perf branch from 89264aa to 16bdf59 Compare January 8, 2026 14:00
@stnie
Copy link
Collaborator Author

stnie commented Jan 8, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31071 [ run ] triggered by Bot. Commit: 16bdf59

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31071 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 6 AM PST on 12/29.

stnie added 4 commits January 9, 2026 09:32
- Added `max_lengths_tensor` to the `Store` class to cache max_length per request.
- Introduced `_is_new_request` method to streamline request state checks.
- Updated `setup_sampler_step` to fill `max_lengths_tensor` based on request parameters.
- Refactored `_are_max_length` for improved token processing on device.

Signed-off-by: Stefan Niebler <[email protected]>
- Introduced `end_ids` tensor in the `Store` class to store end IDs for each request.
- Updated `setup_sampler_step` to fill `end_ids` based on request parameters.
- Refactored `_are_end_id` method to utilize the new `end_ids` tensor for better performance.

Signed-off-by: Stefan Niebler <[email protected]>
…usted override of setup_sampler_step

Signed-off-by: Stefan Niebler <[email protected]>
@stnie stnie force-pushed the develop/sampler/write_finish_reasons_perf branch from 16bdf59 to 570094a Compare January 9, 2026 08:32
@stnie
Copy link
Collaborator Author

stnie commented Jan 9, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31228 [ run ] triggered by Bot. Commit: 570094a

@stnie stnie changed the title [TRTLLM-9687][perf] Improve performance of _write_finish_reasons in TorchSampler [TRTLLM-10312][perf] Improve performance of _write_finish_reasons in TorchSampler Jan 9, 2026
@tensorrt-cicd
Copy link
Collaborator

PR_Github #31228 [ run ] completed with state SUCCESS. Commit: 570094a
/LLM/main/L0_MergeRequest_PR pipeline #24133 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@stnie
Copy link
Collaborator Author

stnie commented Jan 12, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31530 [ run ] triggered by Bot. Commit: 570094a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31530 [ run ] completed with state SUCCESS. Commit: 570094a
/LLM/main/L0_MergeRequest_PR pipeline #24374 completed with status: 'SUCCESS'

@stnie stnie marked this pull request as ready for review January 12, 2026 12:04
@stnie stnie requested review from a team as code owners January 12, 2026 12:04
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

The changes introduce per-request buffer tracking to TorchSampler by adding max_lengths_tensor and end_ids fields to the Store data class. These enable sequence length and end-ID comparisons during finish-reason computation. A new _is_new_request helper identifies when to initialize per-request buffers, and method signatures are updated to propagate sequence length information through sampling paths.

Changes

Cohort / File(s) Summary
Core Sampler Modifications
tensorrt_llm/_torch/pyexecutor/sampler.py
Added max_lengths_tensor and end_ids fields to Store dataclass for per-request tracking. Introduced _is_new_request() helper to detect new requests. Refactored _write_finish_reasons() to accept seq_lens parameter and updated internal length/end-ID comparison helpers (_are_max_length, _are_end_id) to operate on explicit tensors. Updated finish-reason writing and beam-search preparation logic to propagate sequence length information and manage per-request buffers.
Speculative Sampling Compatibility
tensorrt_llm/_torch/speculative/mtp.py
Added optional end_ids and max_lengths_tensor fields to MTPSampler.Store dataclass for API compatibility. Added placeholder setup_sampler_step() hook method.
Test Updates
tests/unittest/_torch/sampler/test_torch_sampler.py
Updated test fixture to construct seq_lens tensor and assign max_seq_lens and end_ids to sampler.store. Modified _write_finish_reasons() call to include new seq_lens parameter in finish-reason writing test cases.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: improving performance of _write_finish_reasons in TorchSampler by reducing host-to-device communication.
Description check ✅ Passed The PR description clearly explains the objective, motivation, performance improvements, test coverage, and follows the template structure with all required sections present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

1339-1355: Per-request buffer initialization is functionally correct.

The logic correctly initializes max_lengths_tensor and end_ids for new requests. However, based on learnings about performance in this file, the current loop with per-element fill_() calls could be batched for better efficiency when there are many new requests.

Consider batching the initialization:

♻️ Optional optimization for batched initialization
     @override
     def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
         """Setup the sampler step for the requests

         Args:
             requests: list[LlmRequest]. The requests to setup the sampler step for
         """
         if self._use_beam_search:
             self._prepare_beam_search(scheduled_requests.all_requests())
+        new_requests = [r for r in scheduled_requests.all_requests() if self._is_new_request(r)]
+        if new_requests:
+            seq_slots = torch.tensor([r.py_seq_slot for r in new_requests], dtype=torch.int64, device="cuda")
+            max_lengths = torch.tensor(
+                [min(self.max_seq_len, r.orig_prompt_len + r.py_max_new_tokens) for r in new_requests],
+                dtype=torch.int32, device="cuda"
+            )
+            end_ids_vals = torch.tensor(
+                [r.py_end_id if r.py_end_id is not None else -1 for r in new_requests],
+                dtype=torch.int32, device="cuda"
+            )
+            self.store.max_lengths_tensor[seq_slots] = max_lengths
+            self.store.end_ids[seq_slots] = end_ids_vals
-        for request in scheduled_requests.all_requests():
-            if self._is_new_request(request):
-                self.store.max_lengths_tensor[request.py_seq_slot].fill_(
-                    min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
-                )
-                self.store.end_ids[request.py_seq_slot].fill_(
-                    request.py_end_id if request.py_end_id is not None else -1
-                )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a09acd and 570094a.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tests/unittest/_torch/sampler/test_torch_sampler.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g., some_file.py)
Python classes should use PascalCase (e.g., class SomeClass)
Python functions and methods should use snake_case (e.g., def my_awesome_function():)
Python local variables should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL)
Python constants should use upper snake_case (e.g., MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format """<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic

Files:

  • tests/unittest/_torch/sampler/test_torch_sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Files:

  • tests/unittest/_torch/sampler/test_torch_sampler.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (7)
📓 Common learnings
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:18.859Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, when reviewing code that iterates through requests, ensure it does not convert excessive data into Python lists. Instead, the code should use torch.gather or indexing to gather only the data that will be used in the for loop before converting to Python lists. This minimizes data movement and improves performance.
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:18.859Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, when reviewing code that iterates through requests, ensure it does not access torch.Tensor objects (CPU or GPU) inside the loop. Instead, the code should use .tolist() to convert batched data tensors to Python lists beforehand, and then access the list in the for loop. This is a critical performance consideration.
📚 Learning: 2025-12-12T03:27:18.859Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:18.859Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, when reviewing code that iterates through requests, ensure it does not convert excessive data into Python lists. Instead, the code should use torch.gather or indexing to gather only the data that will be used in the for loop before converting to Python lists. This minimizes data movement and improves performance.

Applied to files:

  • tests/unittest/_torch/sampler/test_torch_sampler.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-12-12T03:27:18.859Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:18.859Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, when reviewing code that iterates through requests, ensure it does not access torch.Tensor objects (CPU or GPU) inside the loop. Instead, the code should use .tolist() to convert batched data tensors to Python lists beforehand, and then access the list in the for loop. This is a critical performance consideration.

Applied to files:

  • tests/unittest/_torch/sampler/test_torch_sampler.py
📚 Learning: 2025-12-12T03:27:08.565Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:08.565Z
Learning: In files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist(), and then iterate over those lists. This improves performance by reducing tensor-bound operations inside hot loops. Apply this pattern to similar code paths that process batches to access simple Python data structures (lists) inside loops.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
Repo: NVIDIA/TensorRT-LLM PR: 6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-18T08:42:02.640Z
Learnt from: samuellees
Repo: NVIDIA/TensorRT-LLM PR: 6974
File: tensorrt_llm/serve/scripts/benchmark_dataset.py:558-566
Timestamp: 2025-08-18T08:42:02.640Z
Learning: In TensorRT-LLM's RandomDataset (tensorrt_llm/serve/scripts/benchmark_dataset.py), when using --random-token-ids option, sequence length accuracy is prioritized over semantic correctness for benchmarking purposes. The encode/decode operations should use skip_special_tokens=True and add_special_tokens=False to ensure exact target token lengths.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: 2025-08-28T10:22:02.288Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:1191-1197
Timestamp: 2025-08-28T10:22:02.288Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the object identity comparison `softmax_req_indices is not group_req_indices_cuda` on line ~1191 is intentional and used as an optimization to determine whether to reuse an existing indexer or create a new one, based on which code path was taken during tensor assignment.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (2)
tests/unittest/_torch/sampler/test_torch_sampler.py (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
  • _write_finish_reasons (2464-2535)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • LlmRequest (462-691)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • ScheduledRequests (21-42)
🔇 Additional comments (10)
tensorrt_llm/_torch/speculative/mtp.py (1)

225-236: LGTM!

The interface-compliance fields (finish_reasons, end_ids, max_lengths_tensor) with None defaults and the no-op setup_sampler_step override correctly satisfy the TorchSampler.Store interface without adding unnecessary logic, since MTPSampler handles finish reasons through its own path in update_requests.

tests/unittest/_torch/sampler/test_torch_sampler.py (1)

694-728: LGTM!

The test setup correctly initializes the new per-request buffers:

  • seq_lens from max_beam_num_tokens represents current sequence lengths
  • max_seq_lens calculation mirrors the sampler logic for maximum allowed sequence lengths
  • end_ids with -1 sentinel for None end IDs

This properly exercises the updated _write_finish_reasons interface with device-resident per-request data.

tensorrt_llm/_torch/pyexecutor/sampler.py (8)

835-841: LGTM!

The new max_lengths_tensor and end_ids fields are well-documented with clear docstrings describing their shape and usage, following Google-style conventions.


890-898: LGTM!

The new tensors are correctly allocated with consistent shape (max_num_sequences,) in both the beam-search and non-beam-search paths.


1329-1337: LGTM!

The _is_new_request helper correctly identifies when to initialize per-request buffers by:

  1. Excluding finished and draft requests
  2. Handling both chunked context (last chunk) and disaggregated generation scenarios

This aligns with the commit message about preventing draft requests from modifying max_lengths_tensor.


1367-1367: LGTM!

Good refactoring to use the centralized _is_new_request helper, ensuring consistent logic across beam search preparation and buffer initialization.


1877-1899: LGTM!

The seq_lens tensor is correctly constructed from max_beam_num_tokens, transferred to CUDA, and passed through to _write_finish_reasons for on-device finish reason computation. This aligns with the PR's goal of reducing host-to-device communication.


2486-2524: LGTM!

The refactored _write_finish_reasons now:

  1. Properly validates device consistency with assertions
  2. Uses pre-stored max_lengths_tensor and end_ids instead of per-request iteration
  3. Keeps finish reason computation entirely on device

This achieves the PR's performance improvement goal by eliminating repeated host-to-device reads.


2537-2538: LGTM!

The refactored _are_end_id performs an efficient batched comparison using tensor broadcasting. The end_ids tensor is correctly reshaped from (batch_size,) to (1, batch_size, 1) and broadcast-compared against tokens of shape (max_tokens, batch_size, max_beam_width).


2540-2559: LGTM!

The refactored _are_max_length correctly computes whether each position reaches max length:

  • lengths_tensor adds step offsets (1 through max_tokens) to current sequence lengths
  • Comparison >= max_lengths_tensor properly identifies when sequences reach their limit
  • All computation stays on device, achieving the performance improvement goal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Decoding/Sampling <NV>Token sampling algorithms in TRTLLM for text gen (top-k, top-p, beam).

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants