- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
Feat: TRTLLM Dumper for Eagle Offline Training #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: h-guo18 <[email protected]>
| 
           Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here.  | 
    
          
WalkthroughAdds a TRT-LLM hidden-state extraction Python script, two TRT-LLM runner scripts (single and DP), and parameterizes an existing HF DP runner; implements CLI, tokenizer/model setup, async TRT-LLM generation, per-conversation post-processing, and DP orchestration. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant User
  participant Script as compute_hidden_states_trtllm.py
  participant Tokenizer
  participant TRTLLM as TRT-LLM Engine
  participant FS as Filesystem
  User->>Script: CLI args (model, input, output, limits, TP, CUDA graph, dp-rank)
  Script->>FS: Read jsonl conversations
  Script->>Tokenizer: Load tokenizer and chat template
  Script->>TRTLLM: Initialize engine (TP, KV cache, CUDA graph, prefill)
  loop For each conversation
    Script->>Script: Validate structure and length
    Script->>Tokenizer: Build prompt from messages
    Script->>TRTLLM: Enqueue async generation (save hidden states)
  end
  loop As tasks complete
    TRTLLM-->>Script: Hidden-state dump (per item)
    Script->>Script: Post-process fields (remove id, rename to conversation_id)
    Script->>FS: Save `conversation_id`.pt
  end
  Script->>TRTLLM: Shutdown
  Script-->>User: Summary (processed, skipped, invalid)
    sequenceDiagram
  autonumber
  participant User
  participant DP as run_trtllm_compute_hiddens_dp.sh
  participant Split as split
  participant Worker as compute_hidden_states_trtllm.py[*]
  participant FS as Filesystem
  User->>DP: Execute with INPUT, OUTPUT, DP_SIZE
  DP->>Split: split INPUT into /tmp/part-*.jsonl (DP_SIZE parts)
  par Parallel workers
    DP->>Worker: Launch with --input-file /tmp/part-0X.jsonl --dp-rank X
    Worker->>FS: Write per-conversation `.pt` outputs
  end
  DP->>DP: wait
  DP->>FS: Cleanup /tmp/part-*.jsonl
  DP-->>User: Exit
    Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
 Pre-merge checks and finishing touches❌ Failed checks (1 warning)
 ✅ Passed checks (2 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment   | 
    
          Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@           Coverage Diff           @@
##             main     #404   +/-   ##
=======================================
  Coverage   73.79%   73.79%           
=======================================
  Files         171      171           
  Lines       17591    17591           
=======================================
  Hits        12982    12982           
  Misses       4609     4609           ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (6)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh (1)
20-23: Add shebang for shell portability.The script is missing a shebang line. While this may work when sourced or run with an explicit shell, adding
#!/bin/bashat the top improves portability and clarifies the intended interpreter.As per static analysis hints
Apply this diff:
+#!/bin/bash + # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh (1)
25-37: Add shebang for shell portability.The script is missing a shebang line. Adding
#!/bin/bashat the top improves portability and clarifies the intended interpreter.As per static analysis hints
Apply this diff:
+#!/bin/bash + # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (4)
107-110: Consider memory efficiency for large input files.Loading all conversations into memory at once could be problematic for very large datasets. For production use, consider streaming the file or processing in batches.
161-162: Security: torch.load on untrusted data.Using
torch.load()on files that could potentially be tampered with poses a security risk, as it can execute arbitrary Python code during deserialization. Since this script processes internally-generated TRTLLM output, the risk is low, but consider usingtorch.load(..., weights_only=True)if available in your PyTorch version for additional safety.
179-180: Redundant file existence check.
Path.unlink()accepts amissing_okparameter that makes the existence check redundant.Apply this diff:
- if trtllm_dumped_file.exists(): - trtllm_dumped_file.unlink() + trtllm_dumped_file.unlink(missing_ok=True)
182-190: Consider adding error handling for generation failures.The
dump_hidden_statesfunction doesn't handle exceptions fromgenerate_asyncor file operations. If a single conversation fails, it could crash the entire pipeline. Consider wrapping in try-except to log errors and continue processing remaining conversations.Example:
async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]): nonlocal num_success - await llm_spec.generate_async(input_ids, sampling_params) - # TRTLLM API name files starts from 1 - # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012 - trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt" - _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) - num_success += 1 - pbar.update(1) + try: + await llm_spec.generate_async(input_ids, sampling_params) + # TRTLLM API name files starts from 1 + # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012 + trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt" + _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) + num_success += 1 + except Exception as e: + print(f"Error processing conversation {conversation_id}: {e}") + finally: + pbar.update(1)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh(1 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh(1 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh(1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
🔇 Additional comments (5)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)
27-31: LGTM! Clean parameterization improves maintainability.The introduction of
DP_SIZEvariable and its use throughout the script eliminates magic numbers and makes scaling easier.examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (4)
16-35: LGTM! Proper environment setup and imports.Setting
TLLM_LOG_LEVELbefore importingtensorrt_llmis the correct approach to control logging verbosity.
137-137: Verify layer selection works for all model sizes.The hardcoded layer selection
{1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}may not work correctly for small models (e.g., models with fewer than 5 layers). Consider adding validation.Apply this diff to add validation:
+ if num_hidden_layers < 5: + raise ValueError(f"Model has only {num_hidden_layers} layers, need at least 5 for layer selection") + spec_config = { "output_directory": str(args.output_dir), "write_interval": 1, "file_prefix": f"dp_{args.dp_rank}", "eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}, }
196-196: Handle None slice correctly.When
args.debug_max_num_conversationsisNone, the sliceall_conversations[:None]works correctly but is implicit. The current code is correct due to Python's slice behavior, but adding explicit handling would improve clarity.
215-234: LGTM! Proper async execution and cleanup.The async execution, LLM shutdown, and summary reporting are correctly implemented.
| "--use-cuda-graph", | ||
| type=bool, | ||
| default=True, | ||
| help="""Whether to use CUDA graph.""", | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix boolean argument handling.
Using type=bool for boolean flags in argparse is problematic. Any non-empty string (including "False" or "0") will be interpreted as True. Use action='store_true' or action='store_false' instead.
Apply this diff:
     parser.add_argument(
         "--use-cuda-graph",
-        type=bool,
-        default=True,
+        action='store_true',
+        default=False,
         help="""Whether to use CUDA graph.""",
     )Or if you want it enabled by default:
     parser.add_argument(
-        "--use-cuda-graph",
-        type=bool,
-        default=True,
-        help="""Whether to use CUDA graph.""",
+        "--no-cuda-graph",
+        action='store_false',
+        dest='use_cuda_graph',
+        help="""Disable CUDA graph (enabled by default).""",
     )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "--use-cuda-graph", | |
| type=bool, | |
| default=True, | |
| help="""Whether to use CUDA graph.""", | |
| ) | |
| parser.add_argument( | |
| "--no-cuda-graph", | |
| action='store_false', | |
| dest='use_cuda_graph', | |
| help="""Disable CUDA graph (enabled by default).""", | |
| ) | 
| "--use-cuda-graph", | |
| type=bool, | |
| default=True, | |
| help="""Whether to use CUDA graph.""", | |
| ) | |
| parser.add_argument( | |
| "--use-cuda-graph", | |
| action='store_true', | |
| default=False, | |
| help="""Whether to use CUDA graph.""", | |
| ) | 
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 90 to 94, the argparse boolean flag currently uses type=bool which
treats any non-empty string as True; change the argument to use
action='store_true' (or action='store_false' if you prefer the inverse) so the
flag is parsed correctly, and set default accordingly (omit type and use
default=True/False as needed or use action='store_false' to have it enabled by
default).
        
          
                examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
              
                Outdated
          
            Show resolved
            Hide resolved
        
      Signed-off-by: h-guo18 <[email protected]>
