Skip to content

Commit f92be76

Browse files
committed
Removed DetachedHFEagleModel, and misc tweaks
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent b02a3b4 commit f92be76

File tree

11 files changed

+57
-647
lines changed

11 files changed

+57
-647
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
Daring-Anteater
2+
input_conversations
3+
synthetic_conversations
4+
ckpts
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Example usage of the script to compute the hidden states for a conversation dataset
17+
# This script computes hidden states using a Hugging Face model and saves them to
18+
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
19+
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.
20+
21+
# Note: depending on the write-throughput of the destination disk, this is not guaranteed
22+
# to yield a speed improvement compared to running the model-parallel version. Consider
23+
# benchmarking on a smaller dataset before launching a large run.
24+
25+
INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
26+
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
27+
28+
split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
29+
30+
for i in $(seq 0 7)
31+
do
32+
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
33+
done
34+
wait
35+
36+
rm /tmp/part-*.jsonl

examples/speculative_decoding/eagle_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,13 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
219219
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
220220

221221
# Make sure the input_ids have the same shape
222-
if not torch.equal(preprocessed_base["input_ids"], offline_data["input_ids"]):
222+
if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape:
223223
msg = f"""Input IDs from offline data do not match the preprocessed input IDs
224224
for offline data sample at {offline_file_path}."""
225225
raise ValueError(msg)
226226

227227
ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache
228+
ret["input_ids"] = offline_data["input_ids"]
228229
ret["kwargs"] = {
229230
"base_model_outputs": {
230231
"base_model_hidden_states": offline_data["hidden_states"],
@@ -370,11 +371,9 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
370371

371372
batch = {
372373
**base_batch,
373-
"kwargs": {
374-
"base_model_outputs": {
375-
"base_model_hidden_states": batch_hidden_states,
376-
"aux_hidden_states": batch_aux_hidden_states,
377-
}
374+
"base_model_outputs": {
375+
"base_model_hidden_states": batch_hidden_states,
376+
"aux_hidden_states": batch_aux_hidden_states,
378377
},
379378
}
380379

examples/speculative_decoding/gen_synthetic_conversations/__init__.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)