Skip to content

Commit add61db

Browse files
benchisletth-guo18
andauthored
Feature: Offline training for EAGLE3 (NVIDIA#300)
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: h-guo18 <[email protected]> Co-authored-by: h-guo18 <[email protected]>
1 parent c0590b0 commit add61db

21 files changed

+1425
-175
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
Daring-Anteater
2+
input_conversations
3+
synthetic_conversations
4+
ckpts
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Collect hidden states from a dataset of conversations."""
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extract hidden states from an HF-compatible LLM."""
17+
18+
import argparse
19+
import asyncio
20+
import json
21+
from pathlib import Path
22+
23+
import torch
24+
from tqdm import tqdm as tqdm
25+
from transformers import AutoModel, AutoTokenizer
26+
27+
REMOVE_THINK_CHAT_TEMPLATE = (
28+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
29+
)
30+
31+
32+
def parse_args() -> argparse.Namespace:
33+
parser = argparse.ArgumentParser(
34+
description="""Collect hidden states from conversations
35+
by running full conversations through a Hugging Face model."""
36+
)
37+
38+
## Model & Generation Parameters ##
39+
parser.add_argument(
40+
"--model",
41+
type=str,
42+
required=True,
43+
help="Name of the served model.",
44+
)
45+
46+
## Client Parameters ##
47+
parser.add_argument(
48+
"--max-seq-len",
49+
type=int,
50+
default=3072,
51+
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
52+
Defaults to 3072 tokens.""",
53+
)
54+
55+
## I/O Parameters ##
56+
parser.add_argument(
57+
"--input-file",
58+
type=Path,
59+
required=True,
60+
help="""Path to the input `jsonl` file containing conversations.
61+
Each entry must have a unique `conversation_id` field and a `conversations` field
62+
containing a list of messages.""",
63+
)
64+
parser.add_argument(
65+
"--output-dir",
66+
type=Path,
67+
required=True,
68+
help="""Root directory in which to save the hidden states.
69+
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
70+
)
71+
parser.add_argument(
72+
"--debug-max-num-conversations",
73+
type=int,
74+
default=None,
75+
help="""For debugging purposes, limit the number of conversations processed.
76+
Default is None, meaning no limit.""",
77+
)
78+
79+
return parser.parse_args()
80+
81+
82+
async def main(args: argparse.Namespace) -> None:
83+
all_conversations = []
84+
with args.input_file.open("r", encoding="utf-8") as f:
85+
all_conversations.extend([json.loads(line) for line in f if line.strip()])
86+
87+
print("Loaded", len(all_conversations), "conversations from", args.input_file)
88+
89+
model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
90+
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)
91+
92+
tokenizer = AutoTokenizer.from_pretrained(args.model)
93+
if tokenizer.pad_token is None:
94+
tokenizer.pad_token = tokenizer.eos_token
95+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
96+
97+
output_dir = args.output_dir
98+
output_dir.mkdir(parents=True, exist_ok=True)
99+
num_skipped_too_long = 0
100+
num_invalid = 0
101+
num_success = 0
102+
num_total_conversations = min(
103+
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
104+
)
105+
for idx, entry in enumerate(
106+
tqdm(
107+
all_conversations[: args.debug_max_num_conversations],
108+
desc="Processing conversations",
109+
total=num_total_conversations,
110+
)
111+
):
112+
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
113+
conversations = entry["conversations"]
114+
if not conversations or not isinstance(conversations, list):
115+
num_invalid += 1
116+
continue
117+
118+
# Tokenize and check length
119+
input_ids = tokenizer.apply_chat_template(
120+
conversations, return_tensors="pt", add_generation_template=False
121+
)
122+
num_input_tokens = input_ids.shape[1]
123+
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
124+
num_skipped_too_long += 1
125+
continue
126+
127+
# Get hidden states
128+
with torch.inference_mode():
129+
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
130+
if num_hidden_layers is None:
131+
num_hidden_layers = len(outputs.hidden_states) - 1
132+
else:
133+
assert num_hidden_layers + 1 == len(outputs.hidden_states), (
134+
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}."
135+
)
136+
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
137+
hidden_states = outputs.hidden_states
138+
selected_layer_indices = [
139+
2,
140+
max(0, num_hidden_layers // 2),
141+
max(1, num_hidden_layers - 3),
142+
]
143+
selected_layer_indices = sorted(set(selected_layer_indices))
144+
aux_hidden_states = torch.cat(
145+
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
146+
)
147+
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu()
148+
output_file = output_dir / f"{conversation_id}.pt"
149+
num_success += 1
150+
with open(output_file, "wb") as f:
151+
torch.save(
152+
{
153+
"input_ids": input_ids.squeeze(0).cpu(),
154+
"hidden_states": output_hidden_states,
155+
"aux_hidden_states": aux_hidden_states,
156+
"conversation_id": conversation_id,
157+
},
158+
f,
159+
)
160+
161+
if num_skipped_too_long > 0:
162+
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
163+
if num_invalid > 0:
164+
print(f"Skipped {num_invalid} invalid conversations without proper fields.")
165+
166+
if num_success == num_total_conversations:
167+
print(f"Successfully processed all {num_success} conversations.")
168+
else:
169+
print(
170+
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
171+
)
172+
173+
174+
if __name__ == "__main__":
175+
cli_args = parse_args()
176+
asyncio.run(main(cli_args))
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using a Hugging Face model and saves them to
18+
# the specified output directory.
19+
20+
python3 collect_hidden_states/compute_hidden_states_hf.py \
21+
--model meta-llama/Llama-3.2-1B-Instruct \
22+
--input-file synthetic_conversations/daring-anteater.jsonl \
23+
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using a Hugging Face model and saves them to
18+
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
19+
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.
20+
21+
# Note: depending on the write-throughput of the destination disk, this is not guaranteed
22+
# to yield a speed improvement compared to running the model-parallel version. Consider
23+
# benchmarking on a smaller dataset before launching a large run.
24+
25+
INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
26+
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
27+
28+
split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
29+
30+
for i in $(seq 0 7)
31+
do
32+
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
33+
done
34+
wait
35+
36+
rm /tmp/part-*.jsonl
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to send conversations for hidden state collection
17+
# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.
18+
19+
python3 collect_hidden_states/send_conversations_for_hiddens.py \
20+
--model meta-llama/Llama-3.2-1B-Instruct \
21+
--input-file synthetic_conversations/mtbench.jsonl \
22+
--output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
23+
# --debug-max-num-conversations-per-split 1000
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utility script to print a sample of hidden states extracted from a dataset."""
17+
18+
import argparse
19+
import random
20+
from pathlib import Path
21+
22+
import torch
23+
24+
25+
def parse_args() -> argparse.Namespace:
26+
parser = argparse.ArgumentParser(
27+
description="Print a sample of hidden states from a dataset."
28+
"This script will crawl the provided directory for hidden state files,"
29+
" and print a small number of samples."
30+
)
31+
32+
parser.add_argument(
33+
"--input-path",
34+
type=Path,
35+
required=True,
36+
help="Path to the base input directory containing hidden states."
37+
"Alternatively, this can be a path to a specific `.pt` file.",
38+
)
39+
parser.add_argument(
40+
"--num-samples",
41+
type=int,
42+
default=1,
43+
help="Number of samples to print per split. If input_path is a file, this is ignored. "
44+
"Defaults to 1.",
45+
)
46+
return parser.parse_args()
47+
48+
49+
def main(args: argparse.Namespace) -> None:
50+
# Iterate through the input directory and find all hidden state files
51+
if args.input_path.is_file():
52+
all_files = [args.input_path]
53+
else:
54+
all_files = list(args.input_path.glob("*.pt"))
55+
56+
sampled_files = (
57+
random.sample(all_files, args.num_samples)
58+
if len(all_files) > args.num_samples
59+
else all_files
60+
)
61+
62+
for i, file in enumerate(sampled_files):
63+
data = torch.load(file)
64+
expected_keys = [
65+
"input_ids",
66+
"hidden_states",
67+
"aux_hidden_states",
68+
"conversation_id",
69+
]
70+
if set(expected_keys) != set(data.keys()):
71+
print(f"File {file} does not contain all expected keys: {expected_keys}")
72+
print(f" Found keys: {list(data.keys())}")
73+
continue
74+
print(f"Sample {i + 1}: {file.name}")
75+
for key in ["input_ids", "hidden_states", "aux_hidden_states"]:
76+
print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}")
77+
print(f"conversation_id: {data['conversation_id']}")
78+
input_ids_list = data["input_ids"].tolist()
79+
hidden_states = data["hidden_states"]
80+
print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}")
81+
print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}")
82+
print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}")
83+
84+
print(f"\n\nDone. Found: {len(all_files)} files in total.")
85+
86+
87+
if __name__ == "__main__":
88+
args = parse_args()
89+
main(args)

0 commit comments

Comments
 (0)