d69a1c3    to
    067bc1a      
    Compare
  
    Signed-off-by: h-guo18 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh(1 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh(1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: build-docs
 - GitHub Check: code-quality
 
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| 
               | 
          ||
| # Example usage of the script to compute the hidden states for a conversation dataset | ||
| # This script computes hidden states using TensorRT-LLM and saves them to | ||
| # the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting | ||
| # the input file into 8 parts and running 8 processes in parallel, one on each GPU. | ||
| 
               | 
          ||
| # Note: depending on the write-throughput of the destination disk, this is not guaranteed | ||
| # to yield a speed improvement compared to running the model-parallel version. Consider | ||
| # benchmarking on a smaller dataset before launching a large run. | ||
| 
               | 
          ||
| INPUT_FILE=synthetic_conversations/daring-anteater.jsonl | ||
| OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ | ||
| DP_SIZE=8 | ||
| MODEL=meta-llama/Llama-3.2-1B-Instruct | ||
| export TLLM_LOG_LEVEL="error"; | ||
| 
               | 
          
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add a shebang to define the interpreter.
Same as the single-run script: make Bash explicit so the runtime shell is unambiguous and SC2148 is silenced.
+#!/usr/bin/env bash
+🧰 Tools
🪛 Shellcheck (0.11.0)
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
around lines 1-30, the script lacks a shebang; add a Bash shebang as the very
first line (e.g., use env to locate bash) so the interpreter is explicit and
shellcheck SC2148 is silenced, then save and ensure the script remains
executable (chmod +x) if needed.
| export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & | ||
| 
               | 
          ||
| # #On SLURM: | ||
| # PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i | ||
| 
               | 
          ||
| done | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix suffix formatting for DP ranks ≥10.
Line 36 hardcodes /tmp/part-0${i}.jsonl, which works only for single-digit ranks; once DP_SIZE reaches 10 the actual file is /tmp/part-10.jsonl, so the launch fails. Format the suffix with printf to stay aligned with split output.
-for i in $(seq 0 $((DP_SIZE-1)))
-do
-
-export CUDA_VISIBLE_DEVICES=$i;  python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
+for i in $(seq 0 $((DP_SIZE-1)))
+do
+  suffix=$(printf "%02d" "$i")
+  CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py \
+    --model "$MODEL" \
+    --input-file "/tmp/part-${suffix}.jsonl" \
+    --output-dir "$OUTPUT_DIR" \
+    --dp-rank "$i" &📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & | |
| # #On SLURM: | |
| # PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i | |
| done | |
| for i in $(seq 0 $((DP_SIZE-1))) | |
| do | |
| suffix=$(printf "%02d" "$i") | |
| CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model "$MODEL" \ | |
| --input-file "/tmp/part-${suffix}.jsonl" \ | |
| --output-dir "$OUTPUT_DIR" \ | |
| --dp-rank "$i" & | |
| # #On SLURM: | |
| # PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i | |
| done | 
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
around lines 36-41, the input file path is hardcoded as /tmp/part-0${i}.jsonl
which breaks for DP ranks >= 10; replace that construction with a printf-based
filename (e.g. use $(printf "/tmp/part-%s.jsonl" "$i") or similar) so the
generated suffix matches split's output for multi-digit ranks, and update the
python invocation to use that formatted filename.
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| 
               | 
          ||
| # Example usage of the script to compute the hidden states for a conversation dataset | ||
| # This script computes hidden states using TensorRT-LLM and saves them to | ||
| # the specified output directory. | ||
| 
               | 
          ||
| export TLLM_LOG_LEVEL="error"; | ||
| python3 collect_hidden_states/compute_hidden_states_trtllm.py \ | ||
| --model meta-llama/Llama-3.2-1B-Instruct \ | ||
| --input-file synthetic_conversations/daring-anteater.jsonl \ | ||
| --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add a shebang to define the interpreter.
Without a shebang the script runs under whatever /bin/sh points to, which can diverge from the Bash semantics used here and also trips ShellCheck (SC2148). Please make Bash explicit.
+#!/usr/bin/env bash
+📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Example usage of the script to compute the hidden states for a conversation dataset | |
| # This script computes hidden states using TensorRT-LLM and saves them to | |
| # the specified output directory. | |
| export TLLM_LOG_LEVEL="error"; | |
| python3 collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model meta-llama/Llama-3.2-1B-Instruct \ | |
| --input-file synthetic_conversations/daring-anteater.jsonl \ | |
| --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ | |
| #!/usr/bin/env bash | |
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Example usage of the script to compute the hidden states for a conversation dataset | |
| # This script computes hidden states using TensorRT-LLM and saves them to | |
| # the specified output directory. | |
| export TLLM_LOG_LEVEL="error"; | |
| python3 collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model meta-llama/Llama-3.2-1B-Instruct \ | |
| --input-file synthetic_conversations/daring-anteater.jsonl \ | |
| --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ | 
🧰 Tools
🪛 Shellcheck (0.11.0)
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
🤖 Prompt for AI Agents
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
lines 1-24: the script lacks a shebang so it may run under /bin/sh instead of
bash; add a bash shebang (e.g. #!/usr/bin/env bash) as the first line and ensure
the file is executable (chmod +x) so it runs with Bash semantics and avoids
ShellCheck SC2148.
Signed-off-by: h-guo18 <[email protected]>
        
          
                examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
          
            Show resolved
            Hide resolved
        
              
          
                examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
          
            Show resolved
            Hide resolved
        
      Signed-off-by: h-guo18 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/speculative_decoding/README.md(3 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)
modelopt/deploy/llm/generate.py (2)
LLM(53-291)max_seq_len(154-156)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: code-quality
 - GitHub Check: build-docs
 
🔇 Additional comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)
125-270: LGTM! Well-structured async processing pipeline.The implementation correctly handles:
- Conversation deduplication by checking existing output files
 - Tokenizer setup with pad token fallback and chat template adjustment
 - TRT-LLM configuration with CUDA graph, KV cache, and parallelism settings
 - Async task orchestration with proper progress tracking
 - Post-processing to match HF dump format
 - Comprehensive error reporting
 The async pattern efficiently batches hidden-state extraction across multiple conversations.
| def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): | ||
| """Post-process the TRTLLM dumped file to same format as HF dumped: | ||
| 1. Remove id field, replace it with conversation_id | ||
| 2. Rename hidden_state field to hidden_states | ||
| 3. From list of length 1 to dict | ||
| 4. Rename file to conversation_id.pt | ||
| """ | ||
| with open(trtllm_dumped_file, "rb") as f: | ||
| trtllm_dumped = torch.load(f) | ||
| assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( | ||
| "TRTLLM dumped should be a list with one element" | ||
| ) | ||
| assert ( | ||
| isinstance(trtllm_dumped[0], dict) | ||
| and "id" in trtllm_dumped[0] | ||
| and "hidden_state" in trtllm_dumped[0] | ||
| ), "TRTLLM dumped should have an 'id' and 'hidden_states' field" | ||
| trtllm_dumped = trtllm_dumped[0] | ||
| trtllm_dumped.pop("id") | ||
| trtllm_dumped["conversation_id"] = conversation_id | ||
| trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state") | ||
| output_file = args.output_dir / f"{conversation_id}.pt" | ||
| with open(output_file, "wb") as f: | ||
| torch.save(trtllm_dumped, f) | ||
| 
               | 
          ||
| if trtllm_dumped_file.exists(): | ||
| trtllm_dumped_file.unlink() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the type hint for trtllm_dumped_file parameter.
The type hint declares trtllm_dumped_file: str, but the function calls .exists() (line 220) and .unlink() (line 221), which are Path methods. At the call site (line 229), a Path object is passed. Update the type hint to Path for consistency.
Apply this diff:
-    def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
+    def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): | |
| """Post-process the TRTLLM dumped file to same format as HF dumped: | |
| 1. Remove id field, replace it with conversation_id | |
| 2. Rename hidden_state field to hidden_states | |
| 3. From list of length 1 to dict | |
| 4. Rename file to conversation_id.pt | |
| """ | |
| with open(trtllm_dumped_file, "rb") as f: | |
| trtllm_dumped = torch.load(f) | |
| assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( | |
| "TRTLLM dumped should be a list with one element" | |
| ) | |
| assert ( | |
| isinstance(trtllm_dumped[0], dict) | |
| and "id" in trtllm_dumped[0] | |
| and "hidden_state" in trtllm_dumped[0] | |
| ), "TRTLLM dumped should have an 'id' and 'hidden_states' field" | |
| trtllm_dumped = trtllm_dumped[0] | |
| trtllm_dumped.pop("id") | |
| trtllm_dumped["conversation_id"] = conversation_id | |
| trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state") | |
| output_file = args.output_dir / f"{conversation_id}.pt" | |
| with open(output_file, "wb") as f: | |
| torch.save(trtllm_dumped, f) | |
| if trtllm_dumped_file.exists(): | |
| trtllm_dumped_file.unlink() | |
| def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int): | |
| """Post-process the TRTLLM dumped file to same format as HF dumped: | |
| 1. Remove id field, replace it with conversation_id | |
| 2. Rename hidden_state field to hidden_states | |
| 3. From list of length 1 to dict | |
| 4. Rename file to conversation_id.pt | |
| """ | |
| with open(trtllm_dumped_file, "rb") as f: | |
| trtllm_dumped = torch.load(f) | |
| assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( | |
| "TRTLLM dumped should be a list with one element" | |
| ) | |
| assert ( | |
| isinstance(trtllm_dumped[0], dict) | |
| and "id" in trtllm_dumped[0] | |
| and "hidden_state" in trtllm_dumped[0] | |
| ), "TRTLLM dumped should have an 'id' and 'hidden_states' field" | |
| trtllm_dumped = trtllm_dumped[0] | |
| trtllm_dumped.pop("id") | |
| trtllm_dumped["conversation_id"] = conversation_id | |
| trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state") | |
| output_file = args.output_dir / f"{conversation_id}.pt" | |
| with open(output_file, "wb") as f: | |
| torch.save(trtllm_dumped, f) | |
| if trtllm_dumped_file.exists(): | |
| trtllm_dumped_file.unlink() | 
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 195 to 221, the parameter trtllm_dumped_file is annotated as str
but the function uses Path methods (.exists(), .unlink()) and is called with a
Path at the call site; update the type hint to Path (from pathlib) for
trtllm_dumped_file to match usage and callers, and add an import for Path at the
top if not already present.
| For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size. | ||
| 
               | 
          ||
