-
Notifications
You must be signed in to change notification settings - Fork 161
Feature: Offline training for EAGLE3 #300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds dataset preparation utilities, tools to compute/send/sample per-conversation hidden states, an offline training path that consumes precomputed .pt hidden-states (dataset/collator/dataloader changes and CLI/launcher wiring), transformer plugin adjustments for offline execution, .gitignore updates, and removal of a legacy launcher. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User
participant Prep as Prepare scripts
participant HF as compute_hidden_states_hf
participant API as OpenAI-like server
participant DS as Disk (.pt)
participant Trainer as Trainer (offline)
U->>Prep: run add_* → produce input_conversations JSONL
alt local HF compute
U->>HF: compute_hidden_states_hf (--model, --input-file)
HF->>DS: write per-conversation .pt (input_ids, hidden_states, aux_hidden_states)
else send to server
U->>API: send_conversations_for_hiddens (--base-url, --input-file)
API->>DS: server writes per-conversation .pt
end
U->>Trainer: launch_train.sh --offline-data PATH
Trainer->>DS: read .pt files
Trainer->>Trainer: OfflineSupervisedDataset + DataCollatorForOffline → training with eagle_offline=true
sequenceDiagram
autonumber
participant C as send_conversations_for_hiddens
participant Tok as Tokenizer
participant Meta as /tmp/meta.json
participant API as OpenAI-like endpoint
participant Out as Output Dir
C->>Tok: apply_chat_template(conversations) -> input_ids / prompt
C->>Meta: write {conversation_id, output_file}
C->>API: completions.create(model, prompt, max_tokens=1)
alt success
API-->>C: 200 OK
Note right of API: serving engine coordinates writing .pt to Out
API->>Out: .pt created
else error / too long
API-->>C: error
C->>C: increment counters, skip
end
C->>Meta: cleanup entry
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Show resolved
Hide resolved
|
||
# Extend the data sample with the hidden states from the .pt file | ||
max_length = self.tokenizer.model_max_length | ||
offline_data = torch.load(offline_file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious:
- With current implementation, will there by any pre-fetching mechanism to this tensor loading (perhaps taken care inside HF trainer)?
- Considering there's a limit in total disk bandwidth, will data loading possibly be a bottleneck limiting the training speed (if we further optimize the training loop, e.g. the TTT part)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes, I believe standard pytorch/HF datasets will handle prefetching, multiple loading worker processes, etc.
- Indeed, disk bandwidth can absolutely bottleneck training. If this becomes an issue, we can offset this using techniques such as compressed hidden-states on-disk, or using a file system with faster read speeds (e.g. using many smaller disks for high parallel read throughput). However, the speed-of-light for even a single moderate-quality disk is quite good, so networking a few 1-4TB disks together can easily saturate the (optimized) GPU throughput.
Can we remove gen_synthetic_conversations from the PR as we have https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding/distributed_generate |
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Outdated
Show resolved
Hide resolved
4282dbe
to
f6cb37a
Compare
examples/speculative_decoding/gen_synthetic_conversations/run_vllm_server.sh
Outdated
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Show resolved
Hide resolved
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #300 +/- ##
==========================================
- Coverage 73.82% 73.82% -0.01%
==========================================
Files 172 172
Lines 17438 17438
==========================================
- Hits 12874 12873 -1
- Misses 4564 4565 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
LGTM. Tried offline workflow with tinyllama and got reasonable AR. |
1839ab5
to
f92be76
Compare
/ok to test f92be76 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/speculative_decoding/launch_train.sh (1)
144-170
: --multi_gpu is passed to accelerate instead of main.py (likely unrecognized).The flag appears intended for main.py but is placed before main.py, so accelerate may reject it. Move it after main.py.
-export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ +export TOKENIZERS_PARALLELISM=False +CMD="accelerate launch --mixed_precision bf16 main.py \ + $MULTI_GPU \ --mode $MODE \ --model_name_or_path $MODEL \ --training_seq_len $TRAINING_SEQ_LEN \ @@ - --data_path $DATA \ + ${DATA_ARGS} \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS "Add DATA_ARGS a few lines above (see next comment).
modelopt/torch/speculative/plugins/transformers.py (2)
120-132
: Typo: rcache_position should be cache_position.This will raise an unexpected keyword error in HF models.
- rcache_position=cache_position, + cache_position=cache_position,
310-318
: LlamaDecoderLayer expects past_key_value (singular), not past_key_values.HF 4.48–4.57 use past_key_value on the decoder layer; passing past_key_values will error. Keep model-level plural, layer-level singular.
- past_key_values=past_key_values, + past_key_value=past_key_values,If you need to support future HF where the layer accepts plural, add a small shim with signature inspection and forward the appropriate kwarg.
♻️ Duplicate comments (3)
modelopt/torch/speculative/plugins/transformers.py (1)
811-813
: Good: dummy DynamicCache for offline + legacy cache adapter.This addresses earlier Cache init issues across transformers versions.
Also applies to: 827-828
examples/speculative_decoding/main.py (1)
187-194
: Thanks for wiring eagle_offline into the config.This addresses prior feedback to use EagleConfig’s flag. LGTM.
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1)
143-143
: Document the serving engine requirement inline.Add a short comment pointing to the patched vLLM branch or required changes so users aren’t blocked. This addresses the earlier reviewer question.
🧹 Nitpick comments (33)
examples/speculative_decoding/.gitignore (1)
2-4
: Ignore additional generated artifacts (export, hidden-states dirs).Add common outputs created by the new workflow so they don’t get committed accidentally.
Daring-Anteater input_conversations synthetic_conversations ckpts +export +hidden_statesexamples/speculative_decoding/launch_train.sh (3)
91-95
: Guard divide-by-zero when no GPUs are visible.If torch reports 0 GPUs, DEFAULT_SAVE_STEPS errors. Provide a sane fallback.
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +# Calculate save_steps with fallback +if [[ "${GPU_COUNT:-0}" -gt 0 ]]; then + DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +else + echo "Warning: No CUDA devices found. Falling back to GPU_COUNT=1 for save_steps." + DEFAULT_SAVE_STEPS=8192 +fi
107-109
: Remove unused variables (REDRAFTER_*).They’re set but never used in this script.
-REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1} -REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
81-84
: Error message shows only the value after '='.For bad args, print the full token for clarity.
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument: %s\n" "$1"examples/speculative_decoding/train_eagle3_and_export.sh (2)
61-65
: Quote OFFLINE_DATA_PATH and validate early.Preempt path issues and early-fail on typos.
-if [[ "$OFFLINE_DATA_PATH" != "" ]]; then - OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH" +if [[ -n "$OFFLINE_DATA_PATH" ]]; then + if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then + echo "Offline data path not found: $OFFLINE_DATA_PATH" >&2 + exit 1 + fi + OFFLINE_DATA_ARGS="--offline-data \"$OFFLINE_DATA_PATH\"" else OFFLINE_DATA_ARGS="" fi
72-79
: Quote all interpolations in the training call.Prevents breakage with spaces/special chars.
-./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - $OFFLINE_DATA_ARGS \ - --data $DATA \ - --num_gpu $NUM_GPU \ +./launch_train.sh --model "$BASE_MODEL" \ + --output_dir "$OUTPUT_DIR" \ + $OFFLINE_DATA_ARGS \ + --data "$DATA" \ + --num_gpu "$NUM_GPU" \ --num_epochs 2 \ --eagle_config eagle_config.jsonexamples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (3)
63-63
: Load on CPU to avoid accidental GPU allocations.These artifacts are saved on CPU; loading on CPU is safer and avoids device mismatches.
- data = torch.load(file) + data = torch.load(file, map_location="cpu")
50-55
: Validate input path exists and is a dir/file.Currently a non-existent path silently yields 0 files; fail fast.
- if args.input_path.is_file(): + if not args.input_path.exists(): + raise FileNotFoundError(f"Input path not found: {args.input_path}") + if args.input_path.is_file(): all_files = [args.input_path]
70-73
: Allow superset of keys for forward-compat.Future producers may include extra fields (e.g., metadata). Check for required keys instead of exact match.
- if set(expected_keys) != set(data.keys()): + if not set(expected_keys).issubset(set(data.keys())): print(f"File {file} does not contain all expected keys: {expected_keys}") print(f" Found keys: {list(data.keys())}") continuemodelopt/torch/speculative/plugins/transformers.py (1)
425-435
: Device selection logic for offline mode — LGTM with a tiny robustness nit.Using lm_head device in offline mode is correct. Consider guarding for models without lm_head (edge adapters).
- if eagle_offline: + if eagle_offline: # For offline training, the base model has no layers. # Read the device from the lm_head instead. - device = self.lm_head.weight.device + device = getattr(self.lm_head, "weight", self.model.embed_tokens.weight).deviceexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)
89-91
: Set eval mode for deterministic, faster inference.Call model.eval() after loading.
- model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model.eval()
78-88
: This script is synchronous; drop asyncio for simplicity.No awaits inside main. Convert to a regular function and remove asyncio import/usage.
@@ -import asyncio @@ -async def main(args: argparse.Namespace) -> None: +def main(args: argparse.Namespace) -> None: @@ if __name__ == "__main__": cli_args = parse_args() - asyncio.run(main(cli_args)) + main(cli_args)Also applies to: 18-21, 166-168
115-123
: Separate “too short” and “too long” counters/logs for clarity.The variable name num_skipped_too_long also counts short samples (<=10). Track them separately to make logs actionable.
@@ - num_skipped_too_long = 0 + num_skipped_too_long = 0 + num_skipped_too_short = 0 @@ - if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: - num_skipped_too_long += 1 + if num_input_tokens <= 10: + num_skipped_too_short += 1 + continue + if num_input_tokens > args.max_seq_len: + num_skipped_too_long += 1 continue @@ - if num_skipped_too_long > 0: + if num_skipped_too_short > 0: + print(f"Skipped {num_skipped_too_short} conversations for being too short (<=10 tokens).") + if num_skipped_too_long > 0: print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")Also applies to: 153-163
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)
25-36
: Add shebang/safety flags, quote vars, and ensure cleanup on interruption.Harden the runner and keep temp files tidy even on Ctrl‑C.
+#!/usr/bin/env bash +set -euo pipefail + INPUT_FILE=synthetic_conversations/daring-anteater.jsonl OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ -split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- +split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part- +trap 'rm -f /tmp/part-*.jsonl' EXIT for i in $(seq 0 7) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & +CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-file "/tmp/part-0${i}.jsonl" \ + --output-dir "$OUTPUT_DIR" & done wait -rm /tmp/part-*.jsonl +# Files are removed by trap on EXITAlso applies to: 1-15
examples/speculative_decoding/eagle_utils.py (2)
214-220
: Load offline tensors onto CPU explicitly to avoid accidental CUDA deserialization.If any .pt file was saved with CUDA tensors, torch.load will attempt to place them on GPU and can OOM. Force CPU map_location.
- offline_data = torch.load(offline_file_path) + offline_data = torch.load(offline_file_path, map_location="cpu")
353-380
: Batching: align hidden-state sequence length with input_ids to prevent downstream surprises.You pad hidden states to max_hs_length independently of base_batch padding. If these diverge (shouldn’t, but can in edge cases), training code may assume equal lengths. Consider asserting equality or padding HS to base_batch length.
- max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features) + max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features) + # Optional safety: ensure HS length matches token length + base_seq_len = base_batch["input_ids"].shape[1] + assert max_hs_length == base_seq_len, ( + f"Hidden-state length ({max_hs_length}) != token length ({base_seq_len})." + )examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (2)
1-1
: Add a shebang to satisfy shellcheck and ensure correct interpreter.Without a shebang, shells may pick inconsistent interpreters and SC2148 is triggered.
Apply this diff:
+#!/usr/bin/env bash +set -euo pipefail
19-23
: Make the script path-robust and fix the CLI flag name.
- Using a hardcoded relative path will break if users cd into this directory.
- The commented flag name doesn't match the Python CLI (
--debug-max-num-conversations
).Apply this diff:
-python3 collect_hidden_states/send_conversations_for_hiddens.py \ +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "${SCRIPT_DIR}/send_conversations_for_hiddens.py" \ --model meta-llama/Llama-3.2-1B-Instruct \ --input-file synthetic_conversations/mtbench.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ -# --debug-max-num-conversations-per-split 1000 +# --debug-max-num-conversations 1000examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
1-1
: Add a shebang to make the script executable and lint-clean.+#!/usr/bin/env bash +set -euo pipefailexamples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (2)
41-47
: Accept both --output-split-name and --output-split for CLI consistency.The example script uses
--output-split
; add it as an alias to avoid UX breakage.- parser.add_argument( - "--output-split-name", + parser.add_argument( + "--output-split-name", "--output-split", type=str, default="ultrachat", help=dataset_splits_explanation("ultrachat"), )
58-79
: Nit:async
not needed; code is synchronous.Optional: drop
async
andasyncio.run
for simplicity, orawait
any future async I/O.examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (2)
45-51
: Support--output-split
alias to match example usage.- parser.add_argument( - "--output-split-name", + parser.add_argument( + "--output-split-name", "--output-split", type=str, default="mtbench", help=dataset_splits_explanation("mtbench"), )
89-92
: Use the full message list when hashing the conversation ID for consistency.Other scripts hash the normalized message list; hashing only the prompt string is inconsistent.
- prompt_id = f"mtbench-{entry['question_id']:03}_" + id_for_conversation(prompt) - input_conversations.append( - {"conversation_id": prompt_id, "conversations": [{"role": "user", "content": prompt}]} - ) + msgs = [{"role": "user", "content": prompt}] + prompt_id = f"mtbench-{entry['question_id']:03}_" + id_for_conversation(msgs) + input_conversations.append( + {"conversation_id": prompt_id, "conversations": msgs} + )examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (2)
41-46
: Accept--output-split
alias to align with examples.- parser.add_argument( - "--output-dir", + parser.add_argument( + "--output-dir", type=Path, default=Path("input_conversations/"), help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", )And:
- parser.add_argument( - "--output-split-name", + parser.add_argument( + "--output-split-name", "--output-split", type=str, default="daring-anteater", help=dataset_splits_explanation("daring-anteater"), )
87-90
: Skip empty conversations to avoid downstream failures.Downstream scripts expect non-empty
conversations
; append only if messages were extracted.- input_conversations.append( - {"conversation_id": prompt_id, "conversations": processed_conversations} - ) + if processed_conversations: + input_conversations.append( + {"conversation_id": prompt_id, "conversations": processed_conversations} + )examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (4)
104-107
: Close the OpenAI client to avoid connection leaks.Use a context manager to ensure the underlying httpx client is closed.
- client: AsyncOpenAI = AsyncOpenAI( - api_key=args.openai_api_key, - base_url=args.base_url, - ) + async with AsyncOpenAI(api_key=args.openai_api_key, base_url=args.base_url) as client:Also indent the subsequent usage accordingly.
175-186
: Catch OpenAI client errors instead of httpx directly; keep a broad fallback.The OpenAI SDK wraps HTTP errors; catching
httpx.HTTPStatusError
likely won’t trigger. Preferopenai.OpenAIError
(and optionallyBadRequestError
).- except httpx.HTTPStatusError as e: - print(f"HTTP error for conversation {conversation_id}: {e}") - num_error += 1 - continue - except openai.BadRequestError: + except openai.BadRequestError: # Most likely the conversation is too long, ignore num_too_long += 1 continue + except openai.OpenAIError as e: + print(f"OpenAI client error for conversation {conversation_id}: {e}") + num_error += 1 + continue
143-153
: Meta file lifecycle: leave fewer footguns.If a conversation is skipped early (length), the temp meta file persists until the next loop. You already guard at the top, but proactively removing on skip (see earlier diff) reduces surprises if users parallelize.
188-193
: Minor: redundantcontinue
.The loop proceeds to the next iteration anyway.
- num_success += 1 - continue + num_success += 1examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
62-85
: Make parsing resilient to unknown roles and non‑string values.Avoid aborting on unexpected roles and guard against non‑string "value" fields.
Apply:
def parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None: @@ - elif turn.get("from") == "bing": - # Bing conversations are skipped for training, omit it - return None + elif turn.get("from") == "bing": + # Bing conversations are skipped for training, omit it + return None else: - err_msg = f"Unknown role in conversation: {turn.get('from')}" - raise ValueError(err_msg) + # Skip unknown roles rather than abort the whole run + print(f"Warning: Unknown role in conversation: {turn.get('from')}, skipping turn.") + continue - - value = turn.get("value", "").strip() - if value: - msgs.append({"role": role, "content": value}) + raw = turn.get("value", "") + if not isinstance(raw, str): + # Ignore non-string payloads + continue + value = raw.strip() + if value: + msgs.append({"role": role, "content": value})examples/speculative_decoding/prepare_input_conversations/utils.py (3)
101-108
: Float equality is brittle; allow small tolerance.Avoid false negatives due to FP rounding.
Apply:
- if train_ratio + val_ratio + test_ratio != 1.0: - msg = "Ratios must sum to 1.0" + total = train_ratio + val_ratio + test_ratio + if abs(total - 1.0) > 1e-9: + msg = f"Ratios must sum to 1.0 (got {total})" raise ValueError(msg)
113-116
: Avoid mutating global RNG state when shuffling.Use a local Random instance to keep determinism without side effects.
Apply:
- if shuffle: - random.seed(seed) - random.shuffle(conversations) + if shuffle: + rng = random.Random(seed) + rng.shuffle(conversations)
155-170
: Help text prints “%%” instead of “%”.Use single percent signs so argparse help is readable.
Apply:
- - 'mix': Conversations will be randomly mixed and distributed into - 'train' (80%%), 'val' (10%%), and 'test' (10%%) splits. - - 'mix_test': Conversations will be randomly mixed and distributed into - 'val' (50%%) and 'test' (50%%) splits. + - 'mix': Conversations will be randomly mixed and distributed into + 'train' (80%), 'val' (10%), and 'test' (10%) splits. + - 'mix_test': Conversations will be randomly mixed and distributed into + 'val' (50%) and 'test' (50%) splits.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/speculative_decoding/.gitignore
(1 hunks)examples/speculative_decoding/collect_hidden_states/__init__.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
(1 hunks)examples/speculative_decoding/eagle_utils.py
(4 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/launch_train.sh
(4 hunks)examples/speculative_decoding/main.py
(4 hunks)examples/speculative_decoding/prepare_input_conversations/__init__.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
(1 hunks)examples/speculative_decoding/prepare_input_conversations/utils.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(4 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
🧰 Additional context used
🧬 Code graph analysis (12)
examples/speculative_decoding/prepare_input_conversations/utils.py (2)
modelopt/torch/utils/random.py (2)
random
(59-61)shuffle
(148-150)modelopt/torch/_deploy/_runtime/common.py (1)
write_bytes
(65-67)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
train
(117-271)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
parse_args
(28-75)main
(78-163)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
dataset_splits_explanation
(155-170)download_file
(26-34)id_for_conversation
(37-41)update_dataset_file_with_conversations
(125-152)
examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
dataset_splits_explanation
(155-170)download_file
(26-34)id_for_conversation
(37-41)update_dataset_file_with_conversations
(125-152)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
parse_args
(28-75)main
(78-163)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
parse_args
(25-46)main
(49-84)
examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
dataset_splits_explanation
(155-170)id_for_conversation
(37-41)update_dataset_file_with_conversations
(125-152)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
parse_args
(30-90)main
(93-206)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
parse_args
(25-46)main
(49-84)
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
dataset_splits_explanation
(155-170)id_for_conversation
(37-41)update_dataset_file_with_conversations
(125-152)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-312)
modelopt/torch/speculative/plugins/transformers.py (2)
modelopt/torch/speculative/eagle/eagle_model.py (1)
EagleModel
(23-51)modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_set_default_aux_hidden_state_layers
(682-694)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (8)
examples/speculative_decoding/prepare_input_conversations/__init__.py (1)
1-16
: LGTM — license header and package docstring look good.examples/speculative_decoding/collect_hidden_states/__init__.py (1)
1-16
: LGTM — clean package init with correct licensing.examples/speculative_decoding/launch_train.sh (1)
127-136
: Don't pass --data_path when using offline data; quote offline path.File: examples/speculative_decoding/launch_train.sh Lines: 127-136
Passing an empty --data_path can break the CLI; quote paths to handle spaces and let the offline pipeline supply data.
if [[ "$OFFLINE_DATA_PATH" != "" ]]; then if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." exit 1 else - OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH" + OFFLINE_TRAINING_ARGS="--offline-data-path \"$OFFLINE_DATA_PATH\"" + DATA_ARGS="" # let offline pipeline supply data fi else OFFLINE_TRAINING_ARGS="" + DATA_ARGS="--data_path $DATA" fi @@ - --data_path $DATA \ + ${DATA_ARGS} \If main.py requires an explicit --offline-training flag, wire it here similarly.
Also applies to: 167-169
examples/speculative_decoding/eagle_utils.py (1)
238-307
: Potential tokenization mismatch with offline data.preprocess() removes the think‑stripping snippet from tokenizer.chat_template, but compute_hidden_states_hf.py originally didn’t. That will make input_ids differ and trip the shape check here. After adopting the matching replacement in compute_hidden_states_hf.py, please re‑verify that shapes match end‑to‑end on a sample.
examples/speculative_decoding/main.py (1)
236-241
: AR validation: early return when disabled looks good.Short‑circuit is correct and won’t interfere with control flow. LGTM.
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (2)
87-101
: Ensure runtime deps are documented/installed (aiohttp, tqdm).Add these to an examples requirements file or docs so first‑time runs don’t fail.
110-118
: Conversation ID construction looks solid.Stable hash + source id prefixing should dedupe reliably.
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
37-42
: LGTM: stable conversation hashing.Deterministic SHA‑256 over normalized JSON is appropriate for deduping.
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
Outdated
Show resolved
Hide resolved
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
Show resolved
Hide resolved
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
Outdated
Show resolved
Hide resolved
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
f92be76
to
f74bf59
Compare
/ok to test f74bf59 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
examples/speculative_decoding/main.py (1)
188-190
: Good: honor EagleConfig.eagle_offline per prior feedback.
This aligns with the earlier request to use the existing config flag.examples/speculative_decoding/prepare_input_conversations/utils.py (2)
26-36
: Add a network timeout and avoid multi-context in one line for clarity.Prevents indefinite hangs and improves readability; keep parent dir creation.
Apply:
async def download_file(url: str, destination: Path) -> None: """Download a file from a URL to a specified destination.""" destination.parent.mkdir(parents=True, exist_ok=True) - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status != 200: - msg = f"Failed to download {url}: {response.status}" - raise RuntimeError(msg) - content = await response.read() - destination.write_bytes(content) - print(f"Downloaded {url} to {destination}") + timeout = aiohttp.ClientTimeout(total=600) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + if response.status != 200: + msg = f"Failed to download {url}: {response.status}" + raise RuntimeError(msg) + content = await response.read() + destination.write_bytes(content) + print(f"Downloaded {url} to {destination}")
45-86
: Prevent duplicate IDs within the same update and validate required keys.Currently, duplicates inside the provided conversations list can be appended multiple times because existing_ids isn’t updated as you add. Also, missing "conversations" will raise KeyError later; validate early.
Apply:
def add_conversations_to_split(conversations: list, dataset_dir: Path, split: str) -> None: @@ - # Open the dataset file for the specified split, or create it if it doesn't exist - dataset_file = dataset_dir / f"{split}.jsonl" + # Ensure output directory exists and open/create split file + dataset_dir.mkdir(parents=True, exist_ok=True) + dataset_file = dataset_dir / f"{split}.jsonl" @@ - existing_ids = {entry["conversation_id"] for entry in all_conversations} + existing_ids = {entry["conversation_id"] for entry in all_conversations} num_new_entries = 0 num_duplicates = 0 for entry in conversations: - if entry.get("conversation_id") is None: + entry_id = entry.get("conversation_id") + if entry_id is None: raise ValueError("Each conversation must have a 'conversation_id' field.") - if entry["conversation_id"] not in existing_ids: + if "conversations" not in entry: + raise ValueError("Each conversation must have a 'conversations' field.") + if entry_id not in existing_ids: all_conversations.append( { - "conversation_id": entry["conversation_id"], + "conversation_id": entry_id, "conversations": entry["conversations"], } ) num_new_entries += 1 + existing_ids.add(entry_id) else: num_duplicates += 1 @@ - dataset_dir.mkdir(parents=True, exist_ok=True) with dataset_file.open("w", encoding="utf-8") as f: for entry in all_conversations: f.write(json.dumps(entry, ensure_ascii=False) + "\n")examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
27-29
: Unify REMOVE_THINK_CHAT_TEMPLATE with training and guard None chat_template.Avoid divergence from training preprocessing and prevent AttributeError when chat_template is None.
-from transformers import AutoModel, AutoTokenizer +from transformers import AutoModel, AutoTokenizer +try: + # Keep in sync with training; import if available when run from examples/speculative_decoding + from eagle_utils import REMOVE_THINK_CHAT_TEMPLATE +except Exception: + # Fallback for standalone runs; ensure value matches eagle_utils.REMOVE_THINK_CHAT_TEMPLATE + REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" + ) @@ -REMOVE_THINK_CHAT_TEMPLATE = ( - "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" -) @@ - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + )Also applies to: 96-100
138-148
: Clamp/deduplicate auxiliary layer indices; avoid out‑of‑range/duplicates on small models.Current logic can pick invalid indices (e.g., 2 when num_hidden_layers < 3).
- # Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states + # Extract hidden states from early/mid/late layers; clamp within [0, N-1] hidden_states = outputs.hidden_states - selected_layer_indices = [ - 2, - max(0, num_hidden_layers // 2), - max(1, num_hidden_layers - 3), - ] - selected_layer_indices = sorted(set(selected_layer_indices)) + candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3] + selected_layer_indices = sorted( + {i for i in candidates if 0 <= i <= num_hidden_layers - 1} + ) + if not selected_layer_indices: + selected_layer_indices = [0] aux_hidden_states = torch.cat( [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 )
🧹 Nitpick comments (12)
examples/speculative_decoding/main.py (3)
72-81
: Annotate as Optional to match None default.
offline_data_path
defaults to None but is typed asstr
. Make itstr | None
for consistency witheval_data_path
and type checkers.Apply this diff:
- offline_data_path: str = field( + offline_data_path: str | None = field(
144-145
: Make the offline switch robust to empty strings.
is not None
treats""
as offline. Prefer truthiness.Apply this diff:
-use_offline_training = data_args.offline_data_path is not None +use_offline_training = bool(data_args.offline_data_path)
150-159
: num_hidden_layers=0: add safe fallback and always set num_orig_hidden_layers.Keep your space‑saving default but harden for models that can’t instantiate with 0 layers, and ensure
num_orig_hidden_layers
is set even when resuming or the override path changes.Apply this diff:
- model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, torch_dtype="auto", **model_kwargs - ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings - model_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) - model.config.num_orig_hidden_layers = model_config.num_hidden_layers + model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} + try: + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, torch_dtype="auto", **model_kwargs + ) + except Exception as e: + if use_offline_training: + print_rank_0("num_hidden_layers=0 failed; falling back to 1 for offline training.") + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, torch_dtype="auto", num_hidden_layers=1 + ) + else: + raise + if use_offline_training and not hasattr(model.config, "num_orig_hidden_layers"): + # Record original depth for plugins that need it. + base_cfg = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + model.config.num_orig_hidden_layers = base_cfg.num_hidden_layersPlease run a quick smoke test on at least one model per family you target (e.g., Llama, Mistral, Qwen2, Phi‑3) to confirm no fallback is triggered unexpectedly and training proceeds.
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (2)
1-1
: Add a shebang to fix ShellCheck SC2148 and enable script execution.Without a shebang, shells may interpret this file incorrectly and ShellCheck flags SC2148.
Apply:
+#!/usr/bin/env bash # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
16-25
: Fix relative paths; make script location-agnostic; add strict mode and ensure output dir.The script lives in prepare_input_conversations/, but calls into prepare_input_conversations/*.py again, which breaks when run from repo root or this directory. Also add set -euo pipefail and create a data dir.
Apply:
-# Example script to prepare a dataset of prompts for generation -# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. +# Example script to prepare a dataset of prompts for generation +# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# Optional: change DATASET_DIR or pass via env +DATASET_DIR="${DATASET_DIR:-${SCRIPT_DIR}/data}" +mkdir -p "${DATASET_DIR}" -python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train -# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test -python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test +python3 "${SCRIPT_DIR}/add_daring_anteater.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_sharegpt.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_sft --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_gen --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_sft --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_gen --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} +python3 "${SCRIPT_DIR}/add_mtbench.py" --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}Note: If the add_* CLIs don’t support --dataset-dir, the parameter will be omitted; the mkdir remains safe.
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
93-125
: Make ratio check tolerant to FP error and avoid mutating global RNG state.isclose avoids false negatives; using a local Random(seed) preserves global RNG for callers.
Apply:
+import math @@ - if train_ratio + val_ratio + test_ratio != 1.0: + if not math.isclose(train_ratio + val_ratio + test_ratio, 1.0, rel_tol=0.0, abs_tol=1e-9): msg = "Ratios must sum to 1.0" raise ValueError(msg) @@ - if shuffle: - random.seed(seed) - random.shuffle(conversations) + if shuffle: + random.Random(seed).shuffle(conversations)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)
1-24
: Add shebang/safe flags, fix debug flag name, and quote paths.Without a shebang, running as an executable fails; ShellCheck SC2148 also flags this. The commented flag name doesn’t exist in the CLI (should be --debug-max-num-conversations). Also make invocation path-robust and quote args.
+#!/usr/bin/env bash +set -euo pipefail + +# Resolve script directory to make path robust when invoked from anywhere. +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + -python3 collect_hidden_states/send_conversations_for_hiddens.py \ - --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/mtbench.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ -# --debug-max-num-conversations-per-split 1000 +python3 "$script_dir/send_conversations_for_hiddens.py" \ + --model "meta-llama/Llama-3.2-1B-Instruct" \ + --input-file "synthetic_conversations/mtbench.jsonl" \ + --output-dir "/mnt/md0/eagle-hidden-states/llama1b/mtbench/" +# Optional: limit processed conversations during debugging +# --debug-max-num-conversations 1000examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)
93-95
: Set eval() before inference.Minor but standard to disable dropout and ensure deterministic behavior.
- model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + model.eval()
103-127
: Counter name is misleading; it tracks too‑short and too‑long.Tiny naming nit; adjust for clarity.
- num_skipped_too_long = 0 + num_filtered_by_length = 0 @@ - if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: - num_skipped_too_long += 1 + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_filtered_by_length += 1 continue @@ - if num_skipped_too_long > 0: - print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") + if num_filtered_by_length > 0: + print(f"Skipped {num_filtered_by_length} conversations due to length constraints.")Also applies to: 163-164
82-83
: Async isn’t used; simplify to synchronous main.Removes unnecessary asyncio plumbing.
-async def main(args: argparse.Namespace) -> None: +def main(args: argparse.Namespace) -> None: @@ -if __name__ == "__main__": - cli_args = parse_args() - asyncio.run(main(cli_args)) +if __name__ == "__main__": + cli_args = parse_args() + main(cli_args)Also applies to: 176-178
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)
1-23
: Add shebang/safe flags, quote args, and create output dir.Prevents exec failures (SC2148) and path issues; ensures output dir exists.
+#!/usr/bin/env bash +set -euo pipefail + +# Resolve script directory +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Output directory (edit as needed) +OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/" +mkdir -p "$OUTPUT_DIR" + -python3 collect_hidden_states/compute_hidden_states_hf.py \ - --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +python3 "$script_dir/compute_hidden_states_hf.py" \ + --model "meta-llama/Llama-3.2-1B-Instruct" \ + --input-file "synthetic_conversations/daring-anteater.jsonl" \ + --output-dir "$OUTPUT_DIR" +# Optional: +# --max-seq-len 3072examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)
16-36
: Harden DP runner: shebang/safe flags, temp workspace, quoting, and mkdir.Avoids /tmp collisions, improves safety, and makes paths robust.
+#!/usr/bin/env bash +set -euo pipefail + -INPUT_FILE=synthetic_conversations/daring-anteater.jsonl -OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +INPUT_FILE="synthetic_conversations/daring-anteater.jsonl" +OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/" +mkdir -p "$OUTPUT_DIR" + -split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- +tmpdir="$(mktemp -d)" +trap 'rm -rf "$tmpdir"' EXIT +split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" "$tmpdir/part-" for i in $(seq 0 7) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & + CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \ + --model "meta-llama/Llama-3.2-1B-Instruct" \ + --input-file "$tmpdir/part-0${i}.jsonl" \ + --output-dir "$OUTPUT_DIR" & done wait - -rm /tmp/part-*.jsonl +# Temporary files cleaned via trap
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/speculative_decoding/.gitignore
(1 hunks)examples/speculative_decoding/collect_hidden_states/__init__.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
(1 hunks)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
(1 hunks)examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
(1 hunks)examples/speculative_decoding/eagle_utils.py
(4 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/launch_train.sh
(4 hunks)examples/speculative_decoding/main.py
(4 hunks)examples/speculative_decoding/prepare_input_conversations/__init__.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
(1 hunks)examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
(1 hunks)examples/speculative_decoding/prepare_input_conversations/utils.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(4 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (13)
- examples/speculative_decoding/.gitignore
- examples/speculative_decoding/collect_hidden_states/init.py
- examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
- examples/speculative_decoding/launch_train.sh
- examples/speculative_decoding/prepare_input_conversations/init.py
- examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
- examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
- examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
- modelopt/torch/speculative/plugins/transformers.py
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
📚 Learning: 2025-09-17T14:14:44.950Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
Applied to files:
examples/speculative_decoding/main.py
🧬 Code graph analysis (4)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
make_eagle_supervised_data_module
(238-312)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
parse_args
(30-90)main
(93-206)examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
parse_args
(25-46)main
(49-84)
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
modelopt/torch/utils/random.py (2)
random
(59-61)shuffle
(148-150)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
train
(117-271)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
examples/speculative_decoding/main.py (2)
230-230
: Good: data module gated by use_offline_training.
Clear separation between online/offline loaders and collators.
237-238
: Nice: allow disabling AR validation with non-positive steps.
Useful for offline runs without inline eval.examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)
27-29
: Verify single source of truth for REMOVE_THINK_CHAT_TEMPLATErg returned "No files were searched" due to the glob filter; cannot confirm duplicates. Ensure no other copies exist and training/collection stay in lockstep — from the repo root run and paste outputs of:
rg -n --hidden --no-ignore 'REMOVE_THINK_CHAT_TEMPLATE'
rg -n --hidden --no-ignore ''
Location: examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py:27-29
# Get hidden states | ||
with torch.inference_mode(): | ||
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) | ||
if num_hidden_layers is None: | ||
num_hidden_layers = len(outputs.hidden_states) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix device handling with device_map='auto' (model.device may be missing or wrong).
Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.
- with torch.inference_mode():
- outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+ with torch.inference_mode():
+ emb_device = model.get_input_embeddings().weight.device
+ outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Get hidden states | |
with torch.inference_mode(): | |
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) | |
if num_hidden_layers is None: | |
num_hidden_layers = len(outputs.hidden_states) - 1 | |
# Get hidden states | |
with torch.inference_mode(): | |
emb_device = model.get_input_embeddings().weight.device | |
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True) | |
if num_hidden_layers is None: | |
num_hidden_layers = len(outputs.hidden_states) - 1 |
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Show resolved
Hide resolved
…rsation_id Signed-off-by: Benjamin Chislett <[email protected]>
@benchislett You need to sign your commits with an SSH key. Please take a look at https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work |
with args.input_file.open("r", encoding="utf-8") as f: | ||
all_conversations.extend([json.loads(line) for line in f if line.strip()]) | ||
|
||
if any(not entry.get("conversation_id") for entry in all_conversations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @benchislett, seems like there is a bug to this line:
- When some entry has a conversation_id=0,
not entry.get("conversation_id")
will return True, causing an error raised. - Since we add the fallback below to allow no conversation_id, we should probably also remove this check here.
with args.input_file.open("r", encoding="utf-8") as f: | ||
all_conversations.extend([json.loads(line) for line in f if line.strip()]) | ||
|
||
if any(not entry.get("conversation_id") for entry in all_conversations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar as above
valid_entries = [] | ||
for entry in data_json: | ||
conv_id = entry.get("conversation_id") or entry.get("id") | ||
if not conv_id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar as above, this will raise error when conv_id=0. We should probably use if conv_id is None
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, in current line 273: when conversation_id=0
, conv_id will evaluate to None
since the left-hand-side of or
is False. We want to probably do this instead:
conv_id = entry.get("conversation_id", entry.get("id"))
What does this PR do?
Type of change: Feature support for offline training of EAGLE3 heads using the HuggingFace training script
Overview:
This PR contains two primary components:
New helper scripts, ported from previous EAGLE3 training scripts, to facilitate easy data preparation for offline training. These include:
prepare_input_conversations
: short python scripts to load commonly-used training datasets and output a standardizedjsonl
dataset file that is ready for traininggen_synthetic_conversations
: scripts to generate synthetic conversations using a conversation dataset as a collection of prompts for the model. Currently, OpenAI-compatible endpoints are used to generate conversations. This is a known bottleneck of dataset preparation so high performance is key. Any serving engine can be used, but an example script demonstrates how to launchvllm
for inference.collect_hidden_states
: scripts to extract and save hidden states generated from a conversation dataset, for use in offline training. These include a script that can send the completion requests to a local inference server running with a patched inference loop, as well as a script that uses HF transformers AutoModel to generate the hidden states explicitly. Using an inference server is more performant, but either will work well for generating hidden states.Support for offline training in the EAGLE3 training scripts. This is triggered by sending
--offline-data X
tolaunch.sh
, which will then launchmain.py
with:--offline-training True --offline-data-path $OFFLINE_DATA_PATH --omit-target-layers False
. See below for example usage.Currently, the inline evaluation using
ar_validate.py
does not work when the target model's hidden layers are deleted, so the memory footprint of the output checkpoints are not smaller than using online training. However, all performance gains during training should still be present.--omit-target-layers
controls this behaviour and will be re-enabled when offline support is added to the validation script.Usage
Here are the steps to reproduce an offline training run of Llama 3.2 1B-Instruct on the Daring-Anteater dataset:
Note that all commands are expected to be running with modelopt installed, and from the base directory at
examples/speculative_decoding
Testing
Training was tested and evaluated on the example setup above, reporting ~2.2 AL after 2 epochs in all cases. Offline and online training produced nearly identical loss curves and acceptance rates at each evaluation.
Summary by CodeRabbit
New Features
Chores