Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/speculative_decoding/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
Daring-Anteater
input_conversations
synthetic_conversations
ckpts
16 changes: 16 additions & 0 deletions examples/speculative_decoding/collect_hidden_states/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collect hidden states from a dataset of conversations."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Extract hidden states from an HF-compatible LLM."""

import argparse
import asyncio
import json
from pathlib import Path

import torch
from tqdm import tqdm as tqdm
from transformers import AutoModel, AutoTokenizer

REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations
by running full conversations through a Hugging Face model."""
)

## Model & Generation Parameters ##
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the served model.",
)

## Client Parameters ##
parser.add_argument(
"--max-seq-len",
type=int,
default=3072,
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
Defaults to 3072 tokens.""",
)

## I/O Parameters ##
parser.add_argument(
"--input-file",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="""Root directory in which to save the hidden states.
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
)
parser.add_argument(
"--debug-max-num-conversations",
type=int,
default=None,
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)

return parser.parse_args()


async def main(args: argparse.Namespace) -> None:
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])

if any(not entry.get("conversation_id") for entry in all_conversations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @benchislett, seems like there is a bug to this line:

  1. When some entry has a conversation_id=0, not entry.get("conversation_id") will return True, causing an error raised.
  2. Since we add the fallback below to allow no conversation_id, we should probably also remove this check here.

msg = "All conversations must have a 'conversation_id' field."
raise ValueError(msg)

print("Loaded", len(all_conversations), "conversations from", args.input_file)

model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)

tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")

output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
for entry in tqdm(
all_conversations[: args.debug_max_num_conversations],
desc="Processing conversations",
total=num_total_conversations,
):
conversation_id = entry["conversation_id"]
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

# Tokenize and check length
input_ids = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_template=False
)
num_input_tokens = input_ids.shape[1]
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
Comment on lines +131 to +135
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device handling with device_map='auto' (model.device may be missing or wrong).

Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
# Get hidden states
with torch.inference_mode():
emb_device = model.get_input_embeddings().weight.device
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.

else:
assert num_hidden_layers + 1 == len(outputs.hidden_states), (
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}."
)
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
hidden_states = outputs.hidden_states
selected_layer_indices = [
2,
max(0, num_hidden_layers // 2),
max(1, num_hidden_layers - 3),
]
selected_layer_indices = sorted(set(selected_layer_indices))
aux_hidden_states = torch.cat(
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
)
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu()
output_file = output_dir / f"{conversation_id}.pt"
num_success += 1
with open(output_file, "wb") as f:
torch.save(
{
"input_ids": input_ids.squeeze(0).cpu(),
"hidden_states": output_hidden_states,
"aux_hidden_states": aux_hidden_states,
"conversation_id": conversation_id,
},
f,
)

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)


if __name__ == "__main__":
cli_args = parse_args()
asyncio.run(main(cli_args))
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory.

python3 collect_hidden_states/compute_hidden_states_hf.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.

# Note: depending on the write-throughput of the destination disk, this is not guaranteed
# to yield a speed improvement compared to running the model-parallel version. Consider
# benchmarking on a smaller dataset before launching a large run.

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 7)
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
done
wait

rm /tmp/part-*.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to send conversations for hidden state collection
# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.

python3 collect_hidden_states/send_conversations_for_hiddens.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/mtbench.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
# --debug-max-num-conversations-per-split 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility script to print a sample of hidden states extracted from a dataset."""

import argparse
import random
from pathlib import Path

import torch


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Print a sample of hidden states from a dataset."
"This script will crawl the provided directory for hidden state files,"
" and print a small number of samples."
)

parser.add_argument(
"--input-path",
type=Path,
required=True,
help="Path to the base input directory containing hidden states."
"Alternatively, this can be a path to a specific `.pt` file.",
)
parser.add_argument(
"--num-samples",
type=int,
default=1,
help="Number of samples to print per split. If input_path is a file, this is ignored. "
"Defaults to 1.",
)
return parser.parse_args()


def main(args: argparse.Namespace) -> None:
# Iterate through the input directory and find all hidden state files
if args.input_path.is_file():
all_files = [args.input_path]
else:
all_files = list(args.input_path.glob("*.pt"))

sampled_files = (
random.sample(all_files, args.num_samples)
if len(all_files) > args.num_samples
else all_files
)

for i, file in enumerate(sampled_files):
data = torch.load(file)
expected_keys = [
"input_ids",
"hidden_states",
"aux_hidden_states",
"conversation_id",
]
if set(expected_keys) != set(data.keys()):
print(f"File {file} does not contain all expected keys: {expected_keys}")
print(f" Found keys: {list(data.keys())}")
continue
print(f"Sample {i + 1}: {file.name}")
for key in ["input_ids", "hidden_states", "aux_hidden_states"]:
print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}")
print(f"conversation_id: {data['conversation_id']}")
input_ids_list = data["input_ids"].tolist()
hidden_states = data["hidden_states"]
print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}")
print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}")
print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}")

print(f"\n\nDone. Found: {len(all_files)} files in total.")


if __name__ == "__main__":
args = parse_args()
main(args)
Loading
Loading