| First, dump the base model's hidden states with the following command: | ||
| ### Dumpping Hidden States to Disk | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in heading.
"Dumpping" should be "Dumping".
Apply this diff:
-### Dumpping Hidden States to Disk
+### Dumping Hidden States to Disk📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ### Dumpping Hidden States to Disk | |
| ### Dumping Hidden States to Disk | 
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 87, the heading contains
a typo "Dumpping Hidden States to Disk"; change "Dumpping" to "Dumping" so the
heading reads "Dumping Hidden States to Disk".
| First, dump the base model's hidden states with the following command: | ||
| ### Dumpping Hidden States to Disk | ||
| 
               | 
          ||
| We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo.
"effciency" should be "efficiency".
Apply this diff:
-We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
+We support two backends for generating base model hidden states. For better efficiency, it is recommended to use TRT-LLM:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM: | |
| We support two backends for generating base model hidden states. For better efficiency, it is recommended to use TRT-LLM: | 
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 89, there is a typo:
change the word "effciency" to "efficiency" so the sentence reads "For better
efficiency, it is recommended to use TRT-LLM:"; update that single word in the
line.
| python collect_hidden_states/compute_hidden_states_trtllm.py \ | ||
| --model $BASE_MODEL \ | ||
| --input-file Daring-Anteater/train.jsonl \ | ||
| --output-dir $HIDDEN_STATES_DIR | ||
| ``` | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove trailing space after line continuation.
Line 93 has a trailing space after the backslash, which breaks bash line continuation. The backslash must be immediately followed by a newline.
Apply this diff:
 python collect_hidden_states/compute_hidden_states_trtllm.py \
-            --model $BASE_MODEL \ 
+            --model $BASE_MODEL \
             --input-file Daring-Anteater/train.jsonl \
             --output-dir $HIDDEN_STATES_DIR📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| python collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model $BASE_MODEL \ | |
| --input-file Daring-Anteater/train.jsonl \ | |
| --output-dir $HIDDEN_STATES_DIR | |
| ``` | |
| python collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model $BASE_MODEL \ | |
| --input-file Daring-Anteater/train.jsonl \ | |
| --output-dir $HIDDEN_STATES_DIR | 
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around lines 92 to 96, the bash line
continuation backslash on line 93 has a trailing space which breaks the
continuation; remove the trailing space so the backslash is the last character
on the line (and ensure the line ends with a newline, not CRLF) to restore
proper shell line continuation.
| ``` | ||
| 
               | 
          ||
| Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. | ||
| Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix grammar.
"Offline checkpoints does not support" should be "Offline checkpoints do not support" (plural subject requires plural verb).
Apply this diff:
-Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
+Offline checkpoints do not support this evaluation due to missing base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks. | |
| Offline checkpoints do not support this evaluation due to missing base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks. | 
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 133, the sentence
"Offline checkpoints does not support this evaluation due to missing of base
model modules." has a subject-verb agreement error; change "Offline checkpoints
does not support" to "Offline checkpoints do not support" so the plural subject
matches the plural verb (optionally also remove "of" to read "due to missing
base model modules").
What does this PR do?
Type of change: New feature
Overview:
TRTLLM recently supports dumping intermediate hidden states:https://github.com/NVIDIA/TensorRT-LLM/pull/7012/files
This PR add example scripts to dump hidden states with TRTLLM.
To solve data format mismatch, we post-process the dumpped file in the script to get exact same format with previously used HF dumper.
Usage
See
examples/compute_hidden_states/Testing
Diff on dumped files
Diff dumped files against previously used HF dummper. Tested on DP=1, DP=4, and TP=1, TP=2:
E2E AR on TinyLlama
On fixed training setting for 50000 steps:
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Scripts
Documentation