Skip to content

Commit abed33c

Browse files
authored
Feat: TRTLLM Dumper for Eagle Offline Training (#404)
Signed-off-by: h-guo18 <[email protected]>
1 parent 1537885 commit abed33c

File tree

5 files changed

+369
-9
lines changed

5 files changed

+369
-9
lines changed

examples/speculative_decoding/README.md

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,22 @@ The saved modelopt checkpoint is similar in architecture to HF models. It can be
8282

8383
## Training Draft Model with Offline Base Model
8484

85-
For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of storage depending on dataset size.
85+
For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size.
8686

87-
First, dump the base model's hidden states with the following command:
87+
### Dumpping Hidden States to Disk
88+
89+
We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
90+
91+
```bash
92+
python collect_hidden_states/compute_hidden_states_trtllm.py \
93+
--model $BASE_MODEL \
94+
--input-file Daring-Anteater/train.jsonl \
95+
--output-dir $HIDDEN_STATES_DIR
96+
```
97+
98+
**NOTE**: TRT-LLM installation needed for the above command.
99+
100+
Alternatively, you can generate the same hidden states with HF:
88101

89102
```bash
90103
python collect_hidden_states/compute_hidden_states_hf.py \
@@ -93,9 +106,11 @@ python collect_hidden_states/compute_hidden_states_hf.py \
93106
--output-dir $HIDDEN_STATES_DIR
94107
```
95108

96-
See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation.
109+
**NOTE**: See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) and [`run_trtllm_compute_hiddens_dp.sh`](./collect_hidden_states/run_trtllm_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation.
110+
111+
### Train Draft Model with Dumped Hidden States
97112

98-
Then, train draft model with `--offline-data` argument:
113+
Once we finish dumping hidden states, launch offline training with an extra `--offline-data` argument:
99114

100115
```bash
101116
./launch_train.sh --model $BASE_MODEL \
@@ -109,13 +124,13 @@ Then, train draft model with `--offline-data` argument:
109124

110125
## Model Validation
111126

112-
After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
127+
For online training checkpoints, we can run in-framework evaluation on MT-bench:
113128

114129
```bash
115-
python ar_validate.py --model_path $OUTPUT_DIR
130+
python ar_validate.py --model_path $ONLINE_CKPT
116131
```
117132

118-
Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
133+
Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
119134

120135
## Export
121136

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extract hidden states from an HF-compatible LLM."""
17+
18+
import os
19+
20+
os.environ["TLLM_LOG_LEVEL"] = "error"
21+
import argparse
22+
import asyncio
23+
import json
24+
from pathlib import Path
25+
26+
import torch
27+
from tensorrt_llm import LLM, SamplingParams
28+
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
29+
from tqdm import tqdm as tqdm
30+
from transformers import AutoConfig, AutoTokenizer
31+
32+
REMOVE_THINK_CHAT_TEMPLATE = (
33+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
34+
)
35+
36+
37+
def parse_args() -> argparse.Namespace:
38+
parser = argparse.ArgumentParser(
39+
description="""Collect hidden states from conversations
40+
by running full conversations through a Hugging Face model."""
41+
)
42+
43+
## Model & Generation Parameters ##
44+
parser.add_argument(
45+
"--model",
46+
type=str,
47+
required=True,
48+
help="Name of the served model.",
49+
)
50+
51+
## Client Parameters ##
52+
parser.add_argument(
53+
"--max-seq-len",
54+
type=int,
55+
default=3072,
56+
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
57+
Defaults to 3072 tokens.""",
58+
)
59+
60+
## I/O Parameters ##
61+
parser.add_argument(
62+
"--input-file",
63+
type=Path,
64+
required=True,
65+
help="""Path to the input `jsonl` file containing conversations.
66+
Each entry must have a unique `conversation_id` field and a `conversations` field
67+
containing a list of messages.""",
68+
)
69+
parser.add_argument(
70+
"--output-dir",
71+
type=Path,
72+
required=True,
73+
help="""Root directory in which to save the hidden states.
74+
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
75+
)
76+
parser.add_argument(
77+
"--debug-max-num-conversations",
78+
type=int,
79+
default=None,
80+
help="""For debugging purposes, limit the number of conversations processed.
81+
Default is None, meaning no limit.""",
82+
)
83+
parser.add_argument(
84+
"--dp-rank",
85+
type=int,
86+
default=0,
87+
help="""Data parallel rank.""",
88+
)
89+
parser.add_argument(
90+
"--use-cuda-graph",
91+
type=bool,
92+
default=True,
93+
help="""Whether to use CUDA graph.""",
94+
)
95+
parser.add_argument(
96+
"--tp",
97+
type=int,
98+
default=1,
99+
help="""tensor_parallel_size for TRTLLM.""",
100+
)
101+
# moe_ep * moe_tp * moe_cp should be equal to tp
102+
# REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html
103+
parser.add_argument(
104+
"--moe_ep",
105+
type=int,
106+
default=1,
107+
help="""moe_expert_parallel_size for TRTLLM.""",
108+
)
109+
parser.add_argument(
110+
"--moe_tp",
111+
type=int,
112+
default=1,
113+
help="""moe_tensor_parallel_size for TRTLLM.""",
114+
)
115+
parser.add_argument(
116+
"--moe_cp",
117+
type=int,
118+
default=1,
119+
help="""moe_cluster_parallel_size for TRTLLM.""",
120+
)
121+
122+
return parser.parse_args()
123+
124+
125+
def main(args: argparse.Namespace) -> None:
126+
# Load conversations
127+
all_conversations = []
128+
with args.input_file.open("r", encoding="utf-8") as f:
129+
all_conversations.extend([json.loads(line) for line in f if line.strip()])
130+
print("Loaded", len(all_conversations), "conversations from", args.input_file)
131+
132+
# Remove conversations whose output file already exists
133+
filtered_conversations = []
134+
for entry in all_conversations:
135+
conversation_id = entry.get("conversation_id", None)
136+
if conversation_id is None:
137+
filtered_conversations.append(entry)
138+
continue
139+
output_file = args.output_dir / f"{conversation_id}.pt"
140+
if output_file.exists():
141+
continue
142+
filtered_conversations.append(entry)
143+
print(
144+
"Removed",
145+
len(all_conversations) - len(filtered_conversations),
146+
"conversations due to existing output files",
147+
)
148+
all_conversations = filtered_conversations
149+
150+
# Get model config and tokenizer
151+
model_config = AutoConfig.from_pretrained(args.model)
152+
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
153+
tokenizer = AutoTokenizer.from_pretrained(args.model)
154+
if tokenizer.pad_token is None:
155+
tokenizer.pad_token = tokenizer.eos_token
156+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
157+
158+
# Set up LLM
159+
llm_common_config = {
160+
"model": args.model,
161+
"attn_backend": "TRTLLM",
162+
"disable_overlap_scheduler": False,
163+
"cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4])
164+
if args.use_cuda_graph
165+
else None,
166+
"max_batch_size": 16,
167+
"kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5),
168+
"enable_chunked_prefill": False,
169+
"tensor_parallel_size": args.tp,
170+
"moe_expert_parallel_size": args.moe_ep,
171+
"moe_tensor_parallel_size": args.moe_tp,
172+
"moe_cluster_parallel_size": args.moe_cp,
173+
}
174+
spec_config = {
175+
"output_directory": str(args.output_dir),
176+
"write_interval": 1,
177+
"file_prefix": f"dp_{args.dp_rank}",
178+
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
179+
}
180+
sampling_params = SamplingParams(max_tokens=32, temperature=0)
181+
182+
llm_spec = LLM(
183+
**llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config)
184+
)
185+
186+
args.output_dir.mkdir(parents=True, exist_ok=True)
187+
num_skipped_too_long = 0
188+
num_invalid = 0
189+
num_success = 0
190+
num_total_conversations = min(
191+
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
192+
)
193+
pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations")
194+
195+
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
196+
"""Post-process the TRTLLM dumped file to same format as HF dumped:
197+
1. Remove id field, replace it with conversation_id
198+
2. Rename hidden_state field to hidden_states
199+
3. From list of length 1 to dict
200+
4. Rename file to conversation_id.pt
201+
"""
202+
with open(trtllm_dumped_file, "rb") as f:
203+
trtllm_dumped = torch.load(f)
204+
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
205+
"TRTLLM dumped should be a list with one element"
206+
)
207+
assert (
208+
isinstance(trtllm_dumped[0], dict)
209+
and "id" in trtllm_dumped[0]
210+
and "hidden_state" in trtllm_dumped[0]
211+
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
212+
trtllm_dumped = trtllm_dumped[0]
213+
trtllm_dumped.pop("id")
214+
trtllm_dumped["conversation_id"] = conversation_id
215+
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
216+
output_file = args.output_dir / f"{conversation_id}.pt"
217+
with open(output_file, "wb") as f:
218+
torch.save(trtllm_dumped, f)
219+
220+
if trtllm_dumped_file.exists():
221+
trtllm_dumped_file.unlink()
222+
223+
async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
224+
nonlocal num_success
225+
await llm_spec.generate_async(input_ids, sampling_params)
226+
# TRTLLM API name files starts from 1
227+
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
228+
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
229+
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
230+
num_success += 1
231+
pbar.update(1)
232+
233+
async def submit_generates():
234+
nonlocal num_skipped_too_long
235+
nonlocal num_invalid
236+
tasks = []
237+
for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]):
238+
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
239+
240+
conversations = entry["conversations"]
241+
if not conversations or not isinstance(conversations, list):
242+
num_invalid += 1
243+
continue
244+
245+
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
246+
num_input_tokens = (
247+
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
248+
)
249+
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
250+
num_skipped_too_long += 1
251+
continue
252+
253+
tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
254+
await asyncio.gather(*tasks)
255+
256+
asyncio.run(submit_generates())
257+
llm_spec.shutdown()
258+
print("LLM shutdown")
259+
260+
if num_skipped_too_long > 0:
261+
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
262+
if num_invalid > 0:
263+
print(f"Skipped {num_invalid} invalid conversations without proper fields.")
264+
265+
if num_success == num_total_conversations:
266+
print(f"Successfully processed all {num_success} conversations.")
267+
else:
268+
print(
269+
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
270+
)
271+
272+
273+
if __name__ == "__main__":
274+
cli_args = parse_args()
275+
main(cli_args)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424

2525
INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
2626
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
27+
DP_SIZE=8
2728

28-
split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
29+
split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
2930

30-
for i in $(seq 0 7)
31+
for i in $(seq 0 $((DP_SIZE-1)))
3132
do
3233
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
3334
done
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using TensorRT-LLM and saves them to
18+
# the specified output directory.
19+
20+
export TLLM_LOG_LEVEL="error";
21+
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
22+
--model meta-llama/Llama-3.2-1B-Instruct \
23+
--input-file synthetic_conversations/daring-anteater.jsonl \
24+
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
25+

0 commit comments

Comments
 (0)