Skip to content

Conversation

@h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Oct 7, 2025

What does this PR do?

Type of change: New feature

Overview:
TRTLLM recently supports dumping intermediate hidden states:https://github.com/NVIDIA/TensorRT-LLM/pull/7012/files

This PR add example scripts to dump hidden states with TRTLLM.

To solve data format mismatch, we post-process the dumpped file in the script to get exact same format with previously used HF dumper.

Usage

See examples/compute_hidden_states/

Testing

Diff on dumped files

Diff dumped files against previously used HF dummper. Tested on DP=1, DP=4, and TP=1, TP=2:

  • File names matches;
  • File contents: keys and values exactly matches, tensor values allclose;

E2E AR on TinyLlama

On fixed training setting for 50000 steps:

  • with trtllm dumped hiddens:
average    AL     2.82 ( 63558/ 22523)
average    E2E   0.5730    
average    TTFT  0.0106    
average    ITL   0.0041    
average    TPS   700.5946  
  • with hf dumped hidden states:
average    AL     2.83 ( 63558/ 22480)
average    E2E   0.5698    
average    TTFT  0.0104    
average    ITL   0.0041    
average    TPS   705.1892  

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a high-performance utility to extract and save per-conversation hidden states from TensorRT-LLM compatible models, with async generation, progress tracking, conversation validation, length skipping, and per-conversation outputs.
  • Scripts

    • Added single-run and data-parallel launcher scripts for hidden-state extraction; data-parallel runner is configurable via DP_SIZE.
  • Documentation

    • Updated README with hidden-state dumping workflow, backend options, usage notes, and licensing headers.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 7, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 7, 2025

Walkthrough

Adds a TRT-LLM hidden-state extraction Python script, two TRT-LLM runner scripts (single and DP), and parameterizes an existing HF DP runner; implements CLI, tokenizer/model setup, async TRT-LLM generation, per-conversation post-processing, and DP orchestration.

Changes

Cohort / File(s) Summary of Changes
TRT-LLM hidden-state extraction
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
New script: parse_args() and main() added; loads tokenizer/config, reads jsonl conversations, configures TRT-LLM (CUDA graph, KV cache, chunked prefill, tensor/mixture-parallel), enqueues async generations, post-processes TRT-LLM dumps to HF-like format (remove id, map to conversation_id, attach hidden_states), saves per-conversation .pt files, tracks skipped/invalid/success counts, and shuts down engine with a summary.
TRT-LLM runners
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh, examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
New shell scripts: single-run launcher and DP launcher. DP script splits input into parts, sets CUDA_VISIBLE_DEVICES per rank, launches compute_hidden_states_trtllm.py for each rank in background with --dp-rank, waits for completion, and cleans temporary split files.
HF DP parameterization
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
Modified: introduces DP_SIZE variable and replaces hard-coded split count and loop bounds with DP_SIZE-driven logic; retains original splitting and parallel-launch behavior otherwise.
Docs / README
examples/speculative_decoding/README.md
Updated wording and added "Dumping Hidden States to Disk" section describing TRT-LLM and HF backends, example commands, notes on TRT-LLM installation, DP scripts references, training/offline workflow clarifications, and minor typos/phrasing updates.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Script as compute_hidden_states_trtllm.py
  participant Tokenizer
  participant TRTLLM as TRT-LLM Engine
  participant FS as Filesystem

  User->>Script: CLI args (model, input, output, limits, TP, CUDA graph, dp-rank)
  Script->>FS: Read jsonl conversations
  Script->>Tokenizer: Load tokenizer and chat template
  Script->>TRTLLM: Initialize engine (TP, KV cache, CUDA graph, prefill)
  loop For each conversation
    Script->>Script: Validate structure and length
    Script->>Tokenizer: Build prompt from messages
    Script->>TRTLLM: Enqueue async generation (save hidden states)
  end
  loop As tasks complete
    TRTLLM-->>Script: Hidden-state dump (per item)
    Script->>Script: Post-process fields (remove id, rename to conversation_id)
    Script->>FS: Save `conversation_id`.pt
  end
  Script->>TRTLLM: Shutdown
  Script-->>User: Summary (processed, skipped, invalid)
