Skip to content

Commit 450c773

Browse files
committed
add scripts for trtllm eagle dumper
Signed-off-by: h-guo18 <[email protected]>
1 parent 340eb7a commit 450c773

File tree

4 files changed

+297
-2
lines changed

4 files changed

+297
-2
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extract hidden states from an HF-compatible LLM."""
17+
18+
import os
19+
20+
os.environ["TLLM_LOG_LEVEL"] = "error"
21+
import argparse
22+
import asyncio
23+
import json
24+
from pathlib import Path
25+
26+
import torch
27+
from tensorrt_llm import LLM, SamplingParams
28+
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
29+
from tqdm import tqdm as tqdm
30+
from transformers import AutoConfig, AutoTokenizer
31+
32+
REMOVE_THINK_CHAT_TEMPLATE = (
33+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
34+
)
35+
36+
37+
def parse_args() -> argparse.Namespace:
38+
parser = argparse.ArgumentParser(
39+
description="""Collect hidden states from conversations
40+
by running full conversations through a Hugging Face model."""
41+
)
42+
43+
## Model & Generation Parameters ##
44+
parser.add_argument(
45+
"--model",
46+
type=str,
47+
required=True,
48+
help="Name of the served model.",
49+
)
50+
51+
## Client Parameters ##
52+
parser.add_argument(
53+
"--max-seq-len",
54+
type=int,
55+
default=3072,
56+
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
57+
Defaults to 3072 tokens.""",
58+
)
59+
60+
## I/O Parameters ##
61+
parser.add_argument(
62+
"--input-file",
63+
type=Path,
64+
required=True,
65+
help="""Path to the input `jsonl` file containing conversations.
66+
Each entry must have a unique `conversation_id` field and a `conversations` field
67+
containing a list of messages.""",
68+
)
69+
parser.add_argument(
70+
"--output-dir",
71+
type=Path,
72+
required=True,
73+
help="""Root directory in which to save the hidden states.
74+
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
75+
)
76+
parser.add_argument(
77+
"--debug-max-num-conversations",
78+
type=int,
79+
default=None,
80+
help="""For debugging purposes, limit the number of conversations processed.
81+
Default is None, meaning no limit.""",
82+
)
83+
parser.add_argument(
84+
"--dp-rank",
85+
type=int,
86+
default=0,
87+
help="""Data parallel rank.""",
88+
)
89+
parser.add_argument(
90+
"--use-cuda-graph",
91+
type=bool,
92+
default=True,
93+
help="""Whether to use CUDA graph.""",
94+
)
95+
parser.add_argument(
96+
"--tp-size-per-dp",
97+
type=int,
98+
default=2,
99+
help="""Tensor parallel size per data parallel.""",
100+
)
101+
102+
return parser.parse_args()
103+
104+
105+
def main(args: argparse.Namespace) -> None:
106+
# Load conversations
107+
all_conversations = []
108+
with args.input_file.open("r", encoding="utf-8") as f:
109+
all_conversations.extend([json.loads(line) for line in f if line.strip()])
110+
print("Loaded", len(all_conversations), "conversations from", args.input_file)
111+
112+
# Get model config and tokenizer
113+
model_config = AutoConfig.from_pretrained(args.model)
114+
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
115+
tokenizer = AutoTokenizer.from_pretrained(args.model)
116+
if tokenizer.pad_token is None:
117+
tokenizer.pad_token = tokenizer.eos_token
118+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
119+
120+
# Set up LLM
121+
llm_common_config = {
122+
"model": args.model,
123+
"attn_backend": "TRTLLM",
124+
"disable_overlap_scheduler": False,
125+
"cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4])
126+
if args.use_cuda_graph
127+
else None,
128+
"max_batch_size": 16,
129+
"kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5),
130+
"enable_chunked_prefill": False,
131+
"tensor_parallel_size": args.tp_size_per_dp,
132+
}
133+
spec_config = {
134+
"output_directory": str(args.output_dir),
135+
"write_interval": 1,
136+
"file_prefix": f"dp_{args.dp_rank}",
137+
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
138+
}
139+
sampling_params = SamplingParams(max_tokens=32, temperature=0)
140+
141+
llm_spec = LLM(
142+
**llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config)
143+
)
144+
145+
args.output_dir.mkdir(parents=True, exist_ok=True)
146+
num_skipped_too_long = 0
147+
num_invalid = 0
148+
num_success = 0
149+
num_total_conversations = min(
150+
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
151+
)
152+
pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations")
153+
154+
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
155+
"""Post-process the TRTLLM dumped file to same format as HF dumped:
156+
1. Remove id field, replace it with conversation_id
157+
2. Rename hidden_state field to hidden_states
158+
3. From list of length 1 to dict
159+
4. Rename file to conversation_id.pt
160+
"""
161+
with open(trtllm_dumped_file, "rb") as f:
162+
trtllm_dumped = torch.load(f)
163+
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
164+
"TRTLLM dumped should be a list with one element"
165+
)
166+
assert (
167+
isinstance(trtllm_dumped[0], dict)
168+
and "id" in trtllm_dumped[0]
169+
and "hidden_state" in trtllm_dumped[0]
170+
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
171+
trtllm_dumped = trtllm_dumped[0]
172+
trtllm_dumped.pop("id")
173+
trtllm_dumped["conversation_id"] = conversation_id
174+
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
175+
output_file = args.output_dir / f"{conversation_id}.pt"
176+
with open(output_file, "wb") as f:
177+
torch.save(trtllm_dumped, f)
178+
179+
if trtllm_dumped_file.exists():
180+
trtllm_dumped_file.unlink()
181+
182+
async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
183+
nonlocal num_success
184+
await llm_spec.generate_async(input_ids, sampling_params)
185+
# TRTLLM API name files starts from 1
186+
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
187+
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
188+
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
189+
num_success += 1
190+
pbar.update(1)
191+
192+
async def submit_generates():
193+
nonlocal num_skipped_too_long
194+
nonlocal num_invalid
195+
tasks = []
196+
for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]):
197+
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
198+
199+
conversations = entry["conversations"]
200+
if not conversations or not isinstance(conversations, list):
201+
num_invalid += 1
202+
continue
203+
204+
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
205+
num_input_tokens = (
206+
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
207+
)
208+
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
209+
num_skipped_too_long += 1
210+
continue
211+
212+
tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
213+
await asyncio.gather(*tasks)
214+
215+
asyncio.run(submit_generates())
216+
llm_spec.shutdown()
217+
print("LLM shutdown")
218+
219+
if num_skipped_too_long > 0:
220+
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
221+
if num_invalid > 0:
222+
print(f"Skipped {num_invalid} invalid conversations without proper fields.")
223+
224+
if num_success == num_total_conversations:
225+
print(f"Successfully processed all {num_success} conversations.")
226+
else:
227+
print(
228+
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
229+
)
230+
231+
232+
if __name__ == "__main__":
233+
cli_args = parse_args()
234+
main(cli_args)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424

2525
INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
2626
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
27+
DP_SIZE=8
2728

28-
split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
29+
split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
2930

30-
for i in $(seq 0 7)
31+
for i in $(seq 0 $((DP_SIZE-1)))
3132
do
3233
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
3334
done
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using TensorRT-LLM and saves them to
18+
# the specified output directory.
19+
20+
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
21+
--model meta-llama/Llama-3.2-1B-Instruct \
22+
--input-file synthetic_conversations/daring-anteater.jsonl \
23+
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using a Hugging Face model and saves them to
18+
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
19+
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.
20+
21+
# Note: depending on the write-throughput of the destination disk, this is not guaranteed
22+
# to yield a speed improvement compared to running the model-parallel version. Consider
23+
# benchmarking on a smaller dataset before launching a large run.
24+
25+
INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
26+
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
27+
DP_SIZE=8
28+
29+
split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
30+
31+
for i in $(seq 0 $((DP_SIZE-1)))
32+
do
33+
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
34+
done
35+
wait
36+
37+
rm /tmp/part-*.jsonl

0 commit comments

Comments
 (0)