Skip to content

Commit 4282dbe

Browse files
committed
Offline training for EAGLE3
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent 76e8ce2 commit 4282dbe

22 files changed

+2028
-22
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Collect hidden states from a dataset of conversations."""
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extract hidden states from an HF-compatible LLM."""
17+
18+
import argparse
19+
import asyncio
20+
import json
21+
from pathlib import Path
22+
23+
import torch
24+
from tqdm import tqdm as tqdm
25+
from transformers import AutoModel, AutoTokenizer
26+
27+
28+
def parse_args() -> argparse.Namespace:
29+
parser = argparse.ArgumentParser(
30+
description="""Collect hidden states from conversations
31+
by running full conversations through a Hugging Face model."""
32+
)
33+
34+
## Model & Generation Parameters ##
35+
parser.add_argument(
36+
"--model",
37+
type=str,
38+
required=True,
39+
help="Name of the served model.",
40+
)
41+
42+
## Client Parameters ##
43+
parser.add_argument(
44+
"--max-seq-len",
45+
type=int,
46+
default=3072,
47+
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
48+
Defaults to 3072 tokens.""",
49+
)
50+
51+
## I/O Parameters ##
52+
parser.add_argument(
53+
"--input-file",
54+
type=Path,
55+
required=True,
56+
help="""Path to the input `jsonl` file containing conversations.
57+
Each entry must have a unique `conversation_id` field and a `conversations` field
58+
containing a list of messages.""",
59+
)
60+
parser.add_argument(
61+
"--output-dir",
62+
type=Path,
63+
required=True,
64+
help="""Root directory in which to save the hidden states.
65+
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
66+
)
67+
parser.add_argument(
68+
"--debug-max-num-conversations",
69+
type=int,
70+
default=None,
71+
help="""For debugging purposes, limit the number of conversations processed.
72+
Default is None, meaning no limit.""",
73+
)
74+
75+
return parser.parse_args()
76+
77+
78+
async def main(args: argparse.Namespace) -> None:
79+
all_conversations = []
80+
with args.input_file.open("r", encoding="utf-8") as f:
81+
all_conversations.extend([json.loads(line) for line in f if line.strip()])
82+
83+
if any(not entry.get("conversation_id") for entry in all_conversations):
84+
msg = "All conversations must have a 'conversation_id' field."
85+
raise ValueError(msg)
86+
87+
print("Loaded", len(all_conversations), "conversations from", args.input_file)
88+
89+
model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
90+
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)
91+
92+
tokenizer = AutoTokenizer.from_pretrained(args.model)
93+
if tokenizer.pad_token is None:
94+
tokenizer.pad_token = tokenizer.eos_token
95+
96+
output_dir = args.output_dir
97+
output_dir.mkdir(parents=True, exist_ok=True)
98+
num_skipped_too_long = 0
99+
num_invalid = 0
100+
num_success = 0
101+
num_total_conversations = min(
102+
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
103+
)
104+
for entry in tqdm(
105+
all_conversations[: args.debug_max_num_conversations],
106+
desc="Processing conversations",
107+
total=num_total_conversations,
108+
):
109+
conversation_id = entry["conversation_id"]
110+
conversations = entry["conversations"]
111+
if not conversations or not isinstance(conversations, list):
112+
num_invalid += 1
113+
continue
114+
115+
# Tokenize and check length
116+
input_ids = tokenizer.apply_chat_template(
117+
conversations, return_tensors="pt", add_generation_template=False
118+
)
119+
num_input_tokens = input_ids.shape[1]
120+
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
121+
num_skipped_too_long += 1
122+
continue
123+
124+
# Get hidden states
125+
with torch.inference_mode():
126+
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
127+
if num_hidden_layers is None:
128+
num_hidden_layers = len(outputs.hidden_states) - 1
129+
else:
130+
assert num_hidden_layers + 1 == len(outputs.hidden_states), (
131+
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}."
132+
)
133+
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
134+
hidden_states = outputs.hidden_states
135+
selected_layer_indices = [2, num_hidden_layers // 2, num_hidden_layers - 3]
136+
aux_hidden_states = torch.cat(
137+
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
138+
)
139+
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu()
140+
output_file = output_dir / f"{conversation_id}.pt"
141+
num_success += 1
142+
with open(output_file, "wb") as f:
143+
torch.save(
144+
{
145+
"input_ids": input_ids.squeeze(0).cpu(),
146+
"hidden_states": output_hidden_states,
147+
"aux_hidden_states": aux_hidden_states,
148+
"conversation_id": conversation_id,
149+
},
150+
f,
151+
)
152+
153+
if num_skipped_too_long > 0:
154+
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
155+
if num_invalid > 0:
156+
print(f"Skipped {num_invalid} invalid conversations without proper fields.")
157+
158+
if num_success == num_total_conversations:
159+
print(f"Successfully processed all {num_success} conversations.")
160+
else:
161+
print(
162+
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
163+
)
164+
165+
166+
if __name__ == "__main__":
167+
cli_args = parse_args()
168+
asyncio.run(main(cli_args))
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using a Hugging Face model and saves them to
18+
# the specified output directory.
19+
20+
python3 collect_hidden_states/compute_hiddens_hf.py \
21+
--model meta-llama/Llama-3.2-1B-Instruct \
22+
--input-file synthetic_conversations/daring-anteater.jsonl \
23+
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to send conversations for hidden state collection
17+
# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.
18+
19+
python3 collect_hidden_states/send_conversations_for_hiddens.py \
20+
--model meta-llama/Llama-3.2-1B-Instruct \
21+
--input-file synthetic_conversations/mtbench.jsonl \
22+
--output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
23+
# --debug-max-num-conversations-per-split 1000
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utility script to print a sample of hidden states extracted from a dataset."""
17+
18+
import argparse
19+
import random
20+
from pathlib import Path
21+
22+
import torch
23+
24+
25+
def parse_args() -> argparse.Namespace:
26+
parser = argparse.ArgumentParser(
27+
description="Print a sample of hidden states from a dataset."
28+
"This script will crawl the provided directory for hidden state files,"
29+
" and print a small number of samples."
30+
)
31+
32+
parser.add_argument(
33+
"--input-path",
34+
type=Path,
35+
required=True,
36+
help="Path to the base input directory containing hidden states."
37+
"Alternatively, this can be a path to a specific `.pt` file.",
38+
)
39+
parser.add_argument(
40+
"--num-samples",
41+
type=int,
42+
default=1,
43+
help="Number of samples to print per split. If input_path is a file, this is ignored. "
44+
"Defaults to 1.",
45+
)
46+
return parser.parse_args()
47+
48+
49+
def main(args: argparse.Namespace) -> None:
50+
# Iterate through the input directory and find all hidden state files
51+
if args.input_path.is_file():
52+
all_files = [args.input_path]
53+
else:
54+
all_files = list(args.input_path.glob("*.pt"))
55+
56+
sampled_files = (
57+
random.sample(all_files, args.num_samples)
58+
if len(all_files) > args.num_samples
59+
else all_files
60+
)
61+
62+
for i, file in enumerate(sampled_files):
63+
data = torch.load(file)
64+
expected_keys = [
65+
"input_ids",
66+
"hidden_states",
67+
"aux_hidden_states",
68+
"conversation_id",
69+
]
70+
if set(expected_keys) != set(data.keys()):
71+
print(f"File {file} does not contain all expected keys: {expected_keys}")
72+
print(f" Found keys: {list(data.keys())}")
73+
continue
74+
print(f"Sample {i + 1}: {file.name}")
75+
for key in ["input_ids", "hidden_states", "aux_hidden_states"]:
76+
print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}")
77+
print(f"conversation_id: {data['conversation_id']}")
78+
input_ids_list = data["input_ids"].tolist()
79+
hidden_states = data["hidden_states"]
80+
print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}")
81+
print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}")
82+
print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}")
83+
84+
print(f"\n\nDone. Found: {len(all_files)} files in total.")
85+
86+
87+
if __name__ == "__main__":
88+
args = parse_args()
89+
main(args)

0 commit comments

Comments
 (0)