Loading
sequenceDiagram
  autonumber
  participant User
  participant DP as run_trtllm_compute_hiddens_dp.sh
  participant Split as split
  participant Worker as compute_hidden_states_trtllm.py[*]
  participant FS as Filesystem

  User->>DP: Execute with INPUT, OUTPUT, DP_SIZE
  DP->>Split: split INPUT into /tmp/part-*.jsonl (DP_SIZE parts)
  par Parallel workers
    DP->>Worker: Launch with --input-file /tmp/part-0X.jsonl --dp-rank X
    Worker->>FS: Write per-conversation `.pt` outputs
  end
  DP->>DP: wait
  DP->>FS: Cleanup /tmp/part-*.jsonl
  DP-->>User: Exit
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I hop through tensors, swift and bright,
Gathering whispers from each hidden byte.
Split into trails, eight paths unwind,
GPUs hum secrets, layer by kind.
Carrots of files, neatly stowed—my find. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title clearly and concisely describes the main feature addition—a TRTLLM dumper for Eagle’s offline training pipeline—without unnecessary detail or noise, making it straightforward for reviewers to understand the core change at a glance.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch haoguo/eagle-trtllm-dumper

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 self-assigned this Oct 7, 2025
@codecov
Copy link

codecov bot commented Oct 7, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.79%. Comparing base (340eb7a) to head (c784ad0).
⚠️ Report is 4 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #404   +/-   ##
=======================================
  Coverage   73.79%   73.79%           
=======================================
  Files         171      171           
  Lines       17591    17591           
=======================================
  Hits        12982    12982           
  Misses       4609     4609           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 marked this pull request as ready for review October 8, 2025 00:09
@h-guo18 h-guo18 requested a review from a team as a code owner October 8, 2025 00:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (6)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh (1)

20-23: Add shebang for shell portability.

The script is missing a shebang line. While this may work when sourced or run with an explicit shell, adding #!/bin/bash at the top improves portability and clarifies the intended interpreter.

As per static analysis hints

Apply this diff:

+#!/bin/bash
+
 # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh (1)

25-37: Add shebang for shell portability.

The script is missing a shebang line. Adding #!/bin/bash at the top improves portability and clarifies the intended interpreter.

As per static analysis hints

Apply this diff:

+#!/bin/bash
+
 # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (4)

107-110: Consider memory efficiency for large input files.

Loading all conversations into memory at once could be problematic for very large datasets. For production use, consider streaming the file or processing in batches.


161-162: Security: torch.load on untrusted data.

Using torch.load() on files that could potentially be tampered with poses a security risk, as it can execute arbitrary Python code during deserialization. Since this script processes internally-generated TRTLLM output, the risk is low, but consider using torch.load(..., weights_only=True) if available in your PyTorch version for additional safety.


179-180: Redundant file existence check.

Path.unlink() accepts a missing_ok parameter that makes the existence check redundant.

Apply this diff:

-        if trtllm_dumped_file.exists():
-            trtllm_dumped_file.unlink()
+        trtllm_dumped_file.unlink(missing_ok=True)

182-190: Consider adding error handling for generation failures.

The dump_hidden_states function doesn't handle exceptions from generate_async or file operations. If a single conversation fails, it could crash the entire pipeline. Consider wrapping in try-except to log errors and continue processing remaining conversations.

Example:

 async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
     nonlocal num_success
-    await llm_spec.generate_async(input_ids, sampling_params)
-    # TRTLLM API name files starts from 1
-    # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
-    trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
-    _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
-    num_success += 1
-    pbar.update(1)
+    try:
+        await llm_spec.generate_async(input_ids, sampling_params)
+        # TRTLLM API name files starts from 1
+        # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
+        trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
+        _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
+        num_success += 1
+    except Exception as e:
+        print(f"Error processing conversation {conversation_id}: {e}")
+    finally:
+        pbar.update(1)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 340eb7a and 450c773.

📒 Files selected for processing (4)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh (1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🔇 Additional comments (5)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)

27-31: LGTM! Clean parameterization improves maintainability.

The introduction of DP_SIZE variable and its use throughout the script eliminates magic numbers and makes scaling easier.

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (4)

16-35: LGTM! Proper environment setup and imports.

Setting TLLM_LOG_LEVEL before importing tensorrt_llm is the correct approach to control logging verbosity.


137-137: Verify layer selection works for all model sizes.

The hardcoded layer selection {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4} may not work correctly for small models (e.g., models with fewer than 5 layers). Consider adding validation.

Apply this diff to add validation:

+    if num_hidden_layers < 5:
+        raise ValueError(f"Model has only {num_hidden_layers} layers, need at least 5 for layer selection")
+    
     spec_config = {
         "output_directory": str(args.output_dir),
         "write_interval": 1,
         "file_prefix": f"dp_{args.dp_rank}",
         "eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
     }

196-196: Handle None slice correctly.

