Skip to content

Commit f74bf59

Browse files
committed
Minor tweaks to data preparation
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent c6c84d2 commit f74bf59

File tree

6 files changed

+27
-15
lines changed

6 files changed

+27
-15
lines changed

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from tqdm import tqdm as tqdm
2525
from transformers import AutoModel, AutoTokenizer
2626

27+
REMOVE_THINK_CHAT_TEMPLATE = (
28+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
29+
)
30+
2731

2832
def parse_args() -> argparse.Namespace:
2933
parser = argparse.ArgumentParser(
@@ -92,6 +96,7 @@ async def main(args: argparse.Namespace) -> None:
9296
tokenizer = AutoTokenizer.from_pretrained(args.model)
9397
if tokenizer.pad_token is None:
9498
tokenizer.pad_token = tokenizer.eos_token
99+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
95100

96101
output_dir = args.output_dir
97102
output_dir.mkdir(parents=True, exist_ok=True)
@@ -132,7 +137,12 @@ async def main(args: argparse.Namespace) -> None:
132137
)
133138
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
134139
hidden_states = outputs.hidden_states
135-
selected_layer_indices = [2, num_hidden_layers // 2, num_hidden_layers - 3]
140+
selected_layer_indices = [
141+
2,
142+
max(0, num_hidden_layers // 2),
143+
max(1, num_hidden_layers - 3),
144+
]
145+
selected_layer_indices = sorted(set(selected_layer_indices))
136146
aux_hidden_states = torch.cat(
137147
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
138148
)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# This script computes hidden states using a Hugging Face model and saves them to
1818
# the specified output directory.
1919

20-
python3 collect_hidden_states/compute_hiddens_hf.py \
20+
python3 collect_hidden_states/compute_hidden_states_hf.py \
2121
--model meta-llama/Llama-3.2-1B-Instruct \
2222
--input-file synthetic_conversations/daring-anteater.jsonl \
2323
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ async def main(args: argparse.Namespace) -> None:
152152
)
153153

154154
input_ids = tokenizer.apply_chat_template(
155-
conversations, return_tensors=None, add_generation_template=False
155+
conversations, return_tensors=None, add_generation_template=False, tokenize=True
156156
)
157157
num_input_tokens = len(input_ids)
158158
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
# Example script to prepare a dataset of prompts for generation
1717
# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset.
1818

19-
python3 make_prompts_for_gen/add_sharegpt.py --output-split eval --output-file data/mtbench_prompts_dataset.json
20-
# python3 make_prompts_for_gen/add_ultrachat.py --ultrachat-split train_sft --output-split train
21-
# python3 make_prompts_for_gen/add_ultrachat.py --ultrachat-split train_gen --output-split train
22-
# python3 make_prompts_for_gen/add_ultrachat.py --ultrachat-split test_sft --output-split mix_test
23-
# python3 make_prompts_for_gen/add_ultrachat.py --ultrachat-split test_gen --output-split mix_test
24-
python3 make_prompts_for_gen/add_mtbench.py --output-split train --output-file data/mtbench_prompts_dataset.json
25-
# python3 make_prompts_for_gen/add_mtbench.py --output-split eval --output-file data/mtbench_prompts_dataset.json
19+
python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train
20+
# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train
21+
# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train
22+
# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train
23+
# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test
24+
# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test
25+
python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test

examples/speculative_decoding/prepare_input_conversations/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
async def download_file(url: str, destination: Path) -> None:
2727
"""Download a file from a URL to a specified destination."""
28+
destination.parent.mkdir(parents=True, exist_ok=True)
2829
async with aiohttp.ClientSession() as session, session.get(url) as response:
2930
if response.status != 200:
3031
msg = f"Failed to download {url}: {response.status}"
@@ -83,7 +84,8 @@ def add_conversations_to_split(conversations: list, dataset_dir: Path, split: st
8384
else:
8485
print(f"Added {num_new_entries} new conversations to {dataset_file}.")
8586

86-
with open(dataset_file, "w", encoding="utf-8") as f:
87+
dataset_dir.mkdir(parents=True, exist_ok=True)
88+
with dataset_file.open("w", encoding="utf-8") as f:
8789
for entry in all_conversations:
8890
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
8991

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,10 @@ class HFEagleModel(EagleModel):
336336

337337
def _set_default_aux_hidden_state_layers(self):
338338
# Read a custom config attribute since we override num_hidden_layers for offline training
339-
if self.eagle_offline:
340-
num_layers = self.config.num_orig_hidden_layers
341-
else:
342-
num_layers = self.config.num_hidden_layers
339+
num_layers = self.config.num_hidden_layers
340+
if self.eagle_offline and (num_layers is None or num_layers <= 0):
341+
num_layers = getattr(self.config, "num_orig_hidden_layers", 0)
342+
343343
self.eagle_config.eagle_aux_hidden_state_layer_ids = [
344344
1,
345345
max(0, num_layers // 2 - 1),

0 commit comments

Comments
 (0)