-
Notifications
You must be signed in to change notification settings - Fork 169
Feature: Offline training for EAGLE3 #300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds dataset-prep utilities and scripts to build conversation datasets, new tools to compute/send/sample per-conversation hidden states and emit per-conversation .pt artifacts, integrates an offline training path (dataset/collator/dataloader/CLI wiring), updates a transformer plugin for offline usage, updates .gitignore, and removes a legacy launcher. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User
participant Prep as Prepare scripts
participant HF as compute_hidden_states_hf
participant API as OpenAI-like server
participant DS as Disk (.pt)
participant Trainer as Trainer (offline)
U->>Prep: run add_* → produce input_conversations JSONL
alt local HF compute
U->>HF: compute_hidden_states_hf (--model, --input-file)
HF->>DS: write per-conversation .pt (input_ids, hidden_states, aux_hidden_states)
else send to server
U->>API: send_conversations_for_hiddens (--base-url, --input-file)
API->>DS: server writes per-conversation .pt
end
U->>Trainer: launch_train.sh --offline-data PATH
Trainer->>DS: read .pt files
Trainer->>Trainer: OfflineSupervisedDataset + DataCollatorForOffline → training (eagle_offline=true)
sequenceDiagram
autonumber
participant C as send_conversations_for_hiddens
participant Tok as Tokenizer
participant Meta as /tmp/meta.json
participant API as OpenAI-like endpoint
participant Out as Output Dir
C->>Tok: apply_chat_template(conversations) -> input_ids / prompt
C->>Meta: write {conversation_id, output_file}
C->>API: completions.create(model, prompt, max_tokens=1)
alt success
API-->>C: 200 OK
Note right of API: serving engine coordinates writing .pt to Out
API->>Out: .pt created
else error / too long
API-->>C: error
C->>C: increment counters, skip
end
C->>Meta: cleanup entry
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Show resolved
Hide resolved
|
||
# Extend the data sample with the hidden states from the .pt file | ||
max_length = self.tokenizer.model_max_length | ||
offline_data = torch.load(offline_file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious:
- With current implementation, will there by any pre-fetching mechanism to this tensor loading (perhaps taken care inside HF trainer)?
- Considering there's a limit in total disk bandwidth, will data loading possibly be a bottleneck limiting the training speed (if we further optimize the training loop, e.g. the TTT part)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes, I believe standard pytorch/HF datasets will handle prefetching, multiple loading worker processes, etc.
- Indeed, disk bandwidth can absolutely bottleneck training. If this becomes an issue, we can offset this using techniques such as compressed hidden-states on-disk, or using a file system with faster read speeds (e.g. using many smaller disks for high parallel read throughput). However, the speed-of-light for even a single moderate-quality disk is quite good, so networking a few 1-4TB disks together can easily saturate the (optimized) GPU throughput.
Can we remove gen_synthetic_conversations from the PR as we have https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding/distributed_generate |
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Outdated
Show resolved
Hide resolved
4282dbe
to
f6cb37a
Compare
examples/speculative_decoding/gen_synthetic_conversations/run_vllm_server.sh
Outdated
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Show resolved
Hide resolved
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #300 +/- ##
==========================================
- Coverage 73.84% 73.83% -0.01%
==========================================
Files 172 172
Lines 17453 17453
==========================================
- Hits 12888 12887 -1
- Misses 4565 4566 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
LGTM. Tried offline workflow with tinyllama and got reasonable AR. |
1839ab5
to
f92be76
Compare
/ok to test f92be76 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
examples/speculative_decoding/main.py (1)
188-190
: Good: honor EagleConfig.eagle_offline per prior feedback.
This aligns with the earlier request to use the existing config flag.examples/speculative_decoding/prepare_input_conversations/utils.py (2)
26-36
: Add a network timeout and avoid multi-context in one line for clarity.Prevents indefinite hangs and improves readability; keep parent dir creation.
Apply:
async def download_file(url: str, destination: Path) -> None: """Download a file from a URL to a specified destination.""" destination.parent.mkdir(parents=True, exist_ok=True) - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status != 200: - msg = f"Failed to download {url}: {response.status}" - raise RuntimeError(msg) - content = await response.read() - destination.write_bytes(content) - print(f"Downloaded {url} to {destination}") + timeout = aiohttp.ClientTimeout(total=600) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + if response.status != 200: + msg = f"Failed to download {url}: {response.status}" + raise RuntimeError(msg) + content = await response.read() + destination.write_bytes(content) + print(f"Downloaded {url} to {destination}")
45-86
: Prevent duplicate IDs within the same update and validate required keys.Currently, duplicates inside the provided conversations list can be appended multiple times because existing_ids isn’t updated as you add. Also, missing "conversations" will raise KeyError later; validate early.
Apply:
def add_conversations_to_split(conversations: list, dataset_dir: Path, split: str) -> None: @@ - # Open the dataset file for the specified split, or create it if it doesn't exist - dataset_file = dataset_dir / f"{split}.jsonl" + # Ensure output directory exists and open/create split file + dataset_dir.mkdir(parents=True, exist_ok=True) + dataset_file = dataset_dir / f"{split}.jsonl" @@ - existing_ids = {entry["conversation_id"] for entry in all_conversations} + existing_ids = {entry["conversation_id"] for entry in all_conversations} num_new_entries = 0 num_duplicates = 0 for entry in conversations: - if entry.get("conversation_id") is None: + entry_id = entry.get("conversation_id") + if entry_id is None: raise ValueError("Each conversation must have a 'conversation_id' field.") - if entry["conversation_id"] not in existing_ids: + if "conversations" not in entry: + raise ValueError("Each conversation must have a 'conversations' field.") + if entry_id not in existing_ids: all_conversations.append( { - "conversation_id": entry["conversation_id"], + "conversation_id": entry_id, "conversations": entry["conversations"], } ) num_new_entries += 1 + existing_ids.add(entry_id) else: num_duplicates += 1 @@ - dataset_dir.mkdir(parents=True, exist_ok=True) with dataset_file.open("w", encoding="utf-8") as f: for entry in all_conversations: f.write(json.dumps(entry, ensure_ascii=False) + "\n")examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
27-29
: Unify REMOVE_THINK_CHAT_TEMPLATE with training and guard None chat_template.Avoid divergence from training preprocessing and prevent AttributeError when chat_template is None.
-from transformers import AutoModel, AutoTokenizer +from transformers import AutoModel, AutoTokenizer +try: + # Keep in sync with training; import if available when run from examples/speculative_decoding + from eagle_utils import REMOVE_THINK_CHAT_TEMPLATE +except Exception: + # Fallback for standalone runs; ensure value matches eagle_utils.REMOVE_THINK_CHAT_TEMPLATE + REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" + ) @@ -REMOVE_THINK_CHAT_TEMPLATE = ( - "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" -) @@ - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + )Also applies to: 96-100
138-148
: Clamp/deduplicate auxiliary layer indices; avoid out‑of‑range/duplicates on small models.Current logic can pick invalid indices (e.g., 2 when num_hidden_layers < 3).
- # Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states + # Extract hidden states from early/mid/late layers; clamp within [0, N-1] hidden_states = outputs.hidden_states - selected_layer_indices = [ - 2, - max(0, num_hidden_layers // 2), - max(1, num_hidden_layers - 3), - ] - selected_layer_indices = sorted(set(selected_layer_indices)) + candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3] + selected_layer_indices = sorted( + {i for i in candidates if 0 <= i <= num_hidden_layers - 1} + ) + if not selected_layer_indices: + selected_layer_indices = [0] aux_hidden_states = torch.cat( [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 )
🧹 Nitpick comments (12)
examples/speculative_decoding/main.py (3)
72-81
: Annotate as Optional to match None default.
offline_data_path
defaults to None but is typed asstr
. Make itstr | None
for consistency witheval_data_path
and type checkers.Apply this diff:
- offline_data_path: str = field( + offline_data_path: str | None = field(
144-145
: Make the offline switch robust to empty strings.
is not None
treats""
as offline. Prefer truthiness.Apply this diff:
-use_offline_training = data_args.offline_data_path is not None +use_offline_training = bool(data_args.offline_data_path)
150-159
: num_hidden_layers=0: add safe fallback and always set num_orig_hidden_layers.Keep your space‑saving default but harden for models that can’t instantiate with 0 layers, and ensure
num_orig_hidden_layers
is set even when resuming or the override path changes.Apply this diff:
- model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, torch_dtype="auto", **model_kwargs - ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings - model_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) - model.config.num_orig_hidden_layers = model_config.num_hidden_layers + model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} + try: + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, torch_dtype="auto", **model_kwargs + ) + except Exception as e: + if use_offline_training: + print_rank_0("num_hidden_layers=0 failed; falling back to 1 for offline training.") + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, torch_dtype="auto", num_hidden_layers=1 + ) + else: + raise + if use_offline_training and not hasattr(model.config, "num_orig_hidden_layers"): + # Record original depth for plugins that need it. + base_cfg = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + model.config.num_orig_hidden_layers = base_cfg.num_hidden_layersPlease run a quick smoke test on at least one model per family you target (e.g., Llama, Mistral, Qwen2, Phi‑3) to confirm no fallback is triggered unexpectedly and training proceeds.
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (2)
1-1
: Add a shebang to fix ShellCheck SC2148 and enable script execution.Without a shebang, shells may interpret this file incorrectly and ShellCheck flags SC2148.
Apply:
+#!/usr/bin/env bash # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
16-25
: Fix relative paths; make script location-agnostic; add strict mode and ensure output dir.The script lives in prepare_input_conversations/, but calls into prepare_input_conversations/*.py again, which breaks when run from repo root or this directory. Also add set -euo pipefail and create a data dir.
Apply:
-# Example script to prepare a dataset of prompts for generation -# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. +# Example script to prepare a dataset of prompts for generation +# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# Optional: change DATASET_DIR or pass via env +DATASET_DIR="${DATASET_DIR:-${SCRIPT_DIR}/data}" +mkdir -p "${DATASET_DIR}" -python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train -# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test -python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test +python3 "${SCRIPT_DIR}/add_daring_anteater.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_sharegpt.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_sft --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_gen --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_sft --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_gen --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +python3 "${SCRIPT_DIR}/add_mtbench.py" --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}Note: If the add_* CLIs don’t support --dataset-dir, the parameter will be omitted; the mkdir remains safe.
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
93-125
: Make ratio check tolerant to FP error and avoid mutating global RNG state.isclose avoids false negatives; using a local Random(seed) preserves global RNG for callers.
Apply:
+import math @@ - if train_ratio + val_ratio + test_ratio != 1.0: + if not math.isclose(train_ratio + val_ratio + test_ratio, 1.0, rel_tol=0.0, abs_tol=1e-9): msg = "Ratios must sum to 1.0" raise ValueError(msg) @@ - if shuffle: - random.seed(seed) - random.shuffle(conversations) + if shuffle: + random.Random(seed).shuffle(conversations)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)
1-24
: Add shebang/safe flags, fix debug flag name, and quote paths.Without a shebang, running as an executable fails; ShellCheck SC2148 also flags this. The commented flag name doesn’t exist in the CLI (should be --debug-max-num-conversations). Also make invocation path-robust and quote args.
+#!/usr/bin/env bash +set -euo pipefail + +# Resolve script directory to make path robust when invoked from anywhere. +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + -python3 collect_hidden_states/send_conversations_for_hiddens.py \ - --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/mtbench.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ -# --debug-max-num-conversations-per-split 1000 +python3 "$script_dir/send_conversations_for_hiddens.py" \ + --model "meta-llama/Llama-3.2-1B-Instruct" \ + --input-file "synthetic_conversations/mtbench.jsonl" \ + --output-dir "/mnt/md0/eagle-hidden-states/llama1b/mtbench/" +# Optional: limit processed conversations during debugging +# --debug-max-num-conversations 1000examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)
93-95
: Set eval() before inference.Minor but standard to disable dropout and ensure deterministic behavior.
- model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model.eval()
103-127
: Counter name is misleading; it tracks too‑short and too‑long.Tiny naming nit; adjust for clarity.
- num_skipped_too_long = 0 + num_filtered_by_length = 0 @@ - if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: - num_skipped_too_long += 1 + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_filtered_by_length += 1 continue @@ - if num_skipped_too_long > 0: - print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") + if num_filtered_by_length > 0: + print(f"Skipped {num_filtered_by_length} conversations due to length constraints.")Also applies to: 163-164
82-83
: Async isn’t used; simplify to synchronous main.Removes unnecessary asyncio plumbing.
-async def main(args: argparse.Namespace) -> None: +def main(args: argparse.Namespace) -> None: @@ -if __name__ == "__main__": - cli_args = parse_args() - asyncio.run(main(cli_args)) +if __name__ == "__main__": + cli_args = parse_args() + main(cli_args)Also applies to: 176-178
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)
1-23
: Add shebang/safe flags, quote args, and create output dir.Prevents exec failures (SC2148) and path issues; ensures output dir exists.
+#!/usr/bin/env bash +set -euo pipefail + +# Resolve script directory +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Output directory (edit as needed) +OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/" +mkdir -p "$OUTPUT_DIR" + -python3 collect_hidden_states/compute_hidden_states_hf.py \ - --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +python3 "$script_dir/compute_hidden_states_hf.py" \ + --model "meta-llama/Llama-3.2-1B-Instruct" \ + --input-file "synthetic_conversations/daring-anteater.jsonl" \ + --output-dir "$OUTPUT_DIR" +# Optional: +# --max-seq-len 3072examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)
16-36
: Harden DP runner: shebang/safe flags, temp workspace, quoting, and mkdir.Avoids /tmp collisions, improves safety, and makes paths robust.
+#!/usr/bin/env bash +set -euo pipefail + -INPUT_FILE=synthetic_conversations/daring-anteater.jsonl -OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +INPUT_FILE="synthetic_conversations/daring-anteater.jsonl" +OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/" +mkdir -p "$OUTPUT_DIR" + -split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- +tmpdir="$(mktemp -d)" +trap 'rm -rf "$tmpdir"' EXIT +split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" "$tmpdir/part-" for i in $(seq 0 7) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & + CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \ + --model "meta-llama/Llama-3.2-1B-Instruct" \ + --input-file "$tmpdir/part-0${i}.jsonl" \ + --output-dir "$OUTPUT_DIR" & done wait - -rm /tmp/part-*.jsonl +# Temporary files cleaned via trap
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/speculative_decoding/.gitignore
(1 hunks)examples/speculative_decoding/collect_hidden_states/__init__.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
(1 hunks)examples/speculative_decoding/eagle_utils.py
(4 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/launch_train.sh
(4 hunks)examples/speculative_decoding/main.py
(4 hunks)examples/speculative_decoding/prepare_input_conversations/__init__.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
(1 hunks)examples/speculative_decoding/prepare_input_conversations/utils.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(4 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (13)
- examples/speculative_decoding/.gitignore
- examples/speculative_decoding/collect_hidden_states/init.py
- examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
- examples/speculative_decoding/launch_train.sh
- examples/speculative_decoding/prepare_input_conversations/init.py
- examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
- examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
- examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
- modelopt/torch/speculative/plugins/transformers.py
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
📚 Learning: 2025-09-17T14:14:44.950Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
Applied to files:
examples/speculative_decoding/main.py
🧬 Code graph analysis (4)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-312)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
parse_args
(30-90)main
(93-206)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
parse_args
(25-46)main
(49-84)
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
modelopt/torch/utils/random.py (2)
random
(59-61)shuffle
(148-150)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
train
(117-271)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
examples/speculative_decoding/main.py (2)
230-230
: Good: data module gated by use_offline_training.
Clear separation between online/offline loaders and collators.
237-238
: Nice: allow disabling AR validation with non-positive steps.
Useful for offline runs without inline eval.examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)
27-29
: Verify single source of truth for REMOVE_THINK_CHAT_TEMPLATErg returned "No files were searched" due to the glob filter; cannot confirm duplicates. Ensure no other copies exist and training/collection stay in lockstep — from the repo root run and paste outputs of:
rg -n --hidden --no-ignore 'REMOVE_THINK_CHAT_TEMPLATE'
rg -n --hidden --no-ignore ''
Location: examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py:27-29
# Get hidden states | ||
with torch.inference_mode(): | ||
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) | ||
if num_hidden_layers is None: | ||
num_hidden_layers = len(outputs.hidden_states) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix device handling with device_map='auto' (model.device may be missing or wrong).
Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.
- with torch.inference_mode():
- outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+ with torch.inference_mode():
+ emb_device = model.get_input_embeddings().weight.device
+ outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Get hidden states | |
with torch.inference_mode(): | |
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) | |
if num_hidden_layers is None: | |
num_hidden_layers = len(outputs.hidden_states) - 1 | |
# Get hidden states | |
with torch.inference_mode(): | |
emb_device = model.get_input_embeddings().weight.device | |
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True) | |
if num_hidden_layers is None: | |
num_hidden_layers = len(outputs.hidden_states) - 1 |
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Show resolved
Hide resolved
@benchislett You need to sign your commits with an SSH key. Please take a look at https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work |
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Outdated
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Outdated
Show resolved
Hide resolved
bd27b86
to
8708a77
Compare
"Code Quality" check failure seems erroneous to me. The namespace in question does seem to exist and be imported correctly:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/launch_train.sh (1)
91-95
: Guard division by zero when no GPUs are visible.DEFAULT_SAVE_STEPS will divide by zero if torch reports 0 CUDA devices (e.g., CI, CPU-only env). Provide a sane fallback.
Apply:
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +# Calculate save_steps +if [[ "${GPU_COUNT}" -gt 0 ]]; then + DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +else + echo "Warning: No GPUs detected by torch; defaulting save_steps to 8192." + DEFAULT_SAVE_STEPS=8192 +fi
🧹 Nitpick comments (19)
examples/speculative_decoding/launch_train.sh (1)
127-136
: Good: validates offline data path and builds args.Consider also quoting the path to tolerate spaces.
Apply:
- OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH" + OFFLINE_TRAINING_ARGS="--offline-data-path \"$OFFLINE_DATA_PATH\""modelopt/torch/speculative/plugins/transformers.py (1)
338-349
: Clamp aux layer IDs when offline and num_layers <= 0.If num_layers resolves to 0/None offline, current defaults may produce invalid indices (e.g., 1). Clamp to valid range and drop negatives.
Apply:
- num_layers = self.config.num_hidden_layers - if self.eagle_offline and (num_layers is None or num_layers <= 0): - num_layers = getattr(self.config, "num_orig_hidden_layers", 0) - - self.eagle_config.eagle_aux_hidden_state_layer_ids = [ - 1, - max(0, num_layers // 2 - 1), - max(0, num_layers - 4), - ] - self.eagle_config.eagle_aux_hidden_state_layer_ids = list( - set(self.eagle_config.eagle_aux_hidden_state_layer_ids) - ) + num_layers = self.config.num_hidden_layers + if self.eagle_offline and (num_layers is None or num_layers <= 0): + num_layers = getattr(self.config, "num_orig_hidden_layers", 0) + # Propose defaults then clamp into [0, max(0, num_layers-1)] and dedupe/sort + candidate_ids = [1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)] + hi = max(0, num_layers - 1) + clamped = sorted({min(max(i, 0), hi) for i in candidate_ids}) + self.eagle_config.eagle_aux_hidden_state_layer_ids = clampedexamples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)
16-23
: Add shebang and safe shell flags; optional: parametrize endpoint via env.Apply:
+#!/usr/bin/env bash +set -euo pipefail + python3 collect_hidden_states/send_conversations_for_hiddens.py \ --model meta-llama/Llama-3.2-1B-Instruct \ --input-file synthetic_conversations/mtbench.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ # --debug-max-num-conversations-per-split 1000examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
16-26
: Add shebang/safe flags and make calls cwd-agnostic.Use script-relative paths so it works from any directory.
Apply:
+#!/usr/bin/env bash +set -euo pipefail + +# Example script to prepare a dataset of prompts for generation # Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. -python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train -# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test -python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$ROOT/prepare_input_conversations/add_daring_anteater.py" --output-split-name train +# python3 "$ROOT/prepare_input_conversations/add_sharegpt.py" --output-split-name train +# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split train_sft --output-split-name train +# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split train_gen --output-split-name train +# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split test_sft --output-split-name mix_test +# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split test_gen --output-split-name mix_test +python3 "$ROOT/prepare_input_conversations/add_mtbench.py" --output-split-name mix_testexamples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)
16-36
: Add shebang/safe flags; consider quoting vars and checking split results.Apply:
+#!/usr/bin/env bash +set -euo pipefail + # Example usage of the script to compute the hidden states for a conversation dataset # This script computes hidden states using a Hugging Face model and saves them to # the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting # the input file into 8 parts and running 8 processes in parallel, one on each GPU. @@ -INPUT_FILE=synthetic_conversations/daring-anteater.jsonl -OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +INPUT_FILE="synthetic_conversations/daring-anteater.jsonl" +OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/" -split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- +split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part- for i in $(seq 0 7) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & + CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-file "/tmp/part-0${i}.jsonl" \ + --output-dir "$OUTPUT_DIR" & done wait -rm /tmp/part-*.jsonl +rm -f /tmp/part-*.jsonlexamples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
63-63
: Load tensors onto CPU to avoid CUDA dependency.Apply:
- data = torch.load(file) + data = torch.load(file, map_location="cpu")
64-73
: Don’t require exact key set; accept supersets.Future producers may add fields; only assert required keys present.
Apply:
- expected_keys = [ + expected_keys = [ "input_ids", "hidden_states", "aux_hidden_states", "conversation_id", ] - if set(expected_keys) != set(data.keys()): - print(f"File {file} does not contain all expected keys: {expected_keys}") - print(f" Found keys: {list(data.keys())}") - continue + missing = [k for k in expected_keys if k not in data] + if missing: + print(f"File {file} missing required keys: {missing}. Found: {list(data.keys())}") + continueexamples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (2)
23-27
: Make imports robust to script working directory/package execution.
from utils import ...
will fail if invoked outside this folder or as a module. Prefer relative import with a fallback for script execution.-from utils import ( +try: + from .utils import ( + dataset_splits_explanation, + id_for_conversation, + update_dataset_file_with_conversations, + ) +except ImportError: + from utils import ( dataset_splits_explanation, id_for_conversation, update_dataset_file_with_conversations, -) + )
61-65
: Prefer iterating rows directly to avoid repeated random access.Hugging Face datasets support direct iteration; it’s clearer and avoids repeated indexing overhead.
- for i in tqdm( - range(len(ds)), + for i in tqdm( + range(len(ds)), # or: enumerate(ds) desc=f"Loading UltraChat split {args.ultrachat_split}", total=len(ds), ):If acceptable, switch to
for i, row in enumerate(tqdm(ds, ...))
and userow[...]
.examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (2)
23-27
: Harden import path (same concern as other prep scripts).Apply the same relative-import fallback pattern suggested for add_ultrachat.py to avoid ImportError when run outside the folder.
63-76
: Defensive role parsing.
.lower()
on a missing/None role will raise. The diff above normalizes viastr(...).lower()
and falls back cleanly.examples/speculative_decoding/eagle_utils.py (2)
176-236
: Offline dataset: load safely on CPU and validate keys/shapes.
- Use
map_location="cpu"
to avoid accidental CUDA deserialization errors.- Validate presence of expected keys and provide a concise error.
- offline_data = torch.load(offline_file_path) + offline_data = torch.load(offline_file_path, map_location="cpu") + for k in ("input_ids", "hidden_states", "aux_hidden_states"): + if k not in offline_data: + raise ValueError(f"Missing key '{k}' in offline data: {offline_file_path}") offline_data["input_ids"] = offline_data["input_ids"][:max_length] offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] # Make sure the input_ids have the same shape if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape: - msg = f"""Input IDs from offline data do not match the preprocessed input IDs - for offline data sample at {offline_file_path}.""" - raise ValueError(msg) + raise ValueError( + f"Shape mismatch for input_ids at {offline_file_path}: " + f"offline={tuple(offline_data['input_ids'].shape)} " + f"preproc={tuple(preprocessed_base['input_ids'].shape)}" + )
251-256
: Handle blank lines in JSONL and set explicit encoding.Minor robustness: ignore empty lines and open with utf‑8.
- with open(data_args.data_path) as f: - if data_args.data_path.endswith("jsonl"): - data_json = [json.loads(line) for line in f] + with open(data_args.data_path, "r", encoding="utf-8") as f: + if data_args.data_path.endswith("jsonl"): + data_json = [json.loads(line) for line in f if line.strip()] else: data_json = json.load(f)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)
95-95
: Null‑safe chat template handling.Some tokenizers have no chat_template. Guard before replace.
- tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + )examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (3)
23-28
: Import robustness (same pattern as other prep scripts).Apply the relative‑import fallback used in prior comments to avoid ImportError when running from different CWDs.
62-85
: Be lenient on unknown roles; prefer skip over raise to avoid aborting long runs.Unexpected roles in ShareGPT can abort the entire job. Treat them as “skip this conversation” rather than raising.
- else: - err_msg = f"Unknown role in conversation: {turn.get('from')}" - raise ValueError(err_msg) + else: + # Unknown role; skip this conversation + return None
87-101
: Download step: ensure parent directory exists.
~/.cache
may not exist. download_file creates parents, but direct existence check may fail earlier if parent is missing. Create dirs before checking.- args.sharegpt_file = Path("~/.cache/sharegpt.json").expanduser().resolve() + args.sharegpt_file = Path("~/.cache/sharegpt.json").expanduser() + args.sharegpt_file.parent.mkdir(parents=True, exist_ok=True) + args.sharegpt_file = args.sharegpt_file.resolve()examples/speculative_decoding/prepare_input_conversations/utils.py (2)
103-110
: Avoid strict float equality for ratio sum.Use a small tolerance to prevent false negatives due to FP rounding.
- if train_ratio + val_ratio + test_ratio != 1.0: + total = train_ratio + val_ratio + test_ratio + if abs(total - 1.0) > 1e-9: msg = "Ratios must sum to 1.0" raise ValueError(msg)
115-124
: Don’t mutate caller’s list or global RNG state when shuffling.Shuffle a copy with a local RNG for determinism without side effects.
- if shuffle: - random.seed(seed) - random.shuffle(conversations) + if shuffle: + rng = random.Random(seed) + pool = conversations.copy() + rng.shuffle(pool) + else: + pool = conversations @@ - train_conversations = conversations[:train_count] - val_conversations = conversations[train_count : train_count + val_count] - test_conversations = conversations[train_count + val_count :] + train_conversations = pool[:train_count] + val_conversations = pool[train_count : train_count + val_count] + test_conversations = pool[train_count + val_count :]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/speculative_decoding/.gitignore
(1 hunks)examples/speculative_decoding/collect_hidden_states/__init__.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
(1 hunks)examples/speculative_decoding/eagle_utils.py
(4 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/launch_train.sh
(4 hunks)examples/speculative_decoding/main.py
(4 hunks)examples/speculative_decoding/prepare_input_conversations/__init__.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
(1 hunks)examples/speculative_decoding/prepare_input_conversations/utils.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(4 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
✅ Files skipped from review due to trivial changes (1)
- examples/speculative_decoding/collect_hidden_states/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
- examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
- examples/speculative_decoding/prepare_input_conversations/init.py
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/.gitignore
- examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T14:14:44.961Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
Applied to files:
modelopt/torch/speculative/plugins/transformers.py
examples/speculative_decoding/main.py
🧬 Code graph analysis (9)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
parse_args
(32-79)main
(82-171)
modelopt/torch/speculative/plugins/transformers.py (3)
modelopt/torch/opt/dynamic.py (2)
register
(1069-1096)config
(1265-1278)modelopt/torch/speculative/eagle/eagle_model.py (1)
EagleModel
(23-51)modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_set_default_aux_hidden_state_layers
(682-694)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-312)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
parse_args
(30-90)main
(93-204)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
parse_args
(25-46)main
(49-84)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
dataset_splits_explanation
(157-172)download_file
(26-35)id_for_conversation
(38-42)update_dataset_file_with_conversations
(127-154)
examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
dataset_splits_explanation
(157-172)id_for_conversation
(38-42)update_dataset_file_with_conversations
(127-154)
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
dataset_splits_explanation
(157-172)id_for_conversation
(38-42)update_dataset_file_with_conversations
(127-154)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
train
(117-271)
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (22)
examples/speculative_decoding/launch_train.sh (3)
33-36
: Good: adds offline-data plumbed through to main.Flag and value parsing look fine and align with DataArguments.offline_data_path.
111-113
: Increase default training sequence length to 2048 — OK.
168-169
: OK: injects offline args into CMD.modelopt/torch/speculative/plugins/transformers.py (3)
314-314
: Parameter name corrected to past_key_values.Aligns with newer Transformers API; resolves earlier confusion.
425-429
: Good: device selection for offline models via lm_head.Avoids indexing model.layers when depth is zero.
811-813
: Cache handling for offline/legacy paths looks correct.Dummy DynamicCache for offline, and from_legacy conversion when needed.
Also applies to: 826-828
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)
16-23
: Add shebang and safe shell flags; ensure script fails fast.Apply:
+#!/usr/bin/env bash +set -euo pipefail + python3 collect_hidden_states/compute_hidden_states_hf.py \ --model meta-llama/Llama-3.2-1B-Instruct \ --input-file synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/examples/speculative_decoding/main.py (6)
72-81
: Docs read well; field name aligns with launch script.
144-145
: Simple switch to enable offline training — OK.
188-190
: Correct: set EagleConfig.eagle_offline via config when offline.
230-231
: Good: pass use_offline_training into data module.
237-239
: Nice: allow disabling AR validation via non-positive steps.
150-158
: num_hidden_layers=0 override — keep, but verify across Transformers versions/models.Attempted verification here but the environment raised ModuleNotFoundError: No module named 'transformers'. Run the provided check locally or in CI with Transformers installed matching the repo constraint (>=4.48,<5.0) against these models: TinyLlama/TinyLlama-1.1B-Chat-v1.0, mistralai/Mistral-7B-v0.1, Qwen/Qwen2-1.5B, microsoft/Phi-3-mini-4k-instruct; if any fail, remove or add model-specific handling.
File: examples/speculative_decoding/main.py:150-158examples/speculative_decoding/eagle_utils.py (2)
271-281
: Conversation ID fallback logic looks good and resolves prior “conv_id=0” bug.Using
is None
avoids treating0
as falsy. This addresses previous feedback.
353-380
: Collator for offline hidden states looks correct.Padding logic is consistent with base collator; keys match dataset outputs; batch structure is clear.
Please confirm downstream training expects:
- base_model_outputs.base_model_hidden_states: [B, L, D_base]
- base_model_outputs.aux_hidden_states: [B, L, D_aux]
and that both are on CPU at collate time and moved to device by the trainer.examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
136-146
: Clamp/dedupe auxiliary layer indices for small-layer models.Index 2 can be out of range when num_hidden_layers < 2. Clamp to [0, num_hidden_layers] and dedupe in order.
- selected_layer_indices = [ - 2, - max(0, num_hidden_layers // 2), - max(1, num_hidden_layers - 3), - ] - selected_layer_indices = sorted(set(selected_layer_indices)) + candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3] + lo, hi = 0, num_hidden_layers # hidden_states has num_hidden_layers+1 entries + selected_layer_indices = [] + seen = set() + for i in candidates: + i = min(max(i, lo), hi) + if i not in seen: + seen.add(i) + selected_layer_indices.append(i) + if not selected_layer_indices: + selected_layer_indices = [hi]
129-135
: Fix device handling for device_map='auto'.
model.device
may be missing/incorrect under sharded loading. Use the embeddings’ device.- with torch.inference_mode(): - outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) + with torch.inference_mode(): + emb_device = model.get_input_embeddings().weight.device + outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)examples/speculative_decoding/prepare_input_conversations/utils.py (5)
87-89
: Resolved: output dir is created before writing.Creating
dataset_dir
before opening the file addresses first-run failures.
38-43
: Stable content hash looks good.Deterministic ID via sorted JSON + fixed separators is appropriate here.
Confirm upstream producers don’t inject nondeterministic fields (timestamps, UUIDs) into
conversation
before hashing, or IDs will change run-to-run.
157-172
: Helpful split helptext.Clear and accurate, including percent-escaping.
26-35
: Add timeout and stream to disk to avoid hangs and large-memory reads.No network timeout and
await response.read()
can hang and load large files into RAM. UseClientTimeout
and stream chunks.async def download_file(url: str, destination: Path) -> None: """Download a file from a URL to a specified destination.""" - destination.parent.mkdir(parents=True, exist_ok=True) - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status != 200: - msg = f"Failed to download {url}: {response.status}" - raise RuntimeError(msg) - content = await response.read() - destination.write_bytes(content) - print(f"Downloaded {url} to {destination}") + destination.parent.mkdir(parents=True, exist_ok=True) + timeout = aiohttp.ClientTimeout(total=600) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + if response.status // 100 != 2: + raise RuntimeError(f"Failed to download {url}: {response.status}") + with destination.open("wb") as f: + async for chunk in response.content.iter_chunked(1 << 20): + f.write(chunk) + print(f"Downloaded {url} to {destination}")
62-77
: Prevent duplicate IDs within the same batch and validate required key.If the input contains duplicate
conversation_id
s in the same call, both are written becauseexisting_ids
isn’t updated. Also, missing"conversations"
will raiseKeyError
later; validate explicitly.existing_ids = {entry["conversation_id"] for entry in all_conversations} @@ for entry in conversations: if entry.get("conversation_id") is None: raise ValueError("Each conversation must have a 'conversation_id' field.") + if "conversations" not in entry: + raise ValueError("Each conversation must have a 'conversations' field.") if entry["conversation_id"] not in existing_ids: all_conversations.append( { "conversation_id": entry["conversation_id"], "conversations": entry["conversations"], } ) + existing_ids.add(entry["conversation_id"]) num_new_entries += 1 else: num_duplicates += 1
processed_conversations = [] | ||
for msg in conversations: | ||
if "from" in msg: | ||
role = msg["from"].lower() | ||
elif "role" in msg: | ||
role = msg["role"].lower() | ||
else: | ||
continue | ||
if role == "human": | ||
role = "user" | ||
elif role == "gpt": | ||
role = "assistant" | ||
|
||
if "value" in msg: | ||
content = msg["value"] | ||
elif "text" in msg: | ||
content = msg["text"] | ||
elif "content" in msg: | ||
content = msg["content"] | ||
else: | ||
continue | ||
content = content.strip() | ||
if content: | ||
processed_conversations.append({"role": role, "content": content}) | ||
|
||
input_conversations.append( | ||
{"conversation_id": prompt_id, "conversations": processed_conversations} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip empty conversations and compute ID from processed messages.
- If no valid messages are added, you’ll append an entry with an empty list; downstream preprocess() indexes source[0] and will crash.
- The conversation_id should be derived from the processed messages to ensure deduplication matches the actual content used for training.
- conversations = ds[i]["conversations"]
+ conversations = ds[i]["conversations"]
if conversations and isinstance(conversations, list):
- prompt_id = f"daring-anteater-{i:05}_" + id_for_conversation(conversations)
processed_conversations = []
for msg in conversations:
- if "from" in msg:
- role = msg["from"].lower()
- elif "role" in msg:
- role = msg["role"].lower()
+ from_field = msg.get("from")
+ role_field = msg.get("role")
+ role = (str(from_field or role_field or "")).lower()
else:
continue
if role == "human":
role = "user"
elif role == "gpt":
role = "assistant"
- if "value" in msg:
- content = msg["value"]
- elif "text" in msg:
- content = msg["text"]
- elif "content" in msg:
- content = msg["content"]
+ if "value" in msg:
+ content = msg["value"]
+ elif "text" in msg:
+ content = msg["text"]
+ elif "content" in msg:
+ content = msg["content"]
else:
continue
content = content.strip()
if content:
processed_conversations.append({"role": role, "content": content})
- input_conversations.append(
- {"conversation_id": prompt_id, "conversations": processed_conversations}
- )
+ if processed_conversations:
+ cid = id_for_conversation(processed_conversations)
+ prompt_id = f"daring-anteater-{i:05}_{cid}"
+ input_conversations.append(
+ {"conversation_id": prompt_id, "conversations": processed_conversations}
+ )
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
around lines 62 to 89, the current logic can append entries with empty
conversations and uses prompt_id instead of an ID derived from the actual
processed messages; update the code to (1) skip appending when
processed_conversations is empty so downstream preprocess() does not index into
an empty source, and (2) compute conversation_id from the
processed_conversations (for example by deterministic serialization and hashing
of the sequence of role+content pairs) so deduplication is based on the actual
messages used for training.
prompt = ds[i]["prompt"].strip() | ||
prompt_id = ds[i]["prompt_id"].strip() | ||
if prompt and prompt_id: | ||
msgs = [{"role": "user", "content": prompt}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against missing/typed fields before calling .strip().
Some rows may have None/missing keys or non‑string prompt_id. Current code will crash. Coerce to str and validate.
- prompt = ds[i]["prompt"].strip()
- prompt_id = ds[i]["prompt_id"].strip()
- if prompt and prompt_id:
+ prompt_raw = ds[i].get("prompt", "")
+ prompt_id_raw = ds[i].get("prompt_id", "")
+ prompt = str(prompt_raw).strip() if prompt_raw is not None else ""
+ prompt_id = str(prompt_id_raw).strip() if prompt_id_raw is not None else ""
+ if prompt and prompt_id:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
prompt = ds[i]["prompt"].strip() | |
prompt_id = ds[i]["prompt_id"].strip() | |
if prompt and prompt_id: | |
msgs = [{"role": "user", "content": prompt}] | |
prompt_raw = ds[i].get("prompt", "") | |
prompt_id_raw = ds[i].get("prompt_id", "") | |
prompt = str(prompt_raw).strip() if prompt_raw is not None else "" | |
prompt_id = str(prompt_id_raw).strip() if prompt_id_raw is not None else "" | |
if prompt and prompt_id: | |
msgs = [{"role": "user", "content": prompt}] |
🤖 Prompt for AI Agents
In examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
around lines 66 to 69, the code calls .strip() directly on ds[i]["prompt"] and
ds[i]["prompt_id"], which can crash if fields are missing or not strings; change
to safely access with dict.get (or check "prompt" in ds[i]), coerce values to
str only if not None (e.g., val = ds[i].get("prompt"); prompt = str(val).strip()
if val is not None else ""), validate that prompt and prompt_id are non-empty
after stripping, and ensure prompt_id is of the expected form (or converted)
before using it to build msgs; only create msgs when both cleaned strings are
non-empty.
@benchislett #300 (comment) |
8708a77
to
bf8d1a8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/speculative/plugins/transformers.py (1)
569-585
: Mask dtype mismatch will crash concat (bool vs float)
attention_mask_0
is float (from expand_mask), but you create boolean masks (.bool()
) and concatenate with floats → runtime error. Also assign booleans into float tensors. Keep masks in the same dtype asattention_mask_0
.Pattern fix for this block (apply similarly to the later two blocks at Lines 602-627 and 642-681):
- zero_mask = torch.ones_like(attention_mask_0).bool() - mask_2_1 = attention_mask_0.clone().detach() + zero_mask = torch.ones_like(attention_mask_0) + mask_2_1 = attention_mask_0.clone().detach() mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() + mask_2_2 = torch.ones_like(attention_mask_0) for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False + mask_2_2[:, :, i, i] = 0 cat_attention_mask = torch.cat( ( torch.cat((attention_mask_0, zero_mask), dim=-1), torch.cat((mask_2_1, mask_2_2), dim=-1), ), dim=-2, ) - cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) + cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)examples/speculative_decoding/eagle_utils.py (1)
315-351
: Pad labels with IGNORE_TOKEN_ID, not 0Current collator pads
labels
with 0, which will be treated as valid targets and corrupt loss. Add a pad value parameter and useIGNORE_TOKEN_ID
for labels.class DataCollatorWithPadding: - def paddingtensor2d(self, intensors, length): + def paddingtensor2d(self, intensors, length, pad_value=0): n, dim = intensors.shape - padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) + padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) + pad_value outtensors = torch.cat((intensors, padding_tensor)) return outtensors - def paddingtensor(self, intensors, length): - padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) + def paddingtensor(self, intensors, length, pad_value=0): + padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) + pad_value outtensors = torch.cat((intensors, padding_tensor)) return outtensors def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: max_length = max(item["input_ids"].shape[0] for item in features) batch_input_ids = torch.stack( - [self.paddingtensor(item["input_ids"], max_length) for item in features] + [self.paddingtensor(item["input_ids"], max_length, pad_value=0) for item in features] ) batch_attention_mask = torch.stack( - [self.paddingtensor(item["attention_mask"], max_length) for item in features] + [self.paddingtensor(item["attention_mask"], max_length, pad_value=0) for item in features] ) batch_loss_mask = torch.stack( - [self.paddingtensor(item["loss_mask"], max_length) for item in features] + [self.paddingtensor(item["loss_mask"], max_length, pad_value=0) for item in features] ) batch_labels = torch.stack( - [self.paddingtensor(item["labels"], max_length) for item in features] + [self.paddingtensor(item["labels"], max_length, pad_value=IGNORE_TOKEN_ID) for item in features] )
🧹 Nitpick comments (11)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)
89-91
: Ensure eval mode for deterministic inference.Set
model.eval()
after loading.- model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model.eval()
18-20
: Drop asyncio (not used) or add awaits.
main
is async but contains no awaits; simplify to a regular function.-import asyncio @@ -async def main(args: argparse.Namespace) -> None: +def main(args: argparse.Namespace) -> None: @@ if __name__ == "__main__": cli_args = parse_args() - asyncio.run(main(cli_args)) + main(cli_args)Also applies to: 82-82, 174-176
112-113
: Sanitize conversation_id for filenames.Untrusted
conversation_id
could contain path separators or illegal chars.+import re @@ - conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) + raw_id = entry.get("conversation_id", "{:08d}".format(idx)) + conversation_id = re.sub(r"[^A-Za-z0-9._-]", "_", str(raw_id))Also applies to: 148-159, 18-26
modelopt/torch/speculative/plugins/transformers.py (1)
337-343
: Safer fallback when num_orig_hidden_layers missingIf offline and
num_hidden_layers<=0
butnum_orig_hidden_layers
is absent, you setnum_layers=0
, yet later include layer id1
. Use a non‑zero fallback.- if self.eagle_offline and (num_layers is None or num_layers <= 0): - num_layers = getattr(self.config, "num_orig_hidden_layers", 0) + if self.eagle_offline and (num_layers is None or num_layers <= 0): + num_layers = getattr(self.config, "num_orig_hidden_layers", self.config.num_hidden_layers)examples/speculative_decoding/prepare_input_conversations/utils.py (3)
26-36
: Add a network timeout to avoid hanging downloads
aiohttp.ClientSession()
without timeout can hang indefinitely.-async def download_file(url: str, destination: Path) -> None: +async def download_file(url: str, destination: Path) -> None: """Download a file from a URL to a specified destination.""" destination.parent.mkdir(parents=True, exist_ok=True) - async with aiohttp.ClientSession() as session, session.get(url) as response: + timeout = aiohttp.ClientTimeout(total=600) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: if response.status != 200: msg = f"Failed to download {url}: {response.status}" raise RuntimeError(msg) content = await response.read() destination.write_bytes(content) print(f"Downloaded {url} to {destination}")
104-109
: Float-sum equality is brittleUse tolerance when validating ratios.
- if train_ratio + val_ratio + test_ratio != 1.0: + if abs((train_ratio + val_ratio + test_ratio) - 1.0) > 1e-6: msg = "Ratios must sum to 1.0" raise ValueError(msg)
50-61
: Validate 'conversations' presence on existing dataGuard against malformed prior entries to avoid KeyError later.
- if any(not entry.get("conversation_id") for entry in all_conversations): + if any(not entry.get("conversation_id") for entry in all_conversations): msg = "All existing conversations must have a 'conversation_id' field." raise ValueError(msg) + if any("conversations" not in entry for entry in all_conversations): + msg = "All existing conversations must have a 'conversations' field." + raise ValueError(msg)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
1-26
: Add shebang and strict mode; ensure output dir existsImproves portability and early failure.
+#!/usr/bin/env bash +set -euo pipefail + +# Ensure default output dir exists (scripts default to input_conversations/) +mkdir -p input_conversationsexamples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (3)
1-1
: Add a shebang (or ShellCheck directive) to satisfy SC2148 and make the script directly executable.Without a shebang, shells and tooling can misinterpret the script.
Apply one of the following diffs (preferred: add a shebang):
+#!/usr/bin/env bash # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0
Or, if you intentionally keep it sourced/run via
bash file.sh
, add a ShellCheck directive:+# shellcheck shell=bash # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0
16-23
: Harden and parameterize the example so it runs from any cwd and fails fast on errors.Resolve brittle relative paths, create output dir, and enforce required env for the OpenAI‑compatible endpoint.
Apply this diff:
-# Example usage of the script to send conversations for hidden state collection -# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states. - -python3 collect_hidden_states/send_conversations_for_hiddens.py \ - --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/mtbench.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ -# --debug-max-num-conversations-per-split 1000 +# Example usage of the script to send conversations for hidden state collection. +# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states. + +set -euo pipefail +IFS=$'\n\t' + +# Resolve repo root relative to this script so it works from any cwd +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +# Required env for OpenAI-compatible server (fail early if missing) +: "${OPENAI_API_BASE:?Set OPENAI_API_BASE to your server URL, e.g. http://127.0.0.1:8000/v1}" +: "${OPENAI_API_KEY:?Set OPENAI_API_KEY (dummy value is fine for many local servers)}" + +# Tunables (override via env: MODEL, INPUT_FILE, OUTPUT_DIR) +MODEL="${MODEL:-meta-llama/Llama-3.2-1B-Instruct}" +INPUT_FILE="${INPUT_FILE:-$REPO_ROOT/examples/speculative_decoding/synthetic_conversations/mtbench.jsonl}" +OUTPUT_DIR="${OUTPUT_DIR:-/mnt/md0/eagle-hidden-states/llama1b/mtbench}" + +# Sanity checks +[ -f "$INPUT_FILE" ] || { echo "Input not found: $INPUT_FILE" >&2; exit 1; } +mkdir -p "$OUTPUT_DIR" + +python3 "$REPO_ROOT/collect_hidden_states/send_conversations_for_hiddens.py" \ + --model "$MODEL" \ + --input-file "$INPUT_FILE" \ + --output-dir "$OUTPUT_DIR" +# Optional: +# --debug-max-num-conversations-per-split 1000
19-23
: Make the script executable in git.If this is intended to be run directly, mark it executable.
#!/bin/bash # From repo root git update-index --chmod=+x examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/speculative_decoding/.gitignore
(1 hunks)examples/speculative_decoding/collect_hidden_states/__init__.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
(1 hunks)examples/speculative_decoding/eagle_utils.py
(4 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/launch_train.sh
(4 hunks)examples/speculative_decoding/main.py
(4 hunks)examples/speculative_decoding/prepare_input_conversations/__init__.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
(1 hunks)examples/speculative_decoding/prepare_input_conversations/utils.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(4 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (10)
- examples/speculative_decoding/.gitignore
- examples/speculative_decoding/collect_hidden_states/init.py
- examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
- examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
- examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
- examples/speculative_decoding/launch_train.sh
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
- examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
- examples/speculative_decoding/prepare_input_conversations/init.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T14:14:44.961Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
Applied to files:
modelopt/torch/speculative/plugins/transformers.py
examples/speculative_decoding/main.py
🧬 Code graph analysis (6)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/speculative/plugins/transformers.py (2)
modelopt/torch/speculative/eagle/eagle_model.py (1)
EagleModel
(23-51)modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_set_default_aux_hidden_state_layers
(682-694)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-312)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
dataset_splits_explanation
(157-172)download_file
(26-35)id_for_conversation
(38-42)update_dataset_file_with_conversations
(127-154)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
parse_args
(30-90)main
(93-204)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
parse_args
(25-46)main
(49-84)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
train
(117-271)
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (11)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)
92-96
: Guard chat_template access; avoid AttributeError and centralize constant.
.chat_template
can be None; direct.replace(...)
will crash. Also prefer importing the shared constant to avoid drift.Apply:
- tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + )And replace the local constant with a single source of truth:
-from transformers import AutoModel, AutoTokenizer - -REMOVE_THINK_CHAT_TEMPLATE = ( - "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" -) +from transformers import AutoModel, AutoTokenizer +from eagle_utils import REMOVE_THINK_CHAT_TEMPLATE # keep centralized
127-131
: Fix device placement with device_map='auto'.
model.device
can be missing/wrong under sharding; send inputs to the embeddings device.- with torch.inference_mode(): - outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) + with torch.inference_mode(): + emb_device = model.get_input_embeddings().weight.device + outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
136-146
: Harden layer selection; avoid OOB/duplicates for small-layer models.Index 2 can be out of range; clamp/dedupe defensively.
- selected_layer_indices = [ - 2, - max(0, num_hidden_layers // 2), - max(1, num_hidden_layers - 3), - ] - selected_layer_indices = sorted(set(selected_layer_indices)) + candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3] + selected_layer_indices = sorted( + {i for i in candidates if 0 <= i <= num_hidden_layers - 1} + ) + if not selected_layer_indices: + raise ValueError( + f"No valid auxiliary layer indices for num_hidden_layers={num_hidden_layers}." + )modelopt/torch/speculative/plugins/transformers.py (3)
309-318
: Double‑check kwarg name for decoder layer cachePassing
past_key_values=
intoLlamaDecoderLayer
may break on some Transformers versions that still usepast_key_value
. Confirm the minimum version you support; otherwise gate by version or forward both via kwargs filtering.I can generate a small gate to support both signatures if needed.
425-435
: Good device selection for offline pathReading device from
lm_head
when base model has 0 layers avoids attribute errors. Nice.
332-334
: Fix NameError: import OfflineEagleDMRegistryDecorator references
OfflineEagleDMRegistry
but it’s not imported; this fails at import time.Apply:
-from ..eagle.conversion import EagleDMRegistry +from ..eagle.conversion import EagleDMRegistry, OfflineEagleDMRegistryexamples/speculative_decoding/eagle_utils.py (1)
353-381
: Collator extension for offline HS looks goodBuilds batched hidden/aux tensors aligned to max HS length; composes with base batch cleanly.
examples/speculative_decoding/main.py (2)
150-158
: Overriding num_hidden_layers=0: verify supported model families/versionsThis trick can work but is not universally guaranteed across HF architectures/versions. Since you store
num_orig_hidden_layers
, the offline path is consistent; please confirm on the minimum Transformers version in setup and the families you target (Llama, Mistral, Qwen2, Phi‑3).
187-206
: Correct use of EagleConfig.eagle_offlineSetting
config["eagle_offline"]=True
wires the offline path coherently through conversion. LGTM.examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
87-96
: Nice: async download with cache and deterministic IDsEnd-to-end flow is clear and robust; IDs are stable via SHA-256 of messages.
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)
1-15
: License header looks good.SPDX identifiers and copyright block match project norms.
conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) | ||
conversations = entry["conversations"] | ||
if not conversations or not isinstance(conversations, list): | ||
num_invalid += 1 | ||
continue | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid KeyError when 'conversations' is missing.
Indexing with entry["conversations"]
will crash; use .get()
and validate.
- conversations = entry["conversations"]
+ conversations = entry.get("conversations")
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) | |
conversations = entry["conversations"] | |
if not conversations or not isinstance(conversations, list): | |
num_invalid += 1 | |
continue | |
conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) | |
conversations = entry.get("conversations") | |
if not conversations or not isinstance(conversations, list): | |
num_invalid += 1 | |
continue |
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 112 to 117, replace the unsafe indexing entry["conversations"] with
a safe lookup and validation: use entry.get("conversations") into the
conversations variable, then check if conversations is truthy and
isinstance(conversations, list); if not, increment num_invalid and continue.
Ensure you do not assume the key exists and handle None or non-list types
consistently.
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix split suffix/filenames; add shebang and strict mode.
As written, split
likely creates part-0..7.jsonl
, but the loop reads part-00..07.jsonl
. Also add shebang/safety.
+#!/usr/bin/env bash
+set -euo pipefail
@@
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
+OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+mkdir -p "$OUTPUT_DIR"
@@
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+split -n l/8 -d -a 2 --numeric-suffixes=00 --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part-
@@
-for i in $(seq 0 7)
-do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
-done
+for i in $(seq 0 7); do
+ part="/tmp/part-$(printf '%02d' "$i").jsonl"
+ CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+ --model meta-llama/Llama-3.2-1B-Instruct \
+ --input-file "$part" \
+ --output-dir "$OUTPUT_DIR" &
+done
wait
@@
-rm /tmp/part-*.jsonl
+rm -f /tmp/part-*.jsonl
Also applies to: 25-36
🧰 Tools
🪛 Shellcheck (0.11.0)
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
around lines 1-3 (and similarly lines 25-36), the script lacks a shebang and
strict shell settings, and the split/loop filename suffixes are inconsistent:
split produces single-digit suffixes but the loop expects two-digit names. Add a
shebang (#!/usr/bin/env bash) and enable strict mode (set -euo pipefail), change
the split invocation to produce numeric two-digit suffixes and a .jsonl
extension (use split -d -a 2 --additional-suffix=.jsonl <...> part-), and update
the consumer loop to either glob part-*.jsonl or explicitly format two-digit
indices (e.g., part-00.jsonl..part-07.jsonl) so names match; apply the same
changes to the block at lines 25-36.
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add shebang and safe shell flags; make script executable-friendly.
Without a shebang, shells/tools can misinterpret the file; enable strict mode.
+#!/usr/bin/env bash
+set -euo pipefail
@@
-python3 collect_hidden_states/compute_hidden_states_hf.py \
+python3 collect_hidden_states/compute_hidden_states_hf.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Also applies to: 16-23
🧰 Tools
🪛 Shellcheck (0.11.0)
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
around lines 1-3 (and similarly apply to lines 16-23), the script lacks a
shebang and safe shell flags; add a top-line shebang (e.g., #!/usr/bin/env bash)
and enable strict mode by setting set -euo pipefail and a safe IFS (IFS=$'\n\t')
at the top of the file; ensure the file mode is executable (chmod +x) in the
repository or note it in the commit so the script runs robustly in CI and user
environments.
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
Show resolved
Hide resolved
# Extend the data sample with the hidden states from the .pt file | ||
max_length = self.tokenizer.model_max_length | ||
offline_data = torch.load(offline_file_path) | ||
offline_data["input_ids"] = offline_data["input_ids"][:max_length] | ||
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] | ||
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] | ||
|
||
# Make sure the input_ids have the same shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Load offline tensors on CPU to avoid GPU spikes
torch.load(offline_file_path)
may move tensors to GPU if saved with device tags, causing dataloader-side GPU OOMs.
- offline_data = torch.load(offline_file_path)
+ offline_data = torch.load(offline_file_path, map_location="cpu")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Extend the data sample with the hidden states from the .pt file | |
max_length = self.tokenizer.model_max_length | |
offline_data = torch.load(offline_file_path) | |
offline_data["input_ids"] = offline_data["input_ids"][:max_length] | |
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] | |
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] | |
# Make sure the input_ids have the same shape | |
# Extend the data sample with the hidden states from the .pt file | |
max_length = self.tokenizer.model_max_length | |
offline_data = torch.load(offline_file_path, map_location="cpu") | |
offline_data["input_ids"] = offline_data["input_ids"][:max_length] | |
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] | |
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] | |
# Make sure the input_ids have the same shape |
🤖 Prompt for AI Agents
In examples/speculative_decoding/eagle_utils.py around lines 214 to 221,
torch.load(offline_file_path) can deserialize tensors onto their saved device
and risk moving them to GPU; change the load to explicitly map tensors to CPU
(e.g., use torch.load(..., map_location="cpu")) and, if needed, ensure any
tensors used later are .to(device) at the appropriate time so the dataloader
never accidentally allocates GPU memory during loading.
bf8d1a8
to
338d0e3
Compare
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]>
…rsation_id Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
338d0e3
to
5add4cb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/speculative_decoding/launch_train.sh (1)
91-95
: Division by zero when no CUDA devices.GPU_COUNT can be 0; DEFAULT_SAVE_STEPS=$((8192 / 0)) will abort the script on machines without visible GPUs (or before CUDA init in some CI).
Harden with a guard:
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") +if [[ -z "$GPU_COUNT" || "$GPU_COUNT" -le 0 ]]; then + GPU_COUNT=1 +fi +DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))examples/speculative_decoding/eagle_utils.py (1)
316-341
: Fix NameError in collator padding helpers and pad labels with IGNORE_TOKEN_IDintensors → intensors causes a runtime NameError. Also, labels are padded with 0; use IGNORE_TOKEN_ID to avoid loss on pad tokens. Adds truncation guards too.
- def paddingtensor2d(self, intensors, length): - n, dim = intensors.shape - padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors + def paddingtensor2d(self, intensors, length): + n, dim = intensors.shape + if length < n: + return intensors[:length] + padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor)) + return outtensors - - def paddingtensor(self, intensors, length): - padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors + def paddingtensor(self, intensors, length, pad_value=0): + if length < intensors.shape[0]: + return intensors[:length] + padding_tensor = torch.full((length - intensors.shape[0],), pad_value, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor)) + return outtensors @@ - batch_input_ids = torch.stack( - [self.paddingtensor(item["input_ids"], max_length) for item in features] - ) + pad_id = 0 # collator lacks tokenizer context; 0 is typical for LLaMA-family + batch_input_ids = torch.stack( + [self.paddingtensor(item["input_ids"], max_length, pad_id) for item in features] + ) @@ - batch_attention_mask = torch.stack( - [self.paddingtensor(item["attention_mask"], max_length) for item in features] - ) + batch_attention_mask = torch.stack( + [self.paddingtensor(item["attention_mask"], max_length, 0) for item in features] + ) @@ - batch_loss_mask = torch.stack( - [self.paddingtensor(item["loss_mask"], max_length) for item in features] - ) + batch_loss_mask = torch.stack( + [self.paddingtensor(item["loss_mask"], max_length, 0) for item in features] + ) @@ - batch_labels = torch.stack( - [self.paddingtensor(item["labels"], max_length) for item in features] - ) + batch_labels = torch.stack( + [self.paddingtensor(item["labels"], max_length, IGNORE_TOKEN_ID) for item in features] + )
🧹 Nitpick comments (14)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (2)
25-27
: Quote vars and ensure output directory exists.Prevent word-splitting; create the output directory.
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl -OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +INPUT_FILE="synthetic_conversations/daring-anteater.jsonl" +OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/" +mkdir -p "$OUTPUT_DIR"
36-36
: Rely on trap for cleanup; avoid risky rm patterns.The trap above safely removes the entire temp directory.
-rm /tmp/part-*.jsonl +# cleanup handled by trapexamples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)
16-19
: Parameterize paths, quote, and ensure output dir exists.Improves portability and robustness; keeps the example concise.
-# Example usage of the script to compute the hidden states for a conversation dataset +# Example usage of the script to compute the hidden states for a conversation dataset @@ -python3 collect_hidden_states/compute_hidden_states_hf.py \ - --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +MODEL="${MODEL:-meta-llama/Llama-3.2-1B-Instruct}" +INPUT_FILE="${INPUT_FILE:-synthetic_conversations/daring-anteater.jsonl}" +OUTPUT_DIR="${OUTPUT_DIR:-/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/}" +mkdir -p "$OUTPUT_DIR" + +# Resolve Python entry relative to this script for robustness +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +python3 "$SCRIPT_DIR/compute_hidden_states_hf.py" \ + --model "$MODEL" \ + --input-file "$INPUT_FILE" \ + --output-dir "$OUTPUT_DIR"examples/speculative_decoding/launch_train.sh (4)
33-36
: CLI flag name is slightly misleading vs downstream arg; consider aliasing.Downstream you pass --offline-data-path to main.py. Here the user-facing flag is --offline-data. Consider supporting both for clarity:
- Keep --offline-data (backward-compat)
- Add an alias --offline-data-path (sets the same var)
Apply this minimal alias:
- --offline-data*) + --offline-data*|--offline-data-path*) if [[ "$1" != *=* ]]; then shift; fi OFFLINE_DATA_PATH="${1#*=}" ;;
81-84
: Invalid-argument error prints the value, not the flag.For inputs like --foo=bar, message shows “bar” instead of “--foo”. Print $1 directly.
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"
146-171
: Quoting/arg safety when paths contain spaces.You already quote OFFLINE_DATA_PATH and DATA in the proposed patches. For consistency, consider quoting other path-like vars (MODEL, OUTPUT_DIR) or build CMD as an array to avoid word-splitting. Non-blocking.
107-109
: Unused vars (REDRAFTER_*).REDRAFTER_TOKENS and REDRAFTER_NUM_LAYERS aren’t used in this script. Remove or wire them, or add a comment if they’re placeholders for future flags.
examples/speculative_decoding/eagle_utils.py (2)
35-35
: Guard chat_template mutationSome tokenizers may lack chat_template or it may be None; also consider not mutating per-call.
- tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + )Optional follow-up: move this one-time adjustment to dataset init to avoid repeated mutation.
294-297
: Shuffle before split to avoid ordering biasCurrent 95/5 split preserves source order; shuffle with a seed.
- num_train = int(len(valid_entries) * 0.95) + # Optional: shuffle to avoid order bias + import random + rnd = random.Random(getattr(data_args, "seed", 42)) + rnd.shuffle(valid_entries) + num_train = int(len(valid_entries) * 0.95)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (2)
1-1
: Add a shebang (and strict mode) so the script runs as a shell script and satisfies ShellCheck.Without a shebang (SC2148), shells may guess incorrectly; strict mode avoids silent failures.
+#!/usr/bin/env bash +set -euo pipefail
19-19
: Make invocation robust regardless of the current working directory.Resolve the Python script relative to this file’s location to avoid path errors.
-python3 collect_hidden_states/send_conversations_for_hiddens.py \ +python3 "$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)/send_conversations_for_hiddens.py" \examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1)
189-191
: Remove redundant continue.Control flow already proceeds to the next iteration.
- num_success += 1 - continue + num_success += 1examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (2)
87-96
: Resolve output directory path for consistency.Normalize
--output-dir
(supports~
, avoids relative path surprises).Apply this diff:
async def main(args: argparse.Namespace) -> None: @@ if not args.sharegpt_file.exists(): err_msg = f"ShareGPT file {args.sharegpt_file} does not exist." raise FileNotFoundError(err_msg) + # Normalize output dir + args.output_dir = args.output_dir.expanduser().resolve()
62-83
: PII/safety filter hook (optional).ShareGPT "unfiltered" may contain PII/toxicity. Consider an optional sanitizer callback applied to
msgs
before hashing/emit.I can add a lightweight regex-based scrubber and a
--drop-pii
flag if you want.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/speculative_decoding/.gitignore
(1 hunks)examples/speculative_decoding/collect_hidden_states/__init__.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
(1 hunks)examples/speculative_decoding/eagle_utils.py
(4 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/launch_train.sh
(4 hunks)examples/speculative_decoding/main.py
(4 hunks)examples/speculative_decoding/prepare_input_conversations/__init__.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
(1 hunks)examples/speculative_decoding/prepare_input_conversations/utils.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(4 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (12)
- examples/speculative_decoding/prepare_input_conversations/init.py
- examples/speculative_decoding/.gitignore
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
- modelopt/torch/speculative/plugins/transformers.py
- examples/speculative_decoding/main.py
- examples/speculative_decoding/prepare_input_conversations/utils.py
- examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
- examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
- examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
- examples/speculative_decoding/collect_hidden_states/init.py
- examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T14:13:57.091Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py:154-166
Timestamp: 2025-09-17T14:13:57.091Z
Learning: The `apply_chat_template` method in HuggingFace transformers has `tokenize=True` as the default parameter, meaning it returns token IDs even when `return_tensors=None` is used (returns as Python list).
Applied to files:
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
🧬 Code graph analysis (4)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
dataset_splits_explanation
(157-172)download_file
(26-35)id_for_conversation
(38-42)update_dataset_file_with_conversations
(127-154)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)
main
(82-171)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1)
main
(49-84)modelopt/torch/speculative/utils.py (1)
tokenize
(212-229)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
train
(117-271)
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
🔇 Additional comments (18)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
19-25
: Fix broken relative paths; parameterize and create output dir.This file lives in .../prepare_input_conversations/, so prefixing with prepare_input_conversations/ points to a non‑existent nested directory. Also ensure outputs land in a known dir. Flag name --output-split-name looks correct.
-python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train +OUTPUT_DIR="${OUTPUT_DIR:-data}" +mkdir -p "$OUTPUT_DIR" +python3 "$DIR/add_daring_anteater.py" --output-split-name train --output-dir "$OUTPUT_DIR" @@ -python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test +python3 "$DIR/add_mtbench.py" --output-split-name mix_test --output-dir "$OUTPUT_DIR"To confirm flags exist in the repo, run:
#!/bin/bash rg -n -C2 -g 'examples/**' -e 'add_argument\(\s*"--output-dir"' -e 'add_argument\(\s*"--output-split-name"'examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (3)
1-15
: Add shebang and strict mode; make script executable.Without a shebang, shells/tools guess the interpreter; add safe flags. This also addresses SC2148.
Apply:
+#!/usr/bin/env bash +set -euo pipefail +IFS=$'\n\t'Also ensure the file is executable (chmod +x).
28-28
: Make split suffix length explicit; use temp dir and trap for cleanup.Avoid filename mismatches and ensure cleanup even on failure.
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- +# temp workspace for parts +TMPDIR="$(mktemp -d)" +trap 'rm -rf "$TMPDIR"' EXIT +PREFIX="$TMPDIR/part-" +# detect GPU count or default to 8 +NGPUS="${NGPUS:-8}" +if command -v nvidia-smi >/dev/null 2>&1; then + NGPUS=$(nvidia-smi -L | wc -l | tr -d ' ') +fi +split -n "l/${NGPUS}" -d -a 2 --numeric-suffixes=00 --additional-suffix=.jsonl "$INPUT_FILE" "$PREFIX"
30-34
: Robust parallel launch with zero-padded parts and per‑PID status check.Fail fast if any worker fails; avoid silent errors.
-for i in $(seq 0 7) -do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & -done -wait +PIDS=() +for ((i=0; i<NGPUS; i++)); do + part="${PREFIX}$(printf '%02d' "$i").jsonl" + [ -s "$part" ] || continue + CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-file "$part" \ + --output-dir "$OUTPUT_DIR" & + PIDS+=("$!") +done + +ec=0 +for pid in "${PIDS[@]}"; do + if ! wait "$pid"; then ec=1; fi +done +exit "$ec"examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)
1-3
: Add shebang and strict mode.Executable script should declare interpreter and fail fast (SC2148).
+#!/usr/bin/env bash +set -euo pipefail +IFS=$'\n\t'examples/speculative_decoding/launch_train.sh (2)
111-113
: Default seq len bumped to 2048; confirm headroom.At 2048, memory can be tight on smaller GPUs depending on batch/heads. Fine by default; just confirm your example configs fit typical 24–48GB cards. If not, document an override in README/docs.
138-142
: Accelerate --multi_gpu flag: confirm support with your pinned Accelerate version.Some setups rely on config/num_processes rather than --multi_gpu. If your CI/README pins a version that supports it, fine; otherwise derive processes from NUM_GPU and pass --num_processes.
Would you like me to add a version check and fall back to --num_processes "$NUM_GPU" if --multi_gpu isn’t supported?
examples/speculative_decoding/eagle_utils.py (5)
270-281
: Conversation ID handling: confirm naming and typesIf offline files are named with zero-padding (e.g., 00000000.pt) or IDs are ints vs strings, this strict match will drop samples.
- Are offline files named exactly f"{conv_id}.pt" without padding or casting differences?
- If not, normalize: conv_id = str(conv_id) and/or support zero-padded matches. I can provide a small normalizer if needed.
353-380
: Validate hidden_size consistency across batchHidden-state second dims must match to stack cleanly; add cheap checks.
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: base_batch = super().__call__(features) if "kwargs" not in features[0]: raise ValueError("No kwargs found in batch features. Offline data required.") features = [item["kwargs"]["base_model_outputs"] for item in features] + hs_dim = features[0]["base_model_hidden_states"].shape[1] + if any(f["base_model_hidden_states"].shape[1] != hs_dim for f in features): + raise ValueError("Inconsistent base_model_hidden_states hidden_size in batch.") + if any(f["aux_hidden_states"].shape[1] != hs_dim for f in features): + raise ValueError("Inconsistent aux_hidden_states hidden_size in batch.")
257-265
: Offline data wiring looks goodGood gating, path assertion, and clear user feedback for offline mode.
272-279
: LGTM on explicit None checks for conversation_idThis addresses the prior “conv_id=0” pitfall from earlier feedback.
199-235
: CPU-load offline tensors, validate keys, and truncate preprocessed base before shape checkAvoid dataloader-side GPU allocations and make shape checks robust when generation/training max lengths differ.
- # Extend the data sample with the hidden states from the .pt file - max_length = self.tokenizer.model_max_length - offline_data = torch.load(offline_file_path) - offline_data["input_ids"] = offline_data["input_ids"][:max_length] - offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] - offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] + # Extend the data sample with the hidden states from the .pt file + max_length = self.tokenizer.model_max_length + # Truncate base fields before validation (do not mutate cache) + preprocessed_base_trunc = { + "input_ids": preprocessed_base["input_ids"][:max_length], + "labels": preprocessed_base["labels"][:max_length], + "attention_mask": preprocessed_base["attention_mask"][:max_length], + "loss_mask": preprocessed_base["loss_mask"][:max_length], + } + offline_data = torch.load(offline_file_path, map_location="cpu") + required_keys = {"input_ids", "hidden_states", "aux_hidden_states"} + if not required_keys.issubset(offline_data): + missing = required_keys.difference(offline_data.keys()) + raise ValueError(f"Offline data at {offline_file_path} missing keys: {missing}") + offline_data["input_ids"] = offline_data["input_ids"][:max_length] + offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] + offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] @@ - if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape: + if preprocessed_base_trunc["input_ids"].shape != offline_data["input_ids"].shape: msg = f"""Input IDs from offline data do not match the preprocessed input IDs for offline data sample at {offline_file_path}.""" raise ValueError(msg) @@ - ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache + ret = {**preprocessed_base_trunc} # Shallow copy so we don't accidentally modify the cacheexamples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
105-111
: LGTM: tokenizer setup and BOS presence check.Pad fallback and BOS enforcement are sensible for this workflow.
152-163
: Avoid leaking /tmp/meta.json on skips; guard BOS check.When skipping for length, the temp meta file persists and can confuse the watcher. Also guard against empty inputs.
num_input_tokens = len(input_ids) if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: num_too_long += 1 + # Clean up meta file if we decide to skip + if temp_meta_file.exists(): + temp_meta_file.unlink() continue - if input_ids[0] == bos_token_id: + if input_ids and input_ids[0] == bos_token_id: # Remove the leading BOS token. vLLM's completion generation # endpoint will prepend the BOS token automatically. input_ids = input_ids[1:]examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)
21-21
: Fix incorrect input path: mtbench.jsonl lives under input_conversations by default.Current path breaks after gen_synthetic_conversations removal; align with add_mtbench.py output or pre-generate the file.
- --input-file synthetic_conversations/mtbench.jsonl \ + --input-file input_conversations/mtbench.jsonl \examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (3)
62-85
: Robustify role mapping; avoid halting on unknown roles.ShareGPT variants sometimes use
assistant
and may include raretool
/function
roles. Raising on unknown roles aborts the whole run. Prefer skipping such conversations; also coalesceassistant
.
[Suggest_recommended_refactor]
Apply this diff:def parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None: @@ - for turn in sharegpt_conv.get("conversations", []): - if turn.get("from") in ["human", "user"]: + for turn in sharegpt_conv.get("conversations", []): + sender = turn.get("from") + if sender in ("human", "user"): role = "user" - elif turn.get("from") in ["gpt", "chatgpt", "bard"]: + elif sender in ("gpt", "chatgpt", "assistant", "bard"): role = "assistant" - elif turn.get("from") == "system": + elif sender == "system": # ShareGPT system messages are metadata, skip them continue - elif turn.get("from") == "bing": + elif sender == "bing": # Bing conversations are skipped for training, omit it return None else: - err_msg = f"Unknown role in conversation: {turn.get('from')}" - raise ValueError(err_msg) + # Unknown role in this conversation; skip the whole sample to be safe. + return None @@ - value = turn.get("value", "").strip() + value = (turn.get("value") or "").strip() if value: msgs.append({"role": role, "content": value})
127-131
: Entry-point wiring looks good.Async orchestration is correct and self-contained.
30-30
: Validate dataset URL stability (optional).HF “resolve/main” links occasionally move. Consider a fallback via the datasets library (
load_dataset(..., streaming=True)
) if download fails.Would you like me to wire an optional
--use-hf-datasets
path that streams directly from the hub?
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Show resolved
Hide resolved
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
Show resolved
Hide resolved
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
Show resolved
Hide resolved
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
Show resolved
Hide resolved
/ok to test 5add4cb |
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: h-guo18 <[email protected]> Co-authored-by: h-guo18 <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: Feature support for offline training of EAGLE3 heads using the HuggingFace training script
Overview:
This PR contains two primary components:
New helper scripts, ported from previous EAGLE3 training scripts, to facilitate easy data preparation for offline training. These include:
prepare_input_conversations
: short python scripts to load commonly-used training datasets and output a standardizedjsonl
dataset file that is ready for traininggen_synthetic_conversations
: scripts to generate synthetic conversations using a conversation dataset as a collection of prompts for the model. Currently, OpenAI-compatible endpoints are used to generate conversations. This is a known bottleneck of dataset preparation so high performance is key. Any serving engine can be used, but an example script demonstrates how to launchvllm
for inference.collect_hidden_states
: scripts to extract and save hidden states generated from a conversation dataset, for use in offline training. These include a script that can send the completion requests to a local inference server running with a patched inference loop, as well as a script that uses HF transformers AutoModel to generate the hidden states explicitly. Using an inference server is more performant, but either will work well for generating hidden states.Support for offline training in the EAGLE3 training scripts. This is triggered by sending
--offline-data X
tolaunch.sh
, which will then launchmain.py
with:--offline-training True --offline-data-path $OFFLINE_DATA_PATH --omit-target-layers False
. See below for example usage.Currently, the inline evaluation using
ar_validate.py
does not work when the target model's hidden layers are deleted, so the memory footprint of the output checkpoints are not smaller than using online training. However, all performance gains during training should still be present.--omit-target-layers
controls this behaviour and will be re-enabled when offline support is added to the validation script.Usage
Here are the steps to reproduce an offline training run of Llama 3.2 1B-Instruct on the Daring-Anteater dataset:
Note that all commands are expected to be running with modelopt installed, and from the base directory at
examples/speculative_decoding
Testing
Training was tested and evaluated on the example setup above, reporting ~2.2 AL after 2 epochs in all cases. Offline and online training produced nearly identical loss curves and acceptance rates at each evaluation.
Summary by CodeRabbit
New Features
Chores