-
Notifications
You must be signed in to change notification settings - Fork 162
Feature: Offline training for EAGLE3 #300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
47ced1c
e9cd48f
c6c84d2
f74bf59
bd27b86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
Daring-Anteater | ||
input_conversations | ||
synthetic_conversations | ||
ckpts |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Collect hidden states from a dataset of conversations.""" |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,178 @@ | ||||||||||||||||||||||||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||
# | ||||||||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||||||||
# | ||||||||||||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||
# | ||||||||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
"""Extract hidden states from an HF-compatible LLM.""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import argparse | ||||||||||||||||||||||||
import asyncio | ||||||||||||||||||||||||
import json | ||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||
from tqdm import tqdm as tqdm | ||||||||||||||||||||||||
from transformers import AutoModel, AutoTokenizer | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
REMOVE_THINK_CHAT_TEMPLATE = ( | ||||||||||||||||||||||||
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def parse_args() -> argparse.Namespace: | ||||||||||||||||||||||||
parser = argparse.ArgumentParser( | ||||||||||||||||||||||||
description="""Collect hidden states from conversations | ||||||||||||||||||||||||
by running full conversations through a Hugging Face model.""" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
## Model & Generation Parameters ## | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--model", | ||||||||||||||||||||||||
type=str, | ||||||||||||||||||||||||
required=True, | ||||||||||||||||||||||||
help="Name of the served model.", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
## Client Parameters ## | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--max-seq-len", | ||||||||||||||||||||||||
type=int, | ||||||||||||||||||||||||
default=3072, | ||||||||||||||||||||||||
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped. | ||||||||||||||||||||||||
Defaults to 3072 tokens.""", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
## I/O Parameters ## | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--input-file", | ||||||||||||||||||||||||
type=Path, | ||||||||||||||||||||||||
required=True, | ||||||||||||||||||||||||
help="""Path to the input `jsonl` file containing conversations. | ||||||||||||||||||||||||
Each entry must have a unique `conversation_id` field and a `conversations` field | ||||||||||||||||||||||||
containing a list of messages.""", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--output-dir", | ||||||||||||||||||||||||
type=Path, | ||||||||||||||||||||||||
required=True, | ||||||||||||||||||||||||
help="""Root directory in which to save the hidden states. | ||||||||||||||||||||||||
The data will be saved as a torch (`.pt`) dump file for each conversation.""", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--debug-max-num-conversations", | ||||||||||||||||||||||||
type=int, | ||||||||||||||||||||||||
default=None, | ||||||||||||||||||||||||
help="""For debugging purposes, limit the number of conversations processed. | ||||||||||||||||||||||||
Default is None, meaning no limit.""", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return parser.parse_args() | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
async def main(args: argparse.Namespace) -> None: | ||||||||||||||||||||||||
all_conversations = [] | ||||||||||||||||||||||||
with args.input_file.open("r", encoding="utf-8") as f: | ||||||||||||||||||||||||
all_conversations.extend([json.loads(line) for line in f if line.strip()]) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if any(not entry.get("conversation_id") for entry in all_conversations): | ||||||||||||||||||||||||
msg = "All conversations must have a 'conversation_id' field." | ||||||||||||||||||||||||
raise ValueError(msg) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
print("Loaded", len(all_conversations), "conversations from", args.input_file) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") | ||||||||||||||||||||||||
benchislett marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
num_hidden_layers = getattr(model.config, "num_hidden_layers", None) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
tokenizer = AutoTokenizer.from_pretrained(args.model) | ||||||||||||||||||||||||
if tokenizer.pad_token is None: | ||||||||||||||||||||||||
tokenizer.pad_token = tokenizer.eos_token | ||||||||||||||||||||||||
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
output_dir = args.output_dir | ||||||||||||||||||||||||
output_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||
num_skipped_too_long = 0 | ||||||||||||||||||||||||
num_invalid = 0 | ||||||||||||||||||||||||
num_success = 0 | ||||||||||||||||||||||||
num_total_conversations = min( | ||||||||||||||||||||||||
len(all_conversations), args.debug_max_num_conversations or len(all_conversations) | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
for entry in tqdm( | ||||||||||||||||||||||||
all_conversations[: args.debug_max_num_conversations], | ||||||||||||||||||||||||
desc="Processing conversations", | ||||||||||||||||||||||||
total=num_total_conversations, | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
conversation_id = entry["conversation_id"] | ||||||||||||||||||||||||
conversations = entry["conversations"] | ||||||||||||||||||||||||
if not conversations or not isinstance(conversations, list): | ||||||||||||||||||||||||
num_invalid += 1 | ||||||||||||||||||||||||
continue | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Tokenize and check length | ||||||||||||||||||||||||
input_ids = tokenizer.apply_chat_template( | ||||||||||||||||||||||||
conversations, return_tensors="pt", add_generation_template=False | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
num_input_tokens = input_ids.shape[1] | ||||||||||||||||||||||||
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: | ||||||||||||||||||||||||
num_skipped_too_long += 1 | ||||||||||||||||||||||||
continue | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Get hidden states | ||||||||||||||||||||||||
with torch.inference_mode(): | ||||||||||||||||||||||||
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) | ||||||||||||||||||||||||
if num_hidden_layers is None: | ||||||||||||||||||||||||
num_hidden_layers = len(outputs.hidden_states) - 1 | ||||||||||||||||||||||||
Comment on lines
+131
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix device handling with device_map='auto' (model.device may be missing or wrong). Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU. - with torch.inference_mode():
- outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+ with torch.inference_mode():
+ emb_device = model.get_input_embeddings().weight.device
+ outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
assert num_hidden_layers + 1 == len(outputs.hidden_states), ( | ||||||||||||||||||||||||
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states | ||||||||||||||||||||||||
hidden_states = outputs.hidden_states | ||||||||||||||||||||||||
selected_layer_indices = [ | ||||||||||||||||||||||||
2, | ||||||||||||||||||||||||
max(0, num_hidden_layers // 2), | ||||||||||||||||||||||||
max(1, num_hidden_layers - 3), | ||||||||||||||||||||||||
] | ||||||||||||||||||||||||
selected_layer_indices = sorted(set(selected_layer_indices)) | ||||||||||||||||||||||||
aux_hidden_states = torch.cat( | ||||||||||||||||||||||||
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu() | ||||||||||||||||||||||||
output_file = output_dir / f"{conversation_id}.pt" | ||||||||||||||||||||||||
num_success += 1 | ||||||||||||||||||||||||
with open(output_file, "wb") as f: | ||||||||||||||||||||||||
torch.save( | ||||||||||||||||||||||||
{ | ||||||||||||||||||||||||
"input_ids": input_ids.squeeze(0).cpu(), | ||||||||||||||||||||||||
benchislett marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
"hidden_states": output_hidden_states, | ||||||||||||||||||||||||
"aux_hidden_states": aux_hidden_states, | ||||||||||||||||||||||||
"conversation_id": conversation_id, | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
f, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if num_skipped_too_long > 0: | ||||||||||||||||||||||||
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") | ||||||||||||||||||||||||
if num_invalid > 0: | ||||||||||||||||||||||||
print(f"Skipped {num_invalid} invalid conversations without proper fields.") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if num_success == num_total_conversations: | ||||||||||||||||||||||||
print(f"Successfully processed all {num_success} conversations.") | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
print( | ||||||||||||||||||||||||
f"Successfully processed {num_success} out of {num_total_conversations} conversations." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||
cli_args = parse_args() | ||||||||||||||||||||||||
asyncio.run(main(cli_args)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Example usage of the script to compute the hidden states for a conversation dataset | ||
# This script computes hidden states using a Hugging Face model and saves them to | ||
# the specified output directory. | ||
|
||
python3 collect_hidden_states/compute_hidden_states_hf.py \ | ||
--model meta-llama/Llama-3.2-1B-Instruct \ | ||
--input-file synthetic_conversations/daring-anteater.jsonl \ | ||
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Example usage of the script to compute the hidden states for a conversation dataset | ||
# This script computes hidden states using a Hugging Face model and saves them to | ||
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting | ||
# the input file into 8 parts and running 8 processes in parallel, one on each GPU. | ||
|
||
# Note: depending on the write-throughput of the destination disk, this is not guaranteed | ||
# to yield a speed improvement compared to running the model-parallel version. Consider | ||
# benchmarking on a smaller dataset before launching a large run. | ||
|
||
INPUT_FILE=synthetic_conversations/daring-anteater.jsonl | ||
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ | ||
|
||
split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- | ||
|
||
for i in $(seq 0 7) | ||
do | ||
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & | ||
done | ||
wait | ||
|
||
rm /tmp/part-*.jsonl |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Example usage of the script to send conversations for hidden state collection | ||
# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states. | ||
|
||
python3 collect_hidden_states/send_conversations_for_hiddens.py \ | ||
--model meta-llama/Llama-3.2-1B-Instruct \ | ||
--input-file synthetic_conversations/mtbench.jsonl \ | ||
--output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ | ||
# --debug-max-num-conversations-per-split 1000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utility script to print a sample of hidden states extracted from a dataset.""" | ||
|
||
import argparse | ||
import random | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser( | ||
description="Print a sample of hidden states from a dataset." | ||
"This script will crawl the provided directory for hidden state files," | ||
" and print a small number of samples." | ||
) | ||
|
||
parser.add_argument( | ||
"--input-path", | ||
type=Path, | ||
required=True, | ||
help="Path to the base input directory containing hidden states." | ||
"Alternatively, this can be a path to a specific `.pt` file.", | ||
) | ||
parser.add_argument( | ||
"--num-samples", | ||
type=int, | ||
default=1, | ||
help="Number of samples to print per split. If input_path is a file, this is ignored. " | ||
"Defaults to 1.", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def main(args: argparse.Namespace) -> None: | ||
# Iterate through the input directory and find all hidden state files | ||
if args.input_path.is_file(): | ||
all_files = [args.input_path] | ||
else: | ||
all_files = list(args.input_path.glob("*.pt")) | ||
|
||
sampled_files = ( | ||
random.sample(all_files, args.num_samples) | ||
if len(all_files) > args.num_samples | ||
else all_files | ||
) | ||
|
||
for i, file in enumerate(sampled_files): | ||
data = torch.load(file) | ||
expected_keys = [ | ||
"input_ids", | ||
"hidden_states", | ||
"aux_hidden_states", | ||
"conversation_id", | ||
] | ||
if set(expected_keys) != set(data.keys()): | ||
print(f"File {file} does not contain all expected keys: {expected_keys}") | ||
print(f" Found keys: {list(data.keys())}") | ||
continue | ||
print(f"Sample {i + 1}: {file.name}") | ||
for key in ["input_ids", "hidden_states", "aux_hidden_states"]: | ||
print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}") | ||
print(f"conversation_id: {data['conversation_id']}") | ||
input_ids_list = data["input_ids"].tolist() | ||
hidden_states = data["hidden_states"] | ||
print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}") | ||
print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}") | ||
print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}") | ||
|
||
print(f"\n\nDone. Found: {len(all_files)} files in total.") | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @benchislett, seems like there is a bug to this line:
not entry.get("conversation_id")
will return True, causing an error raised.