When args.debug_max_num_conversations is None, the slice all_conversations[:None] works correctly but is implicit. The current code is correct due to Python's slice behavior, but adding explicit handling would improve clarity.


215-234: LGTM! Proper async execution and cleanup.

The async execution, LLM shutdown, and summary reporting are correctly implemented.

Comment on lines +90 to +94
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix boolean argument handling.

Using type=bool for boolean flags in argparse is problematic. Any non-empty string (including "False" or "0") will be interpreted as True. Use action='store_true' or action='store_false' instead.

Apply this diff:

     parser.add_argument(
         "--use-cuda-graph",
-        type=bool,
-        default=True,
+        action='store_true',
+        default=False,
         help="""Whether to use CUDA graph.""",
     )

Or if you want it enabled by default:

     parser.add_argument(
-        "--use-cuda-graph",
-        type=bool,
-        default=True,
-        help="""Whether to use CUDA graph.""",
+        "--no-cuda-graph",
+        action='store_false',
+        dest='use_cuda_graph',
+        help="""Disable CUDA graph (enabled by default).""",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--no-cuda-graph",
action='store_false',
dest='use_cuda_graph',
help="""Disable CUDA graph (enabled by default).""",
)
Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--use-cuda-graph",
action='store_true',
default=False,
help="""Whether to use CUDA graph.""",
)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 90 to 94, the argparse boolean flag currently uses type=bool which
treats any non-empty string as True; change the argument to use
action='store_true' (or action='store_false' if you prefer the inverse) so the
flag is parsed correctly, and set default accordingly (omit type and use
default=True/False as needed or use action='store_false' to have it enabled by
default).

@h-guo18 h-guo18 force-pushed the haoguo/eagle-trtllm-dumper branch from d69a1c3 to 067bc1a Compare October 8, 2025 00:27
Signed-off-by: h-guo18 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 450c773 and dfd0c62.

📒 Files selected for processing (2)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh (1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Comment on lines +1 to +30
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.

# Note: depending on the write-throughput of the destination disk, this is not guaranteed
# to yield a speed improvement compared to running the model-parallel version. Consider
# benchmarking on a smaller dataset before launching a large run.

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
DP_SIZE=8
MODEL=meta-llama/Llama-3.2-1B-Instruct
export TLLM_LOG_LEVEL="error";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add a shebang to define the interpreter.

Same as the single-run script: make Bash explicit so the runtime shell is unambiguous and SC2148 is silenced.

+#!/usr/bin/env bash
+
🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
around lines 1-30, the script lacks a shebang; add a Bash shebang as the very
first line (e.g., use env to locate bash) so the interpreter is explicit and
shellcheck SC2148 is silenced, then save and ensure the script remains
executable (chmod +x) if needed.

Comment on lines +36 to +41
export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &

# #On SLURM:
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i

done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix suffix formatting for DP ranks ≥10.

Line 36 hardcodes /tmp/part-0${i}.jsonl, which works only for single-digit ranks; once DP_SIZE reaches 10 the actual file is /tmp/part-10.jsonl, so the launch fails. Format the suffix with printf to stay aligned with split output.

-for i in $(seq 0 $((DP_SIZE-1)))
-do
-
-export CUDA_VISIBLE_DEVICES=$i;  python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
+for i in $(seq 0 $((DP_SIZE-1)))
+do
+  suffix=$(printf "%02d" "$i")
+  CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py \
+    --model "$MODEL" \
+    --input-file "/tmp/part-${suffix}.jsonl" \
+    --output-dir "$OUTPUT_DIR" \
+    --dp-rank "$i" &
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
# #On SLURM:
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i
done
for i in $(seq 0 $((DP_SIZE-1)))
do
suffix=$(printf "%02d" "$i")
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model "$MODEL" \
--input-file "/tmp/part-${suffix}.jsonl" \
--output-dir "$OUTPUT_DIR" \
--dp-rank "$i" &
# #On SLURM:
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i
done
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
around lines 36-41, the input file path is hardcoded as /tmp/part-0${i}.jsonl
which breaks for DP ranks >= 10; replace that construction with a printf-based
filename (e.g. use $(printf "/tmp/part-%s.jsonl" "$i") or similar) so the
generated suffix matches split's output for multi-digit ranks, and update the
python invocation to use that formatted filename.

Comment on lines +1 to +24
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.

export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add a shebang to define the interpreter.

Without a shebang the script runs under whatever /bin/sh points to, which can diverge from the Bash semantics used here and also trips ShellCheck (SC2148). Please make Bash explicit.

+#!/usr/bin/env bash
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.
export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
#!/usr/bin/env bash
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.
export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
lines 1-24: the script lacks a shebang so it may run under /bin/sh instead of
bash; add a bash shebang (e.g. #!/usr/bin/env bash) as the first line and ensure
the file is executable (chmod +x) so it runs with Bash semantics and avoids
ShellCheck SC2148.

@h-guo18 h-guo18 requested a review from ChenhanYu October 8, 2025 01:59
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dfd0c62 and c784ad0.

📒 Files selected for processing (2)
  • examples/speculative_decoding/README.md (3 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)
modelopt/deploy/llm/generate.py (2)
  • LLM (53-291)
  • max_seq_len (154-156)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)

125-270: LGTM! Well-structured async processing pipeline.

The implementation correctly handles:

  • Conversation deduplication by checking existing output files
  • Tokenizer setup with pad token fallback and chat template adjustment
  • TRT-LLM configuration with CUDA graph, KV cache, and parallelism settings
  • Async task orchestration with proper progress tracking
  • Post-processing to match HF dump format
  • Comprehensive error reporting

The async pattern efficiently batches hidden-state extraction across multiple conversations.

Comment on lines +195 to +221
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)

