Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/speculative_decoding/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
Daring-Anteater
input_conversations
synthetic_conversations
ckpts
16 changes: 16 additions & 0 deletions examples/speculative_decoding/collect_hidden_states/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collect hidden states from a dataset of conversations."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Extract hidden states from an HF-compatible LLM."""

import argparse
import asyncio
import json
from pathlib import Path

import torch
from tqdm import tqdm as tqdm
from transformers import AutoModel, AutoTokenizer

REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations
by running full conversations through a Hugging Face model."""
)

## Model & Generation Parameters ##
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the served model.",
)

## Client Parameters ##
parser.add_argument(
"--max-seq-len",
type=int,
default=3072,
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
Defaults to 3072 tokens.""",
)

## I/O Parameters ##
parser.add_argument(
"--input-file",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="""Root directory in which to save the hidden states.
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
)
parser.add_argument(
"--debug-max-num-conversations",
type=int,
default=None,
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)

return parser.parse_args()


async def main(args: argparse.Namespace) -> None:
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])

if any(not entry.get("conversation_id") for entry in all_conversations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @benchislett, seems like there is a bug to this line:

  1. When some entry has a conversation_id=0, not entry.get("conversation_id") will return True, causing an error raised.
  2. Since we add the fallback below to allow no conversation_id, we should probably also remove this check here.

msg = "All conversations must have a 'conversation_id' field."
raise ValueError(msg)

print("Loaded", len(all_conversations), "conversations from", args.input_file)

model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)

tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")

output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
for idx, entry in enumerate(
tqdm(
all_conversations[: args.debug_max_num_conversations],
desc="Processing conversations",
total=num_total_conversations,
)
):
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

# Tokenize and check length
input_ids = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_template=False
)
num_input_tokens = input_ids.shape[1]
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
Comment on lines +131 to +135
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device handling with device_map='auto' (model.device may be missing or wrong).

Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
# Get hidden states
with torch.inference_mode():
emb_device = model.get_input_embeddings().weight.device
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.

else:
assert num_hidden_layers + 1 == len(outputs.hidden_states), (
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}."
)
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
hidden_states = outputs.hidden_states
selected_layer_indices = [
2,
max(0, num_hidden_layers // 2),
max(1, num_hidden_layers - 3),
]
selected_layer_indices = sorted(set(selected_layer_indices))
aux_hidden_states = torch.cat(
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
)
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu()
output_file = output_dir / f"{conversation_id}.pt"
num_success += 1
with open(output_file, "wb") as f:
torch.save(
{
"input_ids": input_ids.squeeze(0).cpu(),
"hidden_states": output_hidden_states,
"aux_hidden_states": aux_hidden_states,
"conversation_id": conversation_id,
},
f,
)

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)


if __name__ == "__main__":
cli_args = parse_args()
asyncio.run(main(cli_args))
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory.

python3 collect_hidden_states/compute_hidden_states_hf.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.

# Note: depending on the write-throughput of the destination disk, this is not guaranteed
# to yield a speed improvement compared to running the model-parallel version. Consider
# benchmarking on a smaller dataset before launching a large run.

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 7)
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
done
wait

rm /tmp/part-*.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to send conversations for hidden state collection
# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.

python3 collect_hidden_states/send_conversations_for_hiddens.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/mtbench.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
# --debug-max-num-conversations-per-split 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility script to print a sample of hidden states extracted from a dataset."""

import argparse
import random
from pathlib import Path

import torch


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Print a sample of hidden states from a dataset."
"This script will crawl the provided directory for hidden state files,"
" and print a small number of samples."
)

parser.add_argument(
"--input-path",
type=Path,
required=True,
help="Path to the base input directory containing hidden states."
"Alternatively, this can be a path to a specific `.pt` file.",
)
parser.add_argument(
"--num-samples",
type=int,
default=1,
help="Number of samples to print per split. If input_path is a file, this is ignored. "
"Defaults to 1.",
)
return parser.parse_args()


def main(args: argparse.Namespace) -> None:
# Iterate through the input directory and find all hidden state files
if args.input_path.is_file():
all_files = [args.input_path]
else:
all_files = list(args.input_path.glob("*.pt"))

sampled_files = (
random.sample(all_files, args.num_samples)
if len(all_files) > args.num_samples
else all_files
)

for i, file in enumerate(sampled_files):
data = torch.load(file)
expected_keys = [
"input_ids",
"hidden_states",
"aux_hidden_states",
"conversation_id",
]
if set(expected_keys) != set(data.keys()):
print(f"File {file} does not contain all expected keys: {expected_keys}")
print(f" Found keys: {list(data.keys())}")
continue
print(f"Sample {i + 1}: {file.name}")
for key in ["input_ids", "hidden_states", "aux_hidden_states"]:
print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}")
print(f"conversation_id: {data['conversation_id']}")
input_ids_list = data["input_ids"].tolist()
hidden_states = data["hidden_states"]
print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}")
print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}")
print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}")

print(f"\n\nDone. Found: {len(all_files)} files in total.")


if __name__ == "__main__":
args = parse_args()
main(args)
Loading
Loading