Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Extract hidden states from an HF-compatible LLM."""

import os

os.environ["TLLM_LOG_LEVEL"] = "error"
import argparse
import asyncio
import json
from pathlib import Path

import torch
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
from tqdm import tqdm as tqdm
from transformers import AutoConfig, AutoTokenizer

REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations
by running full conversations through a Hugging Face model."""
)

## Model & Generation Parameters ##
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the served model.",
)

## Client Parameters ##
parser.add_argument(
"--max-seq-len",
type=int,
default=3072,
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
Defaults to 3072 tokens.""",
)

## I/O Parameters ##
parser.add_argument(
"--input-file",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="""Root directory in which to save the hidden states.
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
)
parser.add_argument(
"--debug-max-num-conversations",
type=int,
default=None,
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)
parser.add_argument(
"--dp-rank",
type=int,
default=0,
help="""Data parallel rank.""",
)
parser.add_argument(
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
Comment on lines +90 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix boolean argument handling.

Using type=bool for boolean flags in argparse is problematic. Any non-empty string (including "False" or "0") will be interpreted as True. Use action='store_true' or action='store_false' instead.

Apply this diff:

     parser.add_argument(
         "--use-cuda-graph",
-        type=bool,
-        default=True,
+        action='store_true',
+        default=False,
         help="""Whether to use CUDA graph.""",
     )

Or if you want it enabled by default:

     parser.add_argument(
-        "--use-cuda-graph",
-        type=bool,
-        default=True,
-        help="""Whether to use CUDA graph.""",
+        "--no-cuda-graph",
+        action='store_false',
+        dest='use_cuda_graph',
+        help="""Disable CUDA graph (enabled by default).""",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--no-cuda-graph",
action='store_false',
dest='use_cuda_graph',
help="""Disable CUDA graph (enabled by default).""",
)
Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--use-cuda-graph",
action='store_true',
default=False,
help="""Whether to use CUDA graph.""",
)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 90 to 94, the argparse boolean flag currently uses type=bool which
treats any non-empty string as True; change the argument to use
action='store_true' (or action='store_false' if you prefer the inverse) so the
flag is parsed correctly, and set default accordingly (omit type and use
default=True/False as needed or use action='store_false' to have it enabled by
default).

parser.add_argument(
"--tp-size-per-dp",
type=int,
default=2,
help="""Tensor parallel size per data parallel.""",
)

return parser.parse_args()


def main(args: argparse.Namespace) -> None:
# Load conversations
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])
print("Loaded", len(all_conversations), "conversations from", args.input_file)

# Get model config and tokenizer
model_config = AutoConfig.from_pretrained(args.model)
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")

# Set up LLM
llm_common_config = {
"model": args.model,
"attn_backend": "TRTLLM",
"disable_overlap_scheduler": False,
"cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4])
if args.use_cuda_graph
else None,
"max_batch_size": 16,
"kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5),
"enable_chunked_prefill": False,
"tensor_parallel_size": args.tp_size_per_dp,
}
spec_config = {
"output_directory": str(args.output_dir),
"write_interval": 1,
"file_prefix": f"dp_{args.dp_rank}",
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
}
sampling_params = SamplingParams(max_tokens=32, temperature=0)

llm_spec = LLM(
**llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config)
)

args.output_dir.mkdir(parents=True, exist_ok=True)
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations")

def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)

if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
Comment on lines +195 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Correct the type hint for trtllm_dumped_file parameter.

The type hint declares trtllm_dumped_file: str, but the function calls .exists() (line 220) and .unlink() (line 221), which are Path methods. At the call site (line 229), a Path object is passed. Update the type hint to Path for consistency.

Apply this diff:

-    def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
+    def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 195 to 221, the parameter trtllm_dumped_file is annotated as str
but the function uses Path methods (.exists(), .unlink()) and is called with a
Path at the call site; update the type hint to Path (from pathlib) for
trtllm_dumped_file to match usage and callers, and add an import for Path at the
top if not already present.


async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
nonlocal num_success
await llm_spec.generate_async(input_ids, sampling_params)
# TRTLLM API name files starts from 1
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
num_success += 1
pbar.update(1)

async def submit_generates():
nonlocal num_skipped_too_long
nonlocal num_invalid
tasks = []
for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]):
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))

conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
num_input_tokens = (
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
)
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
await asyncio.gather(*tasks)

asyncio.run(submit_generates())
llm_spec.shutdown()
print("LLM shutdown")

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)


if __name__ == "__main__":
cli_args = parse_args()
main(cli_args)
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
DP_SIZE=8

split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 7)
for i in $(seq 0 $((DP_SIZE-1)))
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
done
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.

python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.

# Note: depending on the write-throughput of the destination disk, this is not guaranteed
# to yield a speed improvement compared to running the model-parallel version. Consider
# benchmarking on a smaller dataset before launching a large run.

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
DP_SIZE=8

split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 $((DP_SIZE-1)))
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
done
wait

rm /tmp/part-*.jsonl