if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Correct the type hint for trtllm_dumped_file parameter.

The type hint declares trtllm_dumped_file: str, but the function calls .exists() (line 220) and .unlink() (line 221), which are Path methods. At the call site (line 229), a Path object is passed. Update the type hint to Path for consistency.

Apply this diff:

-    def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
+    def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 195 to 221, the parameter trtllm_dumped_file is annotated as str
but the function uses Path methods (.exists(), .unlink()) and is called with a
Path at the call site; update the type hint to Path (from pathlib) for
trtllm_dumped_file to match usage and callers, and add an import for Path at the
top if not already present.

For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size.

First, dump the base model's hidden states with the following command:
### Dumpping Hidden States to Disk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo in heading.

"Dumpping" should be "Dumping".

Apply this diff:

-### Dumpping Hidden States to Disk
+### Dumping Hidden States to Disk
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
### Dumpping Hidden States to Disk
### Dumping Hidden States to Disk
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 87, the heading contains
a typo "Dumpping Hidden States to Disk"; change "Dumpping" to "Dumping" so the
heading reads "Dumping Hidden States to Disk".

First, dump the base model's hidden states with the following command:
### Dumpping Hidden States to Disk

We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo.

"effciency" should be "efficiency".

Apply this diff:

-We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
+We support two backends for generating base model hidden states. For better efficiency, it is recommended to use TRT-LLM:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
We support two backends for generating base model hidden states. For better efficiency, it is recommended to use TRT-LLM:
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 89, there is a typo:
change the word "effciency" to "efficiency" so the sentence reads "For better
efficiency, it is recommended to use TRT-LLM:"; update that single word in the
line.

Comment on lines +92 to +96
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--output-dir $HIDDEN_STATES_DIR
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove trailing space after line continuation.

Line 93 has a trailing space after the backslash, which breaks bash line continuation. The backslash must be immediately followed by a newline.

Apply this diff:

 python collect_hidden_states/compute_hidden_states_trtllm.py \
-            --model $BASE_MODEL \ 
+            --model $BASE_MODEL \
             --input-file Daring-Anteater/train.jsonl \
             --output-dir $HIDDEN_STATES_DIR
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--output-dir $HIDDEN_STATES_DIR
```
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--output-dir $HIDDEN_STATES_DIR
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around lines 92 to 96, the bash line
continuation backslash on line 93 has a trailing space which breaks the
continuation; remove the trailing space so the backslash is the last character
on the line (and ensure the line ends with a newline, not CRLF) to restore
proper shell line continuation.

```

Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix grammar.

"Offline checkpoints does not support" should be "Offline checkpoints do not support" (plural subject requires plural verb).

Apply this diff:

-Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
+Offline checkpoints do not support this evaluation due to missing base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
Offline checkpoints do not support this evaluation due to missing base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 133, the sentence
"Offline checkpoints does not support this evaluation due to missing of base
model modules." has a subject-verb agreement error; change "Offline checkpoints
does not support" to "Offline checkpoints do not support" so the plural subject
matches the plural verb (optionally also remove "of" to read "due to missing
base model modules").

@h-guo18 h-guo18 merged commit abed33c into main Oct 8, 2025
27 checks passed
@h-guo18 h-guo18 deleted the haoguo/eagle-trtllm-dumper branch October 8, 2025 03:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants