Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/speculative_decoding/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
Daring-Anteater
input_conversations
synthetic_conversations
ckpts
16 changes: 16 additions & 0 deletions examples/speculative_decoding/collect_hidden_states/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collect hidden states from a dataset of conversations."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Extract hidden states from an HF-compatible LLM."""

import argparse
import asyncio
import json
from pathlib import Path

import torch
from tqdm import tqdm as tqdm
from transformers import AutoModel, AutoTokenizer

REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations
by running full conversations through a Hugging Face model."""
)

## Model & Generation Parameters ##
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the served model.",
)

## Client Parameters ##
parser.add_argument(
"--max-seq-len",
type=int,
default=3072,
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
Defaults to 3072 tokens.""",
)

## I/O Parameters ##
parser.add_argument(
"--input-file",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="""Root directory in which to save the hidden states.
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
)
parser.add_argument(
"--debug-max-num-conversations",
type=int,
default=None,
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)

return parser.parse_args()


async def main(args: argparse.Namespace) -> None:
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])

print("Loaded", len(all_conversations), "conversations from", args.input_file)

model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)

tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")

output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
for idx, entry in enumerate(
tqdm(
all_conversations[: args.debug_max_num_conversations],
desc="Processing conversations",
total=num_total_conversations,
)
):
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

Comment on lines +112 to +117
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid KeyError when 'conversations' is missing.

Indexing with entry["conversations"] will crash; use .get() and validate.

-        conversations = entry["conversations"]
+        conversations = entry.get("conversations")
         if not conversations or not isinstance(conversations, list):
             num_invalid += 1
             continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry.get("conversations")
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 112 to 117, replace the unsafe indexing entry["conversations"] with
a safe lookup and validation: use entry.get("conversations") into the
conversations variable, then check if conversations is truthy and
isinstance(conversations, list); if not, increment num_invalid and continue.
Ensure you do not assume the key exists and handle None or non-list types
consistently.

# Tokenize and check length
input_ids = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_template=False
)
num_input_tokens = input_ids.shape[1]
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
Comment on lines +127 to +131
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device handling with device_map='auto' (model.device may be missing or wrong).

Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
# Get hidden states
with torch.inference_mode():
emb_device = model.get_input_embeddings().weight.device
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.

else:
assert num_hidden_layers + 1 == len(outputs.hidden_states), (
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}."
)
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
hidden_states = outputs.hidden_states
selected_layer_indices = [
2,
max(0, num_hidden_layers // 2),
max(1, num_hidden_layers - 3),
]
selected_layer_indices = sorted(set(selected_layer_indices))
aux_hidden_states = torch.cat(
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
)
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu()
output_file = output_dir / f"{conversation_id}.pt"
num_success += 1
with open(output_file, "wb") as f:
torch.save(
{
"input_ids": input_ids.squeeze(0).cpu(),
"hidden_states": output_hidden_states,
"aux_hidden_states": aux_hidden_states,
"conversation_id": conversation_id,
},
f,
)

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)


if __name__ == "__main__":
cli_args = parse_args()
asyncio.run(main(cli_args))
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Comment on lines +1 to +3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add shebang and safe shell flags; make script executable-friendly.

Without a shebang, shells/tools can misinterpret the file; enable strict mode.

+#!/usr/bin/env bash
+set -euo pipefail
@@
-python3 collect_hidden_states/compute_hidden_states_hf.py \
+python3 collect_hidden_states/compute_hidden_states_hf.py \
   --model meta-llama/Llama-3.2-1B-Instruct \
   --input-file synthetic_conversations/daring-anteater.jsonl \
   --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

Also applies to: 16-23

🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
around lines 1-3 (and similarly apply to lines 16-23), the script lacks a
shebang and safe shell flags; add a top-line shebang (e.g., #!/usr/bin/env bash)
and enable strict mode by setting set -euo pipefail and a safe IFS (IFS=$'\n\t')
at the top of the file; ensure the file mode is executable (chmod +x) in the
repository or note it in the commit so the script runs robustly in CI and user
environments.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory.

python3 collect_hidden_states/compute_hidden_states_hf.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Comment on lines +1 to +3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix split suffix/filenames; add shebang and strict mode.

As written, split likely creates part-0..7.jsonl, but the loop reads part-00..07.jsonl. Also add shebang/safety.

+#!/usr/bin/env bash
+set -euo pipefail
@@
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
+OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+mkdir -p "$OUTPUT_DIR"
@@
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+split -n l/8 -d -a 2 --numeric-suffixes=00 --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part-
@@
-for i in $(seq 0 7)
-do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
-done
+for i in $(seq 0 7); do
+  part="/tmp/part-$(printf '%02d' "$i").jsonl"
+  CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+    --model meta-llama/Llama-3.2-1B-Instruct \
+    --input-file "$part" \
+    --output-dir "$OUTPUT_DIR" &
+done
 wait
@@
-rm /tmp/part-*.jsonl
+rm -f /tmp/part-*.jsonl

Also applies to: 25-36

🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
around lines 1-3 (and similarly lines 25-36), the script lacks a shebang and
strict shell settings, and the split/loop filename suffixes are inconsistent:
split produces single-digit suffixes but the loop expects two-digit names. Add a
shebang (#!/usr/bin/env bash) and enable strict mode (set -euo pipefail), change
the split invocation to produce numeric two-digit suffixes and a .jsonl
extension (use split -d -a 2 --additional-suffix=.jsonl <...> part-), and update
the consumer loop to either glob part-*.jsonl or explicitly format two-digit
indices (e.g., part-00.jsonl..part-07.jsonl) so names match; apply the same
changes to the block at lines 25-36.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using a Hugging Face model and saves them to
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.

# Note: depending on the write-throughput of the destination disk, this is not guaranteed
# to yield a speed improvement compared to running the model-parallel version. Consider
# benchmarking on a smaller dataset before launching a large run.

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 7)
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
done
wait

rm /tmp/part-*.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to send conversations for hidden state collection
# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.

python3 collect_hidden_states/send_conversations_for_hiddens.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/mtbench.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
# --debug-max-num-conversations-per-split 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility script to print a sample of hidden states extracted from a dataset."""

import argparse
import random
from pathlib import Path

import torch


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Print a sample of hidden states from a dataset."
"This script will crawl the provided directory for hidden state files,"
" and print a small number of samples."
)

parser.add_argument(
"--input-path",
type=Path,
required=True,
help="Path to the base input directory containing hidden states."
"Alternatively, this can be a path to a specific `.pt` file.",
)
parser.add_argument(
"--num-samples",
type=int,
default=1,
help="Number of samples to print per split. If input_path is a file, this is ignored. "
"Defaults to 1.",
)
return parser.parse_args()


def main(args: argparse.Namespace) -> None:
# Iterate through the input directory and find all hidden state files
if args.input_path.is_file():
all_files = [args.input_path]
else:
all_files = list(args.input_path.glob("*.pt"))

sampled_files = (
random.sample(all_files, args.num_samples)
if len(all_files) > args.num_samples
else all_files
)

for i, file in enumerate(sampled_files):
data = torch.load(file)
expected_keys = [
"input_ids",
"hidden_states",
"aux_hidden_states",
"conversation_id",
]
if set(expected_keys) != set(data.keys()):
print(f"File {file} does not contain all expected keys: {expected_keys}")
print(f" Found keys: {list(data.keys())}")
continue
print(f"Sample {i + 1}: {file.name}")
for key in ["input_ids", "hidden_states", "aux_hidden_states"]:
print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}")
print(f"conversation_id: {data['conversation_id']}")
input_ids_list = data["input_ids"].tolist()
hidden_states = data["hidden_states"]
print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}")
print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}")
print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}")

print(f"\n\nDone. Found: {len(all_files)} files in total.")


if __name__ == "__main__":
args = parse_args()
main(args)
Loading
Loading