Skip to content

Conversation

benchislett
Copy link
Contributor

@benchislett benchislett commented Sep 8, 2025

What does this PR do?

Type of change: Feature support for offline training of EAGLE3 heads using the HuggingFace training script

Overview:

This PR contains two primary components:

  • New helper scripts, ported from previous EAGLE3 training scripts, to facilitate easy data preparation for offline training. These include:

    • prepare_input_conversations: short python scripts to load commonly-used training datasets and output a standardized jsonl dataset file that is ready for training
    • gen_synthetic_conversations: scripts to generate synthetic conversations using a conversation dataset as a collection of prompts for the model. Currently, OpenAI-compatible endpoints are used to generate conversations. This is a known bottleneck of dataset preparation so high performance is key. Any serving engine can be used, but an example script demonstrates how to launch vllm for inference.
    • collect_hidden_states: scripts to extract and save hidden states generated from a conversation dataset, for use in offline training. These include a script that can send the completion requests to a local inference server running with a patched inference loop, as well as a script that uses HF transformers AutoModel to generate the hidden states explicitly. Using an inference server is more performant, but either will work well for generating hidden states.
  • Support for offline training in the EAGLE3 training scripts. This is triggered by sending --offline-data X to launch.sh, which will then launch main.py with: --offline-training True --offline-data-path $OFFLINE_DATA_PATH --omit-target-layers False. See below for example usage.

Currently, the inline evaluation using ar_validate.py does not work when the target model's hidden layers are deleted, so the memory footprint of the output checkpoints are not smaller than using online training. However, all performance gains during training should still be present. --omit-target-layers controls this behaviour and will be re-enabled when offline support is added to the validation script.

Usage

Here are the steps to reproduce an offline training run of Llama 3.2 1B-Instruct on the Daring-Anteater dataset:

# First, prepare the conversations
python3 make_prompts_for_gen/add_daring_anteater.py

# Then generate synthetic conversations
vllm serve meta-llama/Llama-3.2-1B-Instruct # launch any llm server
bash gen_synthetic_conversations/send_completion_reqs_openai.sh
# don't forget to shutdown the llm server

# Compute hidden states
bash collect_hidden_states/run_hf_compute_hiddens.sh

# Launch training
OUTPUT_DIR=ckpts/${llama1b}-$(date +%Y%m%d_%H%M)
mkdir -p "$(dirname "$OUTPUT_DIR")"
./launch.sh --model meta-llama/Llama-3.2-1B-Instruct \
            --output_dir $OUTPUT_DIR \
            --offline-data /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \
            --data synthetic_conversations/daring-anteater.jsonl \
            --num_gpu 8 \
            --num_epochs 2 \
            --eagle_config eagle_config.json

Note that all commands are expected to be running with modelopt installed, and from the base directory at examples/speculative_decoding

Testing

Training was tested and evaluated on the example setup above, reporting ~2.2 AL after 2 epochs in all cases. Offline and online training produced nearly identical loss curves and acceptance rates at each evaluation.

  • Is this change backward compatible?: Yes. Online training should not be affected in any way.
  • Did you write any new necessary tests?: No. Unsure if we have unit tests for training
  • Did you add or update any necessary documentation?: No. Waiting for Feat: update eagle3 example; add export #293 to land before updating any READMEs / docs.
  • Did you update Changelog?: No. TODO.

Summary by CodeRabbit

  • New Features

    • Offline training support with precomputed hidden-state training (data loader/collator and CLI flags); default sequence length raised to 2048.
    • New dataset preparation utilities for Daring-Anteater, MTBench, ShareGPT, UltraChat, plus split/mix helpers.
    • Hidden-state tooling: compute, send, and sample utilities and multi-GPU runner; runtime changes to enable offline mode.
  • Chores

    • .gitignore updated to ignore input_conversations, synthetic_conversations, and ckpts.
    • Removed legacy centralized launcher; added updated launch/train scripts that propagate offline-data options.

Copy link

copy-pr-bot bot commented Sep 8, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 8, 2025

Walkthrough

Adds dataset-prep utilities and scripts to build conversation datasets, new tools to compute/send/sample per-conversation hidden states and emit per-conversation .pt artifacts, integrates an offline training path (dataset/collator/dataloader/CLI wiring), updates a transformer plugin for offline usage, updates .gitignore, and removes a legacy launcher.

Changes

Cohort / File(s) Summary
Ignore updates
examples/speculative_decoding/.gitignore
Appends ignore patterns: input_conversations, synthetic_conversations, and ckpts.
Hidden-state collection package & scripts
examples/speculative_decoding/collect_hidden_states/__init__.py, examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py, examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py, examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py, examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh, examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh, examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
New package and CLIs/scripts to compute (HF), send (OpenAI-compatible), sample, and orchestrate collection of per-conversation hidden states; writes per-conversation .pt files containing input_ids, hidden_states, aux_hidden_states, and conversation_id.
Prepare input conversations utilities & scripts
examples/speculative_decoding/prepare_input_conversations/__init__.py, examples/speculative_decoding/prepare_input_conversations/utils.py, examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py, examples/speculative_decoding/prepare_input_conversations/add_mtbench.py, examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py, examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py, examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
New utilities for async download, deterministic conversation ID hashing, deduplication/append-to-splits, mixing/splitting strategies, and dataset-conversion scripts for Daring-Anteater, MTBench, ShareGPT, UltraChat, plus an example dataset-build script.
Offline training integration (EAGLE)
examples/speculative_decoding/eagle_utils.py, examples/speculative_decoding/main.py, examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/train_eagle3_and_export.sh
Adds OfflineSupervisedDataset and DataCollatorForOffline, extends make_eagle_supervised_data_module(..., use_offline_training), adds DataArguments.offline_data_path, wires launchers/train scripts to accept/validate offline-data paths, increases default training seq len (512→2048), and propagates offline args into invocations.
Transformer plugin updates
modelopt/torch/speculative/plugins/transformers.py
Adjusts HFEagleModel to read config.num_orig_hidden_layers under offline mode, prefer lm_head device when offline, and initialize a dummy DynamicCache() for past_key_values when base_model_outputs are provided; minor comments added.
Removed legacy launcher
examples/speculative_decoding/launch.sh
Deleted legacy bash launcher that previously orchestrated speculative-decoding experiments.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User
  participant Prep as Prepare scripts
  participant HF as compute_hidden_states_hf
  participant API as OpenAI-like server
  participant DS as Disk (.pt)
  participant Trainer as Trainer (offline)

  U->>Prep: run add_* → produce input_conversations JSONL
  alt local HF compute
    U->>HF: compute_hidden_states_hf (--model, --input-file)
    HF->>DS: write per-conversation .pt (input_ids, hidden_states, aux_hidden_states)
  else send to server
    U->>API: send_conversations_for_hiddens (--base-url, --input-file)
    API->>DS: server writes per-conversation .pt
  end
  U->>Trainer: launch_train.sh --offline-data PATH
  Trainer->>DS: read .pt files
  Trainer->>Trainer: OfflineSupervisedDataset + DataCollatorForOffline → training (eagle_offline=true)
Loading
sequenceDiagram
  autonumber
  participant C as send_conversations_for_hiddens
  participant Tok as Tokenizer
  participant Meta as /tmp/meta.json
  participant API as OpenAI-like endpoint
  participant Out as Output Dir

  C->>Tok: apply_chat_template(conversations) -> input_ids / prompt
  C->>Meta: write {conversation_id, output_file}
  C->>API: completions.create(model, prompt, max_tokens=1)
  alt success
    API-->>C: 200 OK
    Note right of API: serving engine coordinates writing .pt to Out
    API->>Out: .pt created
  else error / too long
    API-->>C: error
    C->>C: increment counters, skip
  end
  C->>Meta: cleanup entry
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I hop through prompts and stash them neat,
small .pt burrows where hidden thoughts meet.
I nibble tokens, tuck aux layers tight,
offline carrots fuel the training night.
A joyful hop — new tunnels in sight! 🥕🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Feature: Offline training for EAGLE3" is a concise, single-sentence summary that accurately reflects the PR's primary change—adding offline training support for EAGLE3 and related dataset/hidden-state tooling. It is specific, focused, and clear for a teammate scanning the repository history.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bchislett/offline-eagle-training

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.


# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
Copy link
Contributor

@h-guo18 h-guo18 Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious:

  1. With current implementation, will there by any pre-fetching mechanism to this tensor loading (perhaps taken care inside HF trainer)?
  2. Considering there's a limit in total disk bandwidth, will data loading possibly be a bottleneck limiting the training speed (if we further optimize the training loop, e.g. the TTT part)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes, I believe standard pytorch/HF datasets will handle prefetching, multiple loading worker processes, etc.
  2. Indeed, disk bandwidth can absolutely bottleneck training. If this becomes an issue, we can offset this using techniques such as compressed hidden-states on-disk, or using a file system with faster read speeds (e.g. using many smaller disks for high parallel read throughput). However, the speed-of-light for even a single moderate-quality disk is quite good, so networking a few 1-4TB disks together can easily saturate the (optimized) GPU throughput.

@yeyu-nvidia
Copy link
Contributor

Can we remove gen_synthetic_conversations from the PR as we have https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding/distributed_generate

@benchislett benchislett force-pushed the bchislett/offline-eagle-training branch 2 times, most recently from 4282dbe to f6cb37a Compare September 15, 2025 17:13
Copy link

codecov bot commented Sep 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.83%. Comparing base (74061f5) to head (5add4cb).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #300      +/-   ##
==========================================
- Coverage   73.84%   73.83%   -0.01%     
==========================================
  Files         172      172              
  Lines       17453    17453              
==========================================
- Hits        12888    12887       -1     
- Misses       4565     4566       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18
Copy link
Contributor

h-guo18 commented Sep 15, 2025

LGTM. Tried offline workflow with tinyllama and got reasonable AR.

@h-guo18 h-guo18 self-requested a review September 16, 2025 18:19
@benchislett benchislett force-pushed the bchislett/offline-eagle-training branch from 1839ab5 to f92be76 Compare September 16, 2025 19:23
@benchislett benchislett marked this pull request as ready for review September 16, 2025 19:24
@benchislett benchislett requested a review from a team as a code owner September 16, 2025 19:24
@benchislett
Copy link
Contributor Author

/ok to test f92be76

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
examples/speculative_decoding/main.py (1)

188-190: Good: honor EagleConfig.eagle_offline per prior feedback.
This aligns with the earlier request to use the existing config flag.

examples/speculative_decoding/prepare_input_conversations/utils.py (2)

26-36: Add a network timeout and avoid multi-context in one line for clarity.

Prevents indefinite hangs and improves readability; keep parent dir creation.

Apply:

 async def download_file(url: str, destination: Path) -> None:
     """Download a file from a URL to a specified destination."""
     destination.parent.mkdir(parents=True, exist_ok=True)
-    async with aiohttp.ClientSession() as session, session.get(url) as response:
-        if response.status != 200:
-            msg = f"Failed to download {url}: {response.status}"
-            raise RuntimeError(msg)
-        content = await response.read()
-        destination.write_bytes(content)
-        print(f"Downloaded {url} to {destination}")
+    timeout = aiohttp.ClientTimeout(total=600)
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(url) as response:
+            if response.status != 200:
+                msg = f"Failed to download {url}: {response.status}"
+                raise RuntimeError(msg)
+            content = await response.read()
+            destination.write_bytes(content)
+            print(f"Downloaded {url} to {destination}")

45-86: Prevent duplicate IDs within the same update and validate required keys.

Currently, duplicates inside the provided conversations list can be appended multiple times because existing_ids isn’t updated as you add. Also, missing "conversations" will raise KeyError later; validate early.

Apply:

 def add_conversations_to_split(conversations: list, dataset_dir: Path, split: str) -> None:
@@
-    # Open the dataset file for the specified split, or create it if it doesn't exist
-    dataset_file = dataset_dir / f"{split}.jsonl"
+    # Ensure output directory exists and open/create split file
+    dataset_dir.mkdir(parents=True, exist_ok=True)
+    dataset_file = dataset_dir / f"{split}.jsonl"
@@
-    existing_ids = {entry["conversation_id"] for entry in all_conversations}
+    existing_ids = {entry["conversation_id"] for entry in all_conversations}
     num_new_entries = 0
     num_duplicates = 0
     for entry in conversations:
-        if entry.get("conversation_id") is None:
+        entry_id = entry.get("conversation_id")
+        if entry_id is None:
             raise ValueError("Each conversation must have a 'conversation_id' field.")
-        if entry["conversation_id"] not in existing_ids:
+        if "conversations" not in entry:
+            raise ValueError("Each conversation must have a 'conversations' field.")
+        if entry_id not in existing_ids:
             all_conversations.append(
                 {
-                    "conversation_id": entry["conversation_id"],
+                    "conversation_id": entry_id,
                     "conversations": entry["conversations"],
                 }
             )
             num_new_entries += 1
+            existing_ids.add(entry_id)
         else:
             num_duplicates += 1
@@
-    dataset_dir.mkdir(parents=True, exist_ok=True)
     with dataset_file.open("w", encoding="utf-8") as f:
         for entry in all_conversations:
             f.write(json.dumps(entry, ensure_ascii=False) + "\n")
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)

27-29: Unify REMOVE_THINK_CHAT_TEMPLATE with training and guard None chat_template.

Avoid divergence from training preprocessing and prevent AttributeError when chat_template is None.

-from transformers import AutoModel, AutoTokenizer
+from transformers import AutoModel, AutoTokenizer
+try:
+    # Keep in sync with training; import if available when run from examples/speculative_decoding
+    from eagle_utils import REMOVE_THINK_CHAT_TEMPLATE
+except Exception:
+    # Fallback for standalone runs; ensure value matches eagle_utils.REMOVE_THINK_CHAT_TEMPLATE
+    REMOVE_THINK_CHAT_TEMPLATE = (
+        "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
+    )
@@
-REMOVE_THINK_CHAT_TEMPLATE = (
-    "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
-)
@@
-    tokenizer = AutoTokenizer.from_pretrained(args.model)
+    tokenizer = AutoTokenizer.from_pretrained(args.model)
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
-    tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
+    if getattr(tokenizer, "chat_template", None):
+        tokenizer.chat_template = tokenizer.chat_template.replace(
+            REMOVE_THINK_CHAT_TEMPLATE, ""
+        )

Also applies to: 96-100


138-148: Clamp/deduplicate auxiliary layer indices; avoid out‑of‑range/duplicates on small models.

Current logic can pick invalid indices (e.g., 2 when num_hidden_layers < 3).

-            # Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
+            # Extract hidden states from early/mid/late layers; clamp within [0, N-1]
             hidden_states = outputs.hidden_states
-            selected_layer_indices = [
-                2,
-                max(0, num_hidden_layers // 2),
-                max(1, num_hidden_layers - 3),
-            ]
-            selected_layer_indices = sorted(set(selected_layer_indices))
+            candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3]
+            selected_layer_indices = sorted(
+                {i for i in candidates if 0 <= i <= num_hidden_layers - 1}
+            )
+            if not selected_layer_indices:
+                selected_layer_indices = [0]
             aux_hidden_states = torch.cat(
                 [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
             )
🧹 Nitpick comments (12)
examples/speculative_decoding/main.py (3)

72-81: Annotate as Optional to match None default.

offline_data_path defaults to None but is typed as str. Make it str | None for consistency with eval_data_path and type checkers.

Apply this diff:

-    offline_data_path: str = field(
+    offline_data_path: str | None = field(

144-145: Make the offline switch robust to empty strings.

is not None treats "" as offline. Prefer truthiness.

Apply this diff:

-use_offline_training = data_args.offline_data_path is not None
+use_offline_training = bool(data_args.offline_data_path)

150-159: num_hidden_layers=0: add safe fallback and always set num_orig_hidden_layers.

Keep your space‑saving default but harden for models that can’t instantiate with 0 layers, and ensure num_orig_hidden_layers is set even when resuming or the override path changes.

Apply this diff:

-        model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
-        )
-        if use_offline_training:
-            # When doing offline training, we need to set num_hidden_layers
-            # since we override it when loading the model for space savings
-            model_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
-            model.config.num_orig_hidden_layers = model_config.num_hidden_layers
+        model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
+        try:
+            model = transformers.AutoModelForCausalLM.from_pretrained(
+                model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
+            )
+        except Exception as e:
+            if use_offline_training:
+                print_rank_0("num_hidden_layers=0 failed; falling back to 1 for offline training.")
+                model = transformers.AutoModelForCausalLM.from_pretrained(
+                    model_args.model_name_or_path, torch_dtype="auto", num_hidden_layers=1
+                )
+            else:
+                raise
+        if use_offline_training and not hasattr(model.config, "num_orig_hidden_layers"):
+            # Record original depth for plugins that need it.
+            base_cfg = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
+            model.config.num_orig_hidden_layers = base_cfg.num_hidden_layers

Please run a quick smoke test on at least one model per family you target (e.g., Llama, Mistral, Qwen2, Phi‑3) to confirm no fallback is triggered unexpectedly and training proceeds.

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (2)

1-1: Add a shebang to fix ShellCheck SC2148 and enable script execution.

Without a shebang, shells may interpret this file incorrectly and ShellCheck flags SC2148.

Apply:

+#!/usr/bin/env bash
 # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

16-25: Fix relative paths; make script location-agnostic; add strict mode and ensure output dir.

The script lives in prepare_input_conversations/, but calls into prepare_input_conversations/*.py again, which breaks when run from repo root or this directory. Also add set -euo pipefail and create a data dir.

Apply:

-# Example script to prepare a dataset of prompts for generation
-# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset.
+# Example script to prepare a dataset of prompts for generation
+# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset.
+set -euo pipefail
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+# Optional: change DATASET_DIR or pass via env
+DATASET_DIR="${DATASET_DIR:-${SCRIPT_DIR}/data}"
+mkdir -p "${DATASET_DIR}"

-python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train
-# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test
-python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test
+python3 "${SCRIPT_DIR}/add_daring_anteater.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_sharegpt.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_sft --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_gen --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_sft --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_gen --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+python3 "${SCRIPT_DIR}/add_mtbench.py" --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} 

Note: If the add_* CLIs don’t support --dataset-dir, the parameter will be omitted; the mkdir remains safe.

examples/speculative_decoding/prepare_input_conversations/utils.py (1)

93-125: Make ratio check tolerant to FP error and avoid mutating global RNG state.

isclose avoids false negatives; using a local Random(seed) preserves global RNG for callers.

Apply:

+import math
@@
-    if train_ratio + val_ratio + test_ratio != 1.0:
+    if not math.isclose(train_ratio + val_ratio + test_ratio, 1.0, rel_tol=0.0, abs_tol=1e-9):
         msg = "Ratios must sum to 1.0"
         raise ValueError(msg)
@@
-    if shuffle:
-        random.seed(seed)
-        random.shuffle(conversations)
+    if shuffle:
+        random.Random(seed).shuffle(conversations)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)

1-24: Add shebang/safe flags, fix debug flag name, and quote paths.

Without a shebang, running as an executable fails; ShellCheck SC2148 also flags this. The commented flag name doesn’t exist in the CLI (should be --debug-max-num-conversations). Also make invocation path-robust and quote args.

+#!/usr/bin/env bash
+set -euo pipefail
+
+# Resolve script directory to make path robust when invoked from anywhere.
+script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
-python3 collect_hidden_states/send_conversations_for_hiddens.py \
-  --model meta-llama/Llama-3.2-1B-Instruct \
-  --input-file synthetic_conversations/mtbench.jsonl \
-  --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
-# --debug-max-num-conversations-per-split 1000
+python3 "$script_dir/send_conversations_for_hiddens.py" \
+  --model "meta-llama/Llama-3.2-1B-Instruct" \
+  --input-file "synthetic_conversations/mtbench.jsonl" \
+  --output-dir "/mnt/md0/eagle-hidden-states/llama1b/mtbench/"
+# Optional: limit processed conversations during debugging
+# --debug-max-num-conversations 1000
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)

93-95: Set eval() before inference.

Minor but standard to disable dropout and ensure deterministic behavior.

-    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model.eval()

103-127: Counter name is misleading; it tracks too‑short and too‑long.

Tiny naming nit; adjust for clarity.

-    num_skipped_too_long = 0
+    num_filtered_by_length = 0
@@
-        if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
-            num_skipped_too_long += 1
+        if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
+            num_filtered_by_length += 1
             continue
@@
-    if num_skipped_too_long > 0:
-        print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
+    if num_filtered_by_length > 0:
+        print(f"Skipped {num_filtered_by_length} conversations due to length constraints.")

Also applies to: 163-164


82-83: Async isn’t used; simplify to synchronous main.

Removes unnecessary asyncio plumbing.

-async def main(args: argparse.Namespace) -> None:
+def main(args: argparse.Namespace) -> None:
@@
-if __name__ == "__main__":
-    cli_args = parse_args()
-    asyncio.run(main(cli_args))
+if __name__ == "__main__":
+    cli_args = parse_args()
+    main(cli_args)

Also applies to: 176-178

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)

1-23: Add shebang/safe flags, quote args, and create output dir.

Prevents exec failures (SC2148) and path issues; ensures output dir exists.

+#!/usr/bin/env bash
+set -euo pipefail
+
+# Resolve script directory
+script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Output directory (edit as needed)
+OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/"
+mkdir -p "$OUTPUT_DIR"
+
-python3 collect_hidden_states/compute_hidden_states_hf.py \
-  --model meta-llama/Llama-3.2-1B-Instruct \
-  --input-file synthetic_conversations/daring-anteater.jsonl \
-  --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+python3 "$script_dir/compute_hidden_states_hf.py" \
+  --model "meta-llama/Llama-3.2-1B-Instruct" \
+  --input-file "synthetic_conversations/daring-anteater.jsonl" \
+  --output-dir "$OUTPUT_DIR"
+# Optional:
+# --max-seq-len 3072
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)

16-36: Harden DP runner: shebang/safe flags, temp workspace, quoting, and mkdir.

Avoids /tmp collisions, improves safety, and makes paths robust.

+#!/usr/bin/env bash
+set -euo pipefail
+
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE="synthetic_conversations/daring-anteater.jsonl"
+OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/"
+mkdir -p "$OUTPUT_DIR"
+
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+tmpdir="$(mktemp -d)"
+trap 'rm -rf "$tmpdir"' EXIT
+split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" "$tmpdir/part-"
 
 for i in $(seq 0 7)
 do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
+  CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+    --model "meta-llama/Llama-3.2-1B-Instruct" \
+    --input-file "$tmpdir/part-0${i}.jsonl" \
+    --output-dir "$OUTPUT_DIR" &
 done
 wait
-
-rm /tmp/part-*.jsonl
+# Temporary files cleaned via trap
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f92be76 and f74bf59.

📒 Files selected for processing (21)
  • examples/speculative_decoding/.gitignore (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/__init__.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1 hunks)
  • examples/speculative_decoding/eagle_utils.py (4 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • examples/speculative_decoding/main.py (4 hunks)
  • examples/speculative_decoding/prepare_input_conversations/__init__.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/utils.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (4 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (13)
  • examples/speculative_decoding/.gitignore
  • examples/speculative_decoding/collect_hidden_states/init.py
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/prepare_input_conversations/init.py
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
  • modelopt/torch/speculative/plugins/transformers.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
📚 Learning: 2025-09-17T14:14:44.950Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.

Applied to files:

  • examples/speculative_decoding/main.py
🧬 Code graph analysis (4)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-312)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
  • parse_args (30-90)
  • main (93-206)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
  • parse_args (25-46)
  • main (49-84)
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
modelopt/torch/utils/random.py (2)
  • random (59-61)
  • shuffle (148-150)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
  • train (117-271)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
examples/speculative_decoding/main.py (2)

230-230: Good: data module gated by use_offline_training.
Clear separation between online/offline loaders and collators.


237-238: Nice: allow disabling AR validation with non-positive steps.
Useful for offline runs without inline eval.

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)

27-29: Verify single source of truth for REMOVE_THINK_CHAT_TEMPLATE

rg returned "No files were searched" due to the glob filter; cannot confirm duplicates. Ensure no other copies exist and training/collection stay in lockstep — from the repo root run and paste outputs of:
rg -n --hidden --no-ignore 'REMOVE_THINK_CHAT_TEMPLATE'
rg -n --hidden --no-ignore ''
Location: examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py:27-29

Comment on lines +129 to +131
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device handling with device_map='auto' (model.device may be missing or wrong).

Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
# Get hidden states
with torch.inference_mode():
emb_device = model.get_input_embeddings().weight.device
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.

@kevalmorabia97
Copy link
Collaborator

@benchislett You need to sign your commits with an SSH key. Please take a look at https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work

@benchislett benchislett force-pushed the bchislett/offline-eagle-training branch from bd27b86 to 8708a77 Compare September 22, 2025 17:14
@benchislett
Copy link
Contributor Author

"Code Quality" check failure seems erroneous to me. The namespace in question does seem to exist and be imported correctly:

error: Name "OfflineEagleDMRegistry" is not defined

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/launch_train.sh (1)

91-95: Guard division by zero when no GPUs are visible.

DEFAULT_SAVE_STEPS will divide by zero if torch reports 0 CUDA devices (e.g., CI, CPU-only env). Provide a sane fallback.

Apply:

 GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+# Calculate save_steps
+if [[ "${GPU_COUNT}" -gt 0 ]]; then
+  DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+else
+  echo "Warning: No GPUs detected by torch; defaulting save_steps to 8192."
+  DEFAULT_SAVE_STEPS=8192
+fi
🧹 Nitpick comments (19)
examples/speculative_decoding/launch_train.sh (1)

127-136: Good: validates offline data path and builds args.

Consider also quoting the path to tolerate spaces.

Apply:

-    OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH"
+    OFFLINE_TRAINING_ARGS="--offline-data-path \"$OFFLINE_DATA_PATH\""
modelopt/torch/speculative/plugins/transformers.py (1)

338-349: Clamp aux layer IDs when offline and num_layers <= 0.

If num_layers resolves to 0/None offline, current defaults may produce invalid indices (e.g., 1). Clamp to valid range and drop negatives.

Apply:

-        num_layers = self.config.num_hidden_layers
-        if self.eagle_offline and (num_layers is None or num_layers <= 0):
-            num_layers = getattr(self.config, "num_orig_hidden_layers", 0)
-
-        self.eagle_config.eagle_aux_hidden_state_layer_ids = [
-            1,
-            max(0, num_layers // 2 - 1),
-            max(0, num_layers - 4),
-        ]
-        self.eagle_config.eagle_aux_hidden_state_layer_ids = list(
-            set(self.eagle_config.eagle_aux_hidden_state_layer_ids)
-        )
+        num_layers = self.config.num_hidden_layers
+        if self.eagle_offline and (num_layers is None or num_layers <= 0):
+            num_layers = getattr(self.config, "num_orig_hidden_layers", 0)
+        # Propose defaults then clamp into [0, max(0, num_layers-1)] and dedupe/sort
+        candidate_ids = [1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)]
+        hi = max(0, num_layers - 1)
+        clamped = sorted({min(max(i, 0), hi) for i in candidate_ids})
+        self.eagle_config.eagle_aux_hidden_state_layer_ids = clamped
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)

16-23: Add shebang and safe shell flags; optional: parametrize endpoint via env.

Apply:

+#!/usr/bin/env bash
+set -euo pipefail
+
 python3 collect_hidden_states/send_conversations_for_hiddens.py \
   --model meta-llama/Llama-3.2-1B-Instruct \
   --input-file synthetic_conversations/mtbench.jsonl \
   --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
 # --debug-max-num-conversations-per-split 1000
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)

16-26: Add shebang/safe flags and make calls cwd-agnostic.

Use script-relative paths so it works from any directory.

Apply:

+#!/usr/bin/env bash
+set -euo pipefail
+
+# Example script to prepare a dataset of prompts for generation
 # Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset.
 
-python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train
-# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test
-python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test
+ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+python3 "$ROOT/prepare_input_conversations/add_daring_anteater.py" --output-split-name train
+# python3 "$ROOT/prepare_input_conversations/add_sharegpt.py" --output-split-name train
+# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split train_sft --output-split-name train
+# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split train_gen --output-split-name train
+# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split test_sft --output-split-name mix_test
+# python3 "$ROOT/prepare_input_conversations/add_ultrachat.py" --ultrachat-split test_gen --output-split-name mix_test
+python3 "$ROOT/prepare_input_conversations/add_mtbench.py" --output-split-name mix_test
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)

16-36: Add shebang/safe flags; consider quoting vars and checking split results.

Apply:

+#!/usr/bin/env bash
+set -euo pipefail
+
 # Example usage of the script to compute the hidden states for a conversation dataset
 # This script computes hidden states using a Hugging Face model and saves them to
 # the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
 # the input file into 8 parts and running 8 processes in parallel, one on each GPU.
@@
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE="synthetic_conversations/daring-anteater.jsonl"
+OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/"
 
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part-
 
 for i in $(seq 0 7)
 do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
+  CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py \
+    --model meta-llama/Llama-3.2-1B-Instruct \
+    --input-file "/tmp/part-0${i}.jsonl" \
+    --output-dir "$OUTPUT_DIR" &
 done
 wait
 
-rm /tmp/part-*.jsonl
+rm -f /tmp/part-*.jsonl
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)

63-63: Load tensors onto CPU to avoid CUDA dependency.

Apply:

-        data = torch.load(file)
+        data = torch.load(file, map_location="cpu")

64-73: Don’t require exact key set; accept supersets.

Future producers may add fields; only assert required keys present.

Apply:

-        expected_keys = [
+        expected_keys = [
             "input_ids",
             "hidden_states",
             "aux_hidden_states",
             "conversation_id",
         ]
-        if set(expected_keys) != set(data.keys()):
-            print(f"File {file} does not contain all expected keys: {expected_keys}")
-            print(f"  Found keys: {list(data.keys())}")
-            continue
+        missing = [k for k in expected_keys if k not in data]
+        if missing:
+            print(f"File {file} missing required keys: {missing}. Found: {list(data.keys())}")
+            continue
examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (2)

23-27: Make imports robust to script working directory/package execution.

from utils import ... will fail if invoked outside this folder or as a module. Prefer relative import with a fallback for script execution.

-from utils import (
+try:
+    from .utils import (
+        dataset_splits_explanation,
+        id_for_conversation,
+        update_dataset_file_with_conversations,
+    )
+except ImportError:
+    from utils import (
         dataset_splits_explanation,
         id_for_conversation,
         update_dataset_file_with_conversations,
-)
+    )

61-65: Prefer iterating rows directly to avoid repeated random access.

Hugging Face datasets support direct iteration; it’s clearer and avoids repeated indexing overhead.

-    for i in tqdm(
-        range(len(ds)),
+    for i in tqdm(
+        range(len(ds)),  # or: enumerate(ds)
         desc=f"Loading UltraChat split {args.ultrachat_split}",
         total=len(ds),
     ):

If acceptable, switch to for i, row in enumerate(tqdm(ds, ...)) and use row[...].

examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (2)

23-27: Harden import path (same concern as other prep scripts).

Apply the same relative-import fallback pattern suggested for add_ultrachat.py to avoid ImportError when run outside the folder.


63-76: Defensive role parsing.

.lower() on a missing/None role will raise. The diff above normalizes via str(...).lower() and falls back cleanly.

examples/speculative_decoding/eagle_utils.py (2)

176-236: Offline dataset: load safely on CPU and validate keys/shapes.

  • Use map_location="cpu" to avoid accidental CUDA deserialization errors.
  • Validate presence of expected keys and provide a concise error.
-        offline_data = torch.load(offline_file_path)
+        offline_data = torch.load(offline_file_path, map_location="cpu")
+        for k in ("input_ids", "hidden_states", "aux_hidden_states"):
+            if k not in offline_data:
+                raise ValueError(f"Missing key '{k}' in offline data: {offline_file_path}")
         offline_data["input_ids"] = offline_data["input_ids"][:max_length]
         offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
         offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
 
         # Make sure the input_ids have the same shape
         if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape:
-            msg = f"""Input IDs from offline data do not match the preprocessed input IDs
-                                for offline data sample at {offline_file_path}."""
-            raise ValueError(msg)
+            raise ValueError(
+                f"Shape mismatch for input_ids at {offline_file_path}: "
+                f"offline={tuple(offline_data['input_ids'].shape)} "
+                f"preproc={tuple(preprocessed_base['input_ids'].shape)}"
+            )

251-256: Handle blank lines in JSONL and set explicit encoding.

Minor robustness: ignore empty lines and open with utf‑8.

-    with open(data_args.data_path) as f:
-        if data_args.data_path.endswith("jsonl"):
-            data_json = [json.loads(line) for line in f]
+    with open(data_args.data_path, "r", encoding="utf-8") as f:
+        if data_args.data_path.endswith("jsonl"):
+            data_json = [json.loads(line) for line in f if line.strip()]
         else:
             data_json = json.load(f)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)

95-95: Null‑safe chat template handling.

Some tokenizers have no chat_template. Guard before replace.

-    tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
+    if getattr(tokenizer, "chat_template", None):
+        tokenizer.chat_template = tokenizer.chat_template.replace(
+            REMOVE_THINK_CHAT_TEMPLATE, ""
+        )
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (3)

23-28: Import robustness (same pattern as other prep scripts).

Apply the relative‑import fallback used in prior comments to avoid ImportError when running from different CWDs.


62-85: Be lenient on unknown roles; prefer skip over raise to avoid aborting long runs.

Unexpected roles in ShareGPT can abort the entire job. Treat them as “skip this conversation” rather than raising.

-        else:
-            err_msg = f"Unknown role in conversation: {turn.get('from')}"
-            raise ValueError(err_msg)
+        else:
+            # Unknown role; skip this conversation
+            return None

87-101: Download step: ensure parent directory exists.

~/.cache may not exist. download_file creates parents, but direct existence check may fail earlier if parent is missing. Create dirs before checking.

-        args.sharegpt_file = Path("~/.cache/sharegpt.json").expanduser().resolve()
+        args.sharegpt_file = Path("~/.cache/sharegpt.json").expanduser()
+        args.sharegpt_file.parent.mkdir(parents=True, exist_ok=True)
+        args.sharegpt_file = args.sharegpt_file.resolve()
examples/speculative_decoding/prepare_input_conversations/utils.py (2)

103-110: Avoid strict float equality for ratio sum.

Use a small tolerance to prevent false negatives due to FP rounding.

-    if train_ratio + val_ratio + test_ratio != 1.0:
+    total = train_ratio + val_ratio + test_ratio
+    if abs(total - 1.0) > 1e-9:
         msg = "Ratios must sum to 1.0"
         raise ValueError(msg)

115-124: Don’t mutate caller’s list or global RNG state when shuffling.

Shuffle a copy with a local RNG for determinism without side effects.

-    if shuffle:
-        random.seed(seed)
-        random.shuffle(conversations)
+    if shuffle:
+        rng = random.Random(seed)
+        pool = conversations.copy()
+        rng.shuffle(pool)
+    else:
+        pool = conversations
@@
-    train_conversations = conversations[:train_count]
-    val_conversations = conversations[train_count : train_count + val_count]
-    test_conversations = conversations[train_count + val_count :]
+    train_conversations = pool[:train_count]
+    val_conversations = pool[train_count : train_count + val_count]
+    test_conversations = pool[train_count + val_count :]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd27b86 and 8708a77.

📒 Files selected for processing (21)
  • examples/speculative_decoding/.gitignore (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/__init__.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1 hunks)
  • examples/speculative_decoding/eagle_utils.py (4 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • examples/speculative_decoding/main.py (4 hunks)
  • examples/speculative_decoding/prepare_input_conversations/__init__.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/utils.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (4 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
✅ Files skipped from review due to trivial changes (1)
  • examples/speculative_decoding/collect_hidden_states/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
  • examples/speculative_decoding/prepare_input_conversations/init.py
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/.gitignore
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T14:14:44.961Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.

Applied to files:

  • modelopt/torch/speculative/plugins/transformers.py
  • examples/speculative_decoding/main.py
🧬 Code graph analysis (9)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
  • parse_args (32-79)
  • main (82-171)
modelopt/torch/speculative/plugins/transformers.py (3)
modelopt/torch/opt/dynamic.py (2)
  • register (1069-1096)
  • config (1265-1278)
modelopt/torch/speculative/eagle/eagle_model.py (1)
  • EagleModel (23-51)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _set_default_aux_hidden_state_layers (682-694)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-312)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
  • parse_args (30-90)
  • main (93-204)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
  • parse_args (25-46)
  • main (49-84)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
  • dataset_splits_explanation (157-172)
  • download_file (26-35)
  • id_for_conversation (38-42)
  • update_dataset_file_with_conversations (127-154)
examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
  • dataset_splits_explanation (157-172)
  • id_for_conversation (38-42)
  • update_dataset_file_with_conversations (127-154)
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
  • dataset_splits_explanation (157-172)
  • id_for_conversation (38-42)
  • update_dataset_file_with_conversations (127-154)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
  • train (117-271)
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (22)
examples/speculative_decoding/launch_train.sh (3)

33-36: Good: adds offline-data plumbed through to main.

Flag and value parsing look fine and align with DataArguments.offline_data_path.


111-113: Increase default training sequence length to 2048 — OK.


168-169: OK: injects offline args into CMD.

modelopt/torch/speculative/plugins/transformers.py (3)

314-314: Parameter name corrected to past_key_values.

Aligns with newer Transformers API; resolves earlier confusion.


425-429: Good: device selection for offline models via lm_head.

Avoids indexing model.layers when depth is zero.


811-813: Cache handling for offline/legacy paths looks correct.

Dummy DynamicCache for offline, and from_legacy conversion when needed.

Also applies to: 826-828

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)

16-23: Add shebang and safe shell flags; ensure script fails fast.

Apply:

+#!/usr/bin/env bash
+set -euo pipefail
+
 python3 collect_hidden_states/compute_hidden_states_hf.py \
   --model meta-llama/Llama-3.2-1B-Instruct \
   --input-file synthetic_conversations/daring-anteater.jsonl \
   --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
examples/speculative_decoding/main.py (6)

72-81: Docs read well; field name aligns with launch script.


144-145: Simple switch to enable offline training — OK.


188-190: Correct: set EagleConfig.eagle_offline via config when offline.


230-231: Good: pass use_offline_training into data module.


237-239: Nice: allow disabling AR validation via non-positive steps.


150-158: num_hidden_layers=0 override — keep, but verify across Transformers versions/models.

Attempted verification here but the environment raised ModuleNotFoundError: No module named 'transformers'. Run the provided check locally or in CI with Transformers installed matching the repo constraint (>=4.48,<5.0) against these models: TinyLlama/TinyLlama-1.1B-Chat-v1.0, mistralai/Mistral-7B-v0.1, Qwen/Qwen2-1.5B, microsoft/Phi-3-mini-4k-instruct; if any fail, remove or add model-specific handling.
File: examples/speculative_decoding/main.py:150-158

examples/speculative_decoding/eagle_utils.py (2)

271-281: Conversation ID fallback logic looks good and resolves prior “conv_id=0” bug.

Using is None avoids treating 0 as falsy. This addresses previous feedback.


353-380: Collator for offline hidden states looks correct.

Padding logic is consistent with base collator; keys match dataset outputs; batch structure is clear.

Please confirm downstream training expects:

  • base_model_outputs.base_model_hidden_states: [B, L, D_base]
  • base_model_outputs.aux_hidden_states: [B, L, D_aux]
    and that both are on CPU at collate time and moved to device by the trainer.
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)

136-146: Clamp/dedupe auxiliary layer indices for small-layer models.

Index 2 can be out of range when num_hidden_layers < 2. Clamp to [0, num_hidden_layers] and dedupe in order.

-            selected_layer_indices = [
-                2,
-                max(0, num_hidden_layers // 2),
-                max(1, num_hidden_layers - 3),
-            ]
-            selected_layer_indices = sorted(set(selected_layer_indices))
+            candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3]
+            lo, hi = 0, num_hidden_layers  # hidden_states has num_hidden_layers+1 entries
+            selected_layer_indices = []
+            seen = set()
+            for i in candidates:
+                i = min(max(i, lo), hi)
+                if i not in seen:
+                    seen.add(i)
+                    selected_layer_indices.append(i)
+            if not selected_layer_indices:
+                selected_layer_indices = [hi]

129-135: Fix device handling for device_map='auto'.

model.device may be missing/incorrect under sharded loading. Use the embeddings’ device.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
examples/speculative_decoding/prepare_input_conversations/utils.py (5)

87-89: Resolved: output dir is created before writing.

Creating dataset_dir before opening the file addresses first-run failures.


38-43: Stable content hash looks good.

Deterministic ID via sorted JSON + fixed separators is appropriate here.

Confirm upstream producers don’t inject nondeterministic fields (timestamps, UUIDs) into conversation before hashing, or IDs will change run-to-run.


157-172: Helpful split helptext.

Clear and accurate, including percent-escaping.


26-35: Add timeout and stream to disk to avoid hangs and large-memory reads.

No network timeout and await response.read() can hang and load large files into RAM. Use ClientTimeout and stream chunks.

 async def download_file(url: str, destination: Path) -> None:
     """Download a file from a URL to a specified destination."""
-    destination.parent.mkdir(parents=True, exist_ok=True)
-    async with aiohttp.ClientSession() as session, session.get(url) as response:
-        if response.status != 200:
-            msg = f"Failed to download {url}: {response.status}"
-            raise RuntimeError(msg)
-        content = await response.read()
-        destination.write_bytes(content)
-        print(f"Downloaded {url} to {destination}")
+    destination.parent.mkdir(parents=True, exist_ok=True)
+    timeout = aiohttp.ClientTimeout(total=600)
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(url) as response:
+            if response.status // 100 != 2:
+                raise RuntimeError(f"Failed to download {url}: {response.status}")
+            with destination.open("wb") as f:
+                async for chunk in response.content.iter_chunked(1 << 20):
+                    f.write(chunk)
+    print(f"Downloaded {url} to {destination}")

62-77: Prevent duplicate IDs within the same batch and validate required key.

If the input contains duplicate conversation_ids in the same call, both are written because existing_ids isn’t updated. Also, missing "conversations" will raise KeyError later; validate explicitly.

     existing_ids = {entry["conversation_id"] for entry in all_conversations}
@@
     for entry in conversations:
         if entry.get("conversation_id") is None:
             raise ValueError("Each conversation must have a 'conversation_id' field.")
+        if "conversations" not in entry:
+            raise ValueError("Each conversation must have a 'conversations' field.")
         if entry["conversation_id"] not in existing_ids:
             all_conversations.append(
                 {
                     "conversation_id": entry["conversation_id"],
                     "conversations": entry["conversations"],
                 }
             )
+            existing_ids.add(entry["conversation_id"])
             num_new_entries += 1
         else:
             num_duplicates += 1

Comment on lines +62 to +89
processed_conversations = []
for msg in conversations:
if "from" in msg:
role = msg["from"].lower()
elif "role" in msg:
role = msg["role"].lower()
else:
continue
if role == "human":
role = "user"
elif role == "gpt":
role = "assistant"

if "value" in msg:
content = msg["value"]
elif "text" in msg:
content = msg["text"]
elif "content" in msg:
content = msg["content"]
else:
continue
content = content.strip()
if content:
processed_conversations.append({"role": role, "content": content})

input_conversations.append(
{"conversation_id": prompt_id, "conversations": processed_conversations}
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Skip empty conversations and compute ID from processed messages.

  • If no valid messages are added, you’ll append an entry with an empty list; downstream preprocess() indexes source[0] and will crash.
  • The conversation_id should be derived from the processed messages to ensure deduplication matches the actual content used for training.
-        conversations = ds[i]["conversations"]
+        conversations = ds[i]["conversations"]
         if conversations and isinstance(conversations, list):
-            prompt_id = f"daring-anteater-{i:05}_" + id_for_conversation(conversations)
             processed_conversations = []
             for msg in conversations:
-                if "from" in msg:
-                    role = msg["from"].lower()
-                elif "role" in msg:
-                    role = msg["role"].lower()
+                from_field = msg.get("from")
+                role_field = msg.get("role")
+                role = (str(from_field or role_field or "")).lower()
                 else:
                     continue
                 if role == "human":
                     role = "user"
                 elif role == "gpt":
                     role = "assistant"
 
-                if "value" in msg:
-                    content = msg["value"]
-                elif "text" in msg:
-                    content = msg["text"]
-                elif "content" in msg:
-                    content = msg["content"]
+                if "value" in msg:
+                    content = msg["value"]
+                elif "text" in msg:
+                    content = msg["text"]
+                elif "content" in msg:
+                    content = msg["content"]
                 else:
                     continue
                 content = content.strip()
                 if content:
                     processed_conversations.append({"role": role, "content": content})
 
-            input_conversations.append(
-                {"conversation_id": prompt_id, "conversations": processed_conversations}
-            )
+            if processed_conversations:
+                cid = id_for_conversation(processed_conversations)
+                prompt_id = f"daring-anteater-{i:05}_{cid}"
+                input_conversations.append(
+                    {"conversation_id": prompt_id, "conversations": processed_conversations}
+                )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
around lines 62 to 89, the current logic can append entries with empty
conversations and uses prompt_id instead of an ID derived from the actual
processed messages; update the code to (1) skip appending when
processed_conversations is empty so downstream preprocess() does not index into
an empty source, and (2) compute conversation_id from the
processed_conversations (for example by deterministic serialization and hashing
of the sequence of role+content pairs) so deduplication is based on the actual
messages used for training.

Comment on lines +66 to +69
prompt = ds[i]["prompt"].strip()
prompt_id = ds[i]["prompt_id"].strip()
if prompt and prompt_id:
msgs = [{"role": "user", "content": prompt}]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against missing/typed fields before calling .strip().

Some rows may have None/missing keys or non‑string prompt_id. Current code will crash. Coerce to str and validate.

-        prompt = ds[i]["prompt"].strip()
-        prompt_id = ds[i]["prompt_id"].strip()
-        if prompt and prompt_id:
+        prompt_raw = ds[i].get("prompt", "")
+        prompt_id_raw = ds[i].get("prompt_id", "")
+        prompt = str(prompt_raw).strip() if prompt_raw is not None else ""
+        prompt_id = str(prompt_id_raw).strip() if prompt_id_raw is not None else ""
+        if prompt and prompt_id:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
prompt = ds[i]["prompt"].strip()
prompt_id = ds[i]["prompt_id"].strip()
if prompt and prompt_id:
msgs = [{"role": "user", "content": prompt}]
prompt_raw = ds[i].get("prompt", "")
prompt_id_raw = ds[i].get("prompt_id", "")
prompt = str(prompt_raw).strip() if prompt_raw is not None else ""
prompt_id = str(prompt_id_raw).strip() if prompt_id_raw is not None else ""
if prompt and prompt_id:
msgs = [{"role": "user", "content": prompt}]
🤖 Prompt for AI Agents
In examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
around lines 66 to 69, the code calls .strip() directly on ds[i]["prompt"] and
ds[i]["prompt_id"], which can crash if fields are missing or not strings; change
to safely access with dict.get (or check "prompt" in ds[i]), coerce values to
str only if not None (e.g., val = ds[i].get("prompt"); prompt = str(val).strip()
if val is not None else ""), validate that prompt and prompt_id are non-empty
after stripping, and ensure prompt_id is of the expected form (or converted)
before using it to build msgs; only create msgs when both cleaned strings are
non-empty.

@yeyu-nvidia
Copy link
Contributor

@benchislett #300 (comment)
can you rebase your PR to main? I have deleted this namespace

@h-guo18 h-guo18 force-pushed the bchislett/offline-eagle-training branch from 8708a77 to bf8d1a8 Compare September 22, 2025 22:20
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/speculative/plugins/transformers.py (1)

569-585: Mask dtype mismatch will crash concat (bool vs float)

attention_mask_0 is float (from expand_mask), but you create boolean masks (.bool()) and concatenate with floats → runtime error. Also assign booleans into float tensors. Keep masks in the same dtype as attention_mask_0.

Pattern fix for this block (apply similarly to the later two blocks at Lines 602-627 and 642-681):

-            zero_mask = torch.ones_like(attention_mask_0).bool()
-            mask_2_1 = attention_mask_0.clone().detach()
+            zero_mask = torch.ones_like(attention_mask_0)
+            mask_2_1 = attention_mask_0.clone().detach()
             mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:]
-            mask_2_2 = torch.ones_like(attention_mask_0).bool()
+            mask_2_2 = torch.ones_like(attention_mask_0)
             for i in range(1, seq_length - 1):
-                mask_2_2[:, :, i, i] = False
+                mask_2_2[:, :, i, i] = 0
             cat_attention_mask = torch.cat(
                 (
                     torch.cat((attention_mask_0, zero_mask), dim=-1),
                     torch.cat((mask_2_1, mask_2_2), dim=-1),
                 ),
                 dim=-2,
             )
-            cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
+            cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
examples/speculative_decoding/eagle_utils.py (1)

315-351: Pad labels with IGNORE_TOKEN_ID, not 0

Current collator pads labels with 0, which will be treated as valid targets and corrupt loss. Add a pad value parameter and use IGNORE_TOKEN_ID for labels.

 class DataCollatorWithPadding:
-    def paddingtensor2d(self, intensors, length):
+    def paddingtensor2d(self, intensors, length, pad_value=0):
         n, dim = intensors.shape
-        padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
+        padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) + pad_value
         outtensors = torch.cat((intensors, padding_tensor))
         return outtensors
 
-    def paddingtensor(self, intensors, length):
-        padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype)
+    def paddingtensor(self, intensors, length, pad_value=0):
+        padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) + pad_value
         outtensors = torch.cat((intensors, padding_tensor))
         return outtensors
 
     def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
         max_length = max(item["input_ids"].shape[0] for item in features)
         batch_input_ids = torch.stack(
-            [self.paddingtensor(item["input_ids"], max_length) for item in features]
+            [self.paddingtensor(item["input_ids"], max_length, pad_value=0) for item in features]
         )
         batch_attention_mask = torch.stack(
-            [self.paddingtensor(item["attention_mask"], max_length) for item in features]
+            [self.paddingtensor(item["attention_mask"], max_length, pad_value=0) for item in features]
         )
         batch_loss_mask = torch.stack(
-            [self.paddingtensor(item["loss_mask"], max_length) for item in features]
+            [self.paddingtensor(item["loss_mask"], max_length, pad_value=0) for item in features]
         )
 
         batch_labels = torch.stack(
-            [self.paddingtensor(item["labels"], max_length) for item in features]
+            [self.paddingtensor(item["labels"], max_length, pad_value=IGNORE_TOKEN_ID) for item in features]
         )
🧹 Nitpick comments (11)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)

89-91: Ensure eval mode for deterministic inference.

Set model.eval() after loading.

-    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model.eval()

18-20: Drop asyncio (not used) or add awaits.

main is async but contains no awaits; simplify to a regular function.

-import asyncio
@@
-async def main(args: argparse.Namespace) -> None:
+def main(args: argparse.Namespace) -> None:
@@
 if __name__ == "__main__":
     cli_args = parse_args()
-    asyncio.run(main(cli_args))
+    main(cli_args)

Also applies to: 82-82, 174-176


112-113: Sanitize conversation_id for filenames.

Untrusted conversation_id could contain path separators or illegal chars.

+import re
@@
-        conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
+        raw_id = entry.get("conversation_id", "{:08d}".format(idx))
+        conversation_id = re.sub(r"[^A-Za-z0-9._-]", "_", str(raw_id))

Also applies to: 148-159, 18-26

modelopt/torch/speculative/plugins/transformers.py (1)

337-343: Safer fallback when num_orig_hidden_layers missing

If offline and num_hidden_layers<=0 but num_orig_hidden_layers is absent, you set num_layers=0, yet later include layer id 1. Use a non‑zero fallback.

-        if self.eagle_offline and (num_layers is None or num_layers <= 0):
-            num_layers = getattr(self.config, "num_orig_hidden_layers", 0)
+        if self.eagle_offline and (num_layers is None or num_layers <= 0):
+            num_layers = getattr(self.config, "num_orig_hidden_layers", self.config.num_hidden_layers)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)

26-36: Add a network timeout to avoid hanging downloads

aiohttp.ClientSession() without timeout can hang indefinitely.

-async def download_file(url: str, destination: Path) -> None:
+async def download_file(url: str, destination: Path) -> None:
     """Download a file from a URL to a specified destination."""
     destination.parent.mkdir(parents=True, exist_ok=True)
-    async with aiohttp.ClientSession() as session, session.get(url) as response:
+    timeout = aiohttp.ClientTimeout(total=600)
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(url) as response:
             if response.status != 200:
                 msg = f"Failed to download {url}: {response.status}"
                 raise RuntimeError(msg)
             content = await response.read()
             destination.write_bytes(content)
             print(f"Downloaded {url} to {destination}")

104-109: Float-sum equality is brittle

Use tolerance when validating ratios.

-    if train_ratio + val_ratio + test_ratio != 1.0:
+    if abs((train_ratio + val_ratio + test_ratio) - 1.0) > 1e-6:
         msg = "Ratios must sum to 1.0"
         raise ValueError(msg)

50-61: Validate 'conversations' presence on existing data

Guard against malformed prior entries to avoid KeyError later.

-    if any(not entry.get("conversation_id") for entry in all_conversations):
+    if any(not entry.get("conversation_id") for entry in all_conversations):
         msg = "All existing conversations must have a 'conversation_id' field."
         raise ValueError(msg)
+    if any("conversations" not in entry for entry in all_conversations):
+        msg = "All existing conversations must have a 'conversations' field."
+        raise ValueError(msg)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)

1-26: Add shebang and strict mode; ensure output dir exists

Improves portability and early failure.

+#!/usr/bin/env bash
+set -euo pipefail
+
+# Ensure default output dir exists (scripts default to input_conversations/)
+mkdir -p input_conversations
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (3)

1-1: Add a shebang (or ShellCheck directive) to satisfy SC2148 and make the script directly executable.

Without a shebang, shells and tooling can misinterpret the script.

Apply one of the following diffs (preferred: add a shebang):

+#!/usr/bin/env bash
 # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 # SPDX-License-Identifier: Apache-2.0

Or, if you intentionally keep it sourced/run via bash file.sh, add a ShellCheck directive:

+# shellcheck shell=bash
 # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 # SPDX-License-Identifier: Apache-2.0

16-23: Harden and parameterize the example so it runs from any cwd and fails fast on errors.

Resolve brittle relative paths, create output dir, and enforce required env for the OpenAI‑compatible endpoint.

Apply this diff:

-# Example usage of the script to send conversations for hidden state collection
-# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.
-
-python3 collect_hidden_states/send_conversations_for_hiddens.py \
-  --model meta-llama/Llama-3.2-1B-Instruct \
-  --input-file synthetic_conversations/mtbench.jsonl \
-  --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
-# --debug-max-num-conversations-per-split 1000
+# Example usage of the script to send conversations for hidden state collection.
+# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states.
+
+set -euo pipefail
+IFS=$'\n\t'
+
+# Resolve repo root relative to this script so it works from any cwd
+SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
+
+# Required env for OpenAI-compatible server (fail early if missing)
+: "${OPENAI_API_BASE:?Set OPENAI_API_BASE to your server URL, e.g. http://127.0.0.1:8000/v1}"
+: "${OPENAI_API_KEY:?Set OPENAI_API_KEY (dummy value is fine for many local servers)}"
+
+# Tunables (override via env: MODEL, INPUT_FILE, OUTPUT_DIR)
+MODEL="${MODEL:-meta-llama/Llama-3.2-1B-Instruct}"
+INPUT_FILE="${INPUT_FILE:-$REPO_ROOT/examples/speculative_decoding/synthetic_conversations/mtbench.jsonl}"
+OUTPUT_DIR="${OUTPUT_DIR:-/mnt/md0/eagle-hidden-states/llama1b/mtbench}"
+
+# Sanity checks
+[ -f "$INPUT_FILE" ] || { echo "Input not found: $INPUT_FILE" >&2; exit 1; }
+mkdir -p "$OUTPUT_DIR"
+
+python3 "$REPO_ROOT/collect_hidden_states/send_conversations_for_hiddens.py" \
+  --model "$MODEL" \
+  --input-file "$INPUT_FILE" \
+  --output-dir "$OUTPUT_DIR"
+# Optional:
+#   --debug-max-num-conversations-per-split 1000

19-23: Make the script executable in git.

If this is intended to be run directly, mark it executable.

#!/bin/bash
# From repo root
git update-index --chmod=+x examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8708a77 and bf8d1a8.

📒 Files selected for processing (21)
  • examples/speculative_decoding/.gitignore (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/__init__.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1 hunks)
  • examples/speculative_decoding/eagle_utils.py (4 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • examples/speculative_decoding/main.py (4 hunks)
  • examples/speculative_decoding/prepare_input_conversations/__init__.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/utils.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (4 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (10)
  • examples/speculative_decoding/.gitignore
  • examples/speculative_decoding/collect_hidden_states/init.py
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
  • examples/speculative_decoding/prepare_input_conversations/init.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T14:14:44.961Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.

Applied to files:

  • modelopt/torch/speculative/plugins/transformers.py
  • examples/speculative_decoding/main.py
🧬 Code graph analysis (6)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/speculative/plugins/transformers.py (2)
modelopt/torch/speculative/eagle/eagle_model.py (1)
  • EagleModel (23-51)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _set_default_aux_hidden_state_layers (682-694)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-312)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
  • dataset_splits_explanation (157-172)
  • download_file (26-35)
  • id_for_conversation (38-42)
  • update_dataset_file_with_conversations (127-154)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
  • parse_args (30-90)
  • main (93-204)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
  • parse_args (25-46)
  • main (49-84)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
  • train (117-271)
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (11)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)

92-96: Guard chat_template access; avoid AttributeError and centralize constant.

.chat_template can be None; direct .replace(...) will crash. Also prefer importing the shared constant to avoid drift.

Apply:

-    tokenizer = AutoTokenizer.from_pretrained(args.model)
+    tokenizer = AutoTokenizer.from_pretrained(args.model)
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
-    tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
+    if getattr(tokenizer, "chat_template", None):
+        tokenizer.chat_template = tokenizer.chat_template.replace(
+            REMOVE_THINK_CHAT_TEMPLATE, ""
+        )

And replace the local constant with a single source of truth:

-from transformers import AutoModel, AutoTokenizer
-
-REMOVE_THINK_CHAT_TEMPLATE = (
-    "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
-)
+from transformers import AutoModel, AutoTokenizer
+from eagle_utils import REMOVE_THINK_CHAT_TEMPLATE  # keep centralized

127-131: Fix device placement with device_map='auto'.

model.device can be missing/wrong under sharding; send inputs to the embeddings device.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)

136-146: Harden layer selection; avoid OOB/duplicates for small-layer models.

Index 2 can be out of range; clamp/dedupe defensively.

-            selected_layer_indices = [
-                2,
-                max(0, num_hidden_layers // 2),
-                max(1, num_hidden_layers - 3),
-            ]
-            selected_layer_indices = sorted(set(selected_layer_indices))
+            candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3]
+            selected_layer_indices = sorted(
+                {i for i in candidates if 0 <= i <= num_hidden_layers - 1}
+            )
+            if not selected_layer_indices:
+                raise ValueError(
+                    f"No valid auxiliary layer indices for num_hidden_layers={num_hidden_layers}."
+                )
modelopt/torch/speculative/plugins/transformers.py (3)

309-318: Double‑check kwarg name for decoder layer cache

Passing past_key_values= into LlamaDecoderLayer may break on some Transformers versions that still use past_key_value. Confirm the minimum version you support; otherwise gate by version or forward both via kwargs filtering.

I can generate a small gate to support both signatures if needed.


425-435: Good device selection for offline path

Reading device from lm_head when base model has 0 layers avoids attribute errors. Nice.


332-334: Fix NameError: import OfflineEagleDMRegistry

Decorator references OfflineEagleDMRegistry but it’s not imported; this fails at import time.

Apply:

-from ..eagle.conversion import EagleDMRegistry
+from ..eagle.conversion import EagleDMRegistry, OfflineEagleDMRegistry
examples/speculative_decoding/eagle_utils.py (1)

353-381: Collator extension for offline HS looks good

Builds batched hidden/aux tensors aligned to max HS length; composes with base batch cleanly.

examples/speculative_decoding/main.py (2)

150-158: Overriding num_hidden_layers=0: verify supported model families/versions

This trick can work but is not universally guaranteed across HF architectures/versions. Since you store num_orig_hidden_layers, the offline path is consistent; please confirm on the minimum Transformers version in setup and the families you target (Llama, Mistral, Qwen2, Phi‑3).


187-206: Correct use of EagleConfig.eagle_offline

Setting config["eagle_offline"]=True wires the offline path coherently through conversion. LGTM.

examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)

87-96: Nice: async download with cache and deterministic IDs

End-to-end flow is clear and robust; IDs are stable via SHA-256 of messages.

examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)

1-15: License header looks good.

SPDX identifiers and copyright block match project norms.

Comment on lines +112 to +117
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid KeyError when 'conversations' is missing.

Indexing with entry["conversations"] will crash; use .get() and validate.

-        conversations = entry["conversations"]
+        conversations = entry.get("conversations")
         if not conversations or not isinstance(conversations, list):
             num_invalid += 1
             continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry.get("conversations")
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 112 to 117, replace the unsafe indexing entry["conversations"] with
a safe lookup and validation: use entry.get("conversations") into the
conversations variable, then check if conversations is truthy and
isinstance(conversations, list); if not, increment num_invalid and continue.
Ensure you do not assume the key exists and handle None or non-list types
consistently.

Comment on lines +1 to +3
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix split suffix/filenames; add shebang and strict mode.

As written, split likely creates part-0..7.jsonl, but the loop reads part-00..07.jsonl. Also add shebang/safety.

+#!/usr/bin/env bash
+set -euo pipefail
@@
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
+OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+mkdir -p "$OUTPUT_DIR"
@@
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+split -n l/8 -d -a 2 --numeric-suffixes=00 --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part-
@@
-for i in $(seq 0 7)
-do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
-done
+for i in $(seq 0 7); do
+  part="/tmp/part-$(printf '%02d' "$i").jsonl"
+  CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+    --model meta-llama/Llama-3.2-1B-Instruct \
+    --input-file "$part" \
+    --output-dir "$OUTPUT_DIR" &
+done
 wait
@@
-rm /tmp/part-*.jsonl
+rm -f /tmp/part-*.jsonl

Also applies to: 25-36

🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh
around lines 1-3 (and similarly lines 25-36), the script lacks a shebang and
strict shell settings, and the split/loop filename suffixes are inconsistent:
split produces single-digit suffixes but the loop expects two-digit names. Add a
shebang (#!/usr/bin/env bash) and enable strict mode (set -euo pipefail), change
the split invocation to produce numeric two-digit suffixes and a .jsonl
extension (use split -d -a 2 --additional-suffix=.jsonl <...> part-), and update
the consumer loop to either glob part-*.jsonl or explicitly format two-digit
indices (e.g., part-00.jsonl..part-07.jsonl) so names match; apply the same
changes to the block at lines 25-36.

Comment on lines +1 to +3
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add shebang and safe shell flags; make script executable-friendly.

Without a shebang, shells/tools can misinterpret the file; enable strict mode.

+#!/usr/bin/env bash
+set -euo pipefail
@@
-python3 collect_hidden_states/compute_hidden_states_hf.py \
+python3 collect_hidden_states/compute_hidden_states_hf.py \
   --model meta-llama/Llama-3.2-1B-Instruct \
   --input-file synthetic_conversations/daring-anteater.jsonl \
   --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

Also applies to: 16-23

🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh
around lines 1-3 (and similarly apply to lines 16-23), the script lacks a
shebang and safe shell flags; add a top-line shebang (e.g., #!/usr/bin/env bash)
and enable strict mode by setting set -euo pipefail and a safe IFS (IFS=$'\n\t')
at the top of the file; ensure the file mode is executable (chmod +x) in the
repository or note it in the commit so the script runs robustly in CI and user
environments.

Comment on lines +214 to +221
# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]

# Make sure the input_ids have the same shape
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Load offline tensors on CPU to avoid GPU spikes

torch.load(offline_file_path) may move tensors to GPU if saved with device tags, causing dataloader-side GPU OOMs.

-        offline_data = torch.load(offline_file_path)
+        offline_data = torch.load(offline_file_path, map_location="cpu")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
# Make sure the input_ids have the same shape
# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path, map_location="cpu")
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
# Make sure the input_ids have the same shape
🤖 Prompt for AI Agents
In examples/speculative_decoding/eagle_utils.py around lines 214 to 221,
torch.load(offline_file_path) can deserialize tensors onto their saved device
and risk moving them to GPU; change the load to explicitly map tensors to CPU
(e.g., use torch.load(..., map_location="cpu")) and, if needed, ensure any
tensors used later are .to(device) at the appropriate time so the dataloader
never accidentally allocates GPU memory during loading.

@h-guo18 h-guo18 force-pushed the bchislett/offline-eagle-training branch from bf8d1a8 to 338d0e3 Compare September 22, 2025 22:27
benchislett and others added 6 commits September 22, 2025 22:33
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
…rsation_id

Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
@h-guo18 h-guo18 force-pushed the bchislett/offline-eagle-training branch from 338d0e3 to 5add4cb Compare September 22, 2025 22:33
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/speculative_decoding/launch_train.sh (1)

91-95: Division by zero when no CUDA devices.

GPU_COUNT can be 0; DEFAULT_SAVE_STEPS=$((8192 / 0)) will abort the script on machines without visible GPUs (or before CUDA init in some CI).

Harden with a guard:

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
+if [[ -z "$GPU_COUNT" || "$GPU_COUNT" -le 0 ]]; then
+  GPU_COUNT=1
+fi
+DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
examples/speculative_decoding/eagle_utils.py (1)

316-341: Fix NameError in collator padding helpers and pad labels with IGNORE_TOKEN_ID

intensors → intensors causes a runtime NameError. Also, labels are padded with 0; use IGNORE_TOKEN_ID to avoid loss on pad tokens. Adds truncation guards too.

-    def paddingtensor2d(self, intensors, length):
-        n, dim = intensors.shape
-        padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
-        outtensors = torch.cat((intensors, padding_tensor))
-        return outtensors
+    def paddingtensor2d(self, intensors, length):
+        n, dim = intensors.shape
+        if length < n:
+            return intensors[:length]
+        padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
+        outtensors = torch.cat((intensors, padding_tensor))
+        return outtensors
-
-    def paddingtensor(self, intensors, length):
-        padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype)
-        outtensors = torch.cat((intensors, padding_tensor))
-        return outtensors
+    def paddingtensor(self, intensors, length, pad_value=0):
+        if length < intensors.shape[0]:
+            return intensors[:length]
+        padding_tensor = torch.full((length - intensors.shape[0],), pad_value, dtype=intensors.dtype)
+        outtensors = torch.cat((intensors, padding_tensor))
+        return outtensors
@@
-        batch_input_ids = torch.stack(
-            [self.paddingtensor(item["input_ids"], max_length) for item in features]
-        )
+        pad_id = 0  # collator lacks tokenizer context; 0 is typical for LLaMA-family
+        batch_input_ids = torch.stack(
+            [self.paddingtensor(item["input_ids"], max_length, pad_id) for item in features]
+        )
@@
-        batch_attention_mask = torch.stack(
-            [self.paddingtensor(item["attention_mask"], max_length) for item in features]
-        )
+        batch_attention_mask = torch.stack(
+            [self.paddingtensor(item["attention_mask"], max_length, 0) for item in features]
+        )
@@
-        batch_loss_mask = torch.stack(
-            [self.paddingtensor(item["loss_mask"], max_length) for item in features]
-        )
+        batch_loss_mask = torch.stack(
+            [self.paddingtensor(item["loss_mask"], max_length, 0) for item in features]
+        )
@@
-        batch_labels = torch.stack(
-            [self.paddingtensor(item["labels"], max_length) for item in features]
-        )
+        batch_labels = torch.stack(
+            [self.paddingtensor(item["labels"], max_length, IGNORE_TOKEN_ID) for item in features]
+        )
🧹 Nitpick comments (14)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (2)

25-27: Quote vars and ensure output directory exists.

Prevent word-splitting; create the output directory.

-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE="synthetic_conversations/daring-anteater.jsonl"
+OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/"
+mkdir -p "$OUTPUT_DIR"

36-36: Rely on trap for cleanup; avoid risky rm patterns.

The trap above safely removes the entire temp directory.

-rm /tmp/part-*.jsonl
+# cleanup handled by trap
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)

16-19: Parameterize paths, quote, and ensure output dir exists.

Improves portability and robustness; keeps the example concise.

-# Example usage of the script to compute the hidden states for a conversation dataset
+# Example usage of the script to compute the hidden states for a conversation dataset
@@
-python3 collect_hidden_states/compute_hidden_states_hf.py \
-  --model meta-llama/Llama-3.2-1B-Instruct \
-  --input-file synthetic_conversations/daring-anteater.jsonl \
-  --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+MODEL="${MODEL:-meta-llama/Llama-3.2-1B-Instruct}"
+INPUT_FILE="${INPUT_FILE:-synthetic_conversations/daring-anteater.jsonl}"
+OUTPUT_DIR="${OUTPUT_DIR:-/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/}"
+mkdir -p "$OUTPUT_DIR"
+
+# Resolve Python entry relative to this script for robustness
+SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
+python3 "$SCRIPT_DIR/compute_hidden_states_hf.py" \
+  --model "$MODEL" \
+  --input-file "$INPUT_FILE" \
+  --output-dir "$OUTPUT_DIR"
examples/speculative_decoding/launch_train.sh (4)

33-36: CLI flag name is slightly misleading vs downstream arg; consider aliasing.

Downstream you pass --offline-data-path to main.py. Here the user-facing flag is --offline-data. Consider supporting both for clarity:

  • Keep --offline-data (backward-compat)
  • Add an alias --offline-data-path (sets the same var)

Apply this minimal alias:

-    --offline-data*)
+    --offline-data*|--offline-data-path*)
       if [[ "$1" != *=* ]]; then shift; fi
       OFFLINE_DATA_PATH="${1#*=}"
       ;;

81-84: Invalid-argument error prints the value, not the flag.

For inputs like --foo=bar, message shows “bar” instead of “--foo”. Print $1 directly.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"

146-171: Quoting/arg safety when paths contain spaces.

You already quote OFFLINE_DATA_PATH and DATA in the proposed patches. For consistency, consider quoting other path-like vars (MODEL, OUTPUT_DIR) or build CMD as an array to avoid word-splitting. Non-blocking.


107-109: Unused vars (REDRAFTER_*).

REDRAFTER_TOKENS and REDRAFTER_NUM_LAYERS aren’t used in this script. Remove or wire them, or add a comment if they’re placeholders for future flags.

examples/speculative_decoding/eagle_utils.py (2)

35-35: Guard chat_template mutation

Some tokenizers may lack chat_template or it may be None; also consider not mutating per-call.

-    tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
+    if getattr(tokenizer, "chat_template", None):
+        tokenizer.chat_template = tokenizer.chat_template.replace(
+            REMOVE_THINK_CHAT_TEMPLATE, ""
+        )

Optional follow-up: move this one-time adjustment to dataset init to avoid repeated mutation.


294-297: Shuffle before split to avoid ordering bias

Current 95/5 split preserves source order; shuffle with a seed.

-        num_train = int(len(valid_entries) * 0.95)
+        # Optional: shuffle to avoid order bias
+        import random
+        rnd = random.Random(getattr(data_args, "seed", 42))
+        rnd.shuffle(valid_entries)
+        num_train = int(len(valid_entries) * 0.95)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (2)

1-1: Add a shebang (and strict mode) so the script runs as a shell script and satisfies ShellCheck.

Without a shebang (SC2148), shells may guess incorrectly; strict mode avoids silent failures.

+#!/usr/bin/env bash
+set -euo pipefail

19-19: Make invocation robust regardless of the current working directory.

Resolve the Python script relative to this file’s location to avoid path errors.

-python3 collect_hidden_states/send_conversations_for_hiddens.py \
+python3 "$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)/send_conversations_for_hiddens.py" \
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1)

189-191: Remove redundant continue.

Control flow already proceeds to the next iteration.

-        num_success += 1
-        continue
+        num_success += 1
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (2)

87-96: Resolve output directory path for consistency.

Normalize --output-dir (supports ~, avoids relative path surprises).

Apply this diff:

 async def main(args: argparse.Namespace) -> None:
@@
     if not args.sharegpt_file.exists():
         err_msg = f"ShareGPT file {args.sharegpt_file} does not exist."
         raise FileNotFoundError(err_msg)
 
+    # Normalize output dir
+    args.output_dir = args.output_dir.expanduser().resolve()

62-83: PII/safety filter hook (optional).

ShareGPT "unfiltered" may contain PII/toxicity. Consider an optional sanitizer callback applied to msgs before hashing/emit.

I can add a lightweight regex-based scrubber and a --drop-pii flag if you want.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bf8d1a8 and 5add4cb.

📒 Files selected for processing (21)
  • examples/speculative_decoding/.gitignore (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/__init__.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1 hunks)
  • examples/speculative_decoding/eagle_utils.py (4 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • examples/speculative_decoding/main.py (4 hunks)
  • examples/speculative_decoding/prepare_input_conversations/__init__.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/utils.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (4 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (12)
  • examples/speculative_decoding/prepare_input_conversations/init.py
  • examples/speculative_decoding/.gitignore
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
  • modelopt/torch/speculative/plugins/transformers.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/prepare_input_conversations/utils.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
  • examples/speculative_decoding/collect_hidden_states/init.py
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T14:13:57.091Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py:154-166
Timestamp: 2025-09-17T14:13:57.091Z
Learning: The `apply_chat_template` method in HuggingFace transformers has `tokenize=True` as the default parameter, meaning it returns token IDs even when `return_tensors=None` is used (returns as Python list).

Applied to files:

  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
🧬 Code graph analysis (4)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
  • dataset_splits_explanation (157-172)
  • download_file (26-35)
  • id_for_conversation (38-42)
  • update_dataset_file_with_conversations (127-154)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)
  • main (82-171)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1)
  • main (49-84)
modelopt/torch/speculative/utils.py (1)
  • tokenize (212-229)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
  • train (117-271)
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
🔇 Additional comments (18)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)

19-25: Fix broken relative paths; parameterize and create output dir.

This file lives in .../prepare_input_conversations/, so prefixing with prepare_input_conversations/ points to a non‑existent nested directory. Also ensure outputs land in a known dir. Flag name --output-split-name looks correct.

-python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train
+OUTPUT_DIR="${OUTPUT_DIR:-data}"
+mkdir -p "$OUTPUT_DIR"
+python3 "$DIR/add_daring_anteater.py" --output-split-name train --output-dir "$OUTPUT_DIR"
@@
-python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test
+python3 "$DIR/add_mtbench.py" --output-split-name mix_test --output-dir "$OUTPUT_DIR"

To confirm flags exist in the repo, run:

#!/bin/bash
rg -n -C2 -g 'examples/**' -e 'add_argument\(\s*"--output-dir"' -e 'add_argument\(\s*"--output-split-name"'
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (3)

1-15: Add shebang and strict mode; make script executable.

Without a shebang, shells/tools guess the interpreter; add safe flags. This also addresses SC2148.

Apply:

+#!/usr/bin/env bash
+set -euo pipefail
+IFS=$'\n\t'

Also ensure the file is executable (chmod +x).


28-28: Make split suffix length explicit; use temp dir and trap for cleanup.

Avoid filename mismatches and ensure cleanup even on failure.

-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+# temp workspace for parts
+TMPDIR="$(mktemp -d)"
+trap 'rm -rf "$TMPDIR"' EXIT
+PREFIX="$TMPDIR/part-"
+# detect GPU count or default to 8
+NGPUS="${NGPUS:-8}"
+if command -v nvidia-smi >/dev/null 2>&1; then
+  NGPUS=$(nvidia-smi -L | wc -l | tr -d ' ')
+fi
+split -n "l/${NGPUS}" -d -a 2 --numeric-suffixes=00 --additional-suffix=.jsonl "$INPUT_FILE" "$PREFIX"

30-34: Robust parallel launch with zero-padded parts and per‑PID status check.

Fail fast if any worker fails; avoid silent errors.

-for i in $(seq 0 7)
-do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
-done
-wait
+PIDS=()
+for ((i=0; i<NGPUS; i++)); do
+  part="${PREFIX}$(printf '%02d' "$i").jsonl"
+  [ -s "$part" ] || continue
+  CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+    --model meta-llama/Llama-3.2-1B-Instruct \
+    --input-file "$part" \
+    --output-dir "$OUTPUT_DIR" &
+  PIDS+=("$!")
+done
+
+ec=0
+for pid in "${PIDS[@]}"; do
+  if ! wait "$pid"; then ec=1; fi
+done
+exit "$ec"
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)

1-3: Add shebang and strict mode.

Executable script should declare interpreter and fail fast (SC2148).

+#!/usr/bin/env bash
+set -euo pipefail
+IFS=$'\n\t'
examples/speculative_decoding/launch_train.sh (2)

111-113: Default seq len bumped to 2048; confirm headroom.

At 2048, memory can be tight on smaller GPUs depending on batch/heads. Fine by default; just confirm your example configs fit typical 24–48GB cards. If not, document an override in README/docs.


138-142: Accelerate --multi_gpu flag: confirm support with your pinned Accelerate version.

Some setups rely on config/num_processes rather than --multi_gpu. If your CI/README pins a version that supports it, fine; otherwise derive processes from NUM_GPU and pass --num_processes.

Would you like me to add a version check and fall back to --num_processes "$NUM_GPU" if --multi_gpu isn’t supported?

examples/speculative_decoding/eagle_utils.py (5)

270-281: Conversation ID handling: confirm naming and types

If offline files are named with zero-padding (e.g., 00000000.pt) or IDs are ints vs strings, this strict match will drop samples.

  • Are offline files named exactly f"{conv_id}.pt" without padding or casting differences?
  • If not, normalize: conv_id = str(conv_id) and/or support zero-padded matches. I can provide a small normalizer if needed.

353-380: Validate hidden_size consistency across batch

Hidden-state second dims must match to stack cleanly; add cheap checks.

     def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
         base_batch = super().__call__(features)
         if "kwargs" not in features[0]:
             raise ValueError("No kwargs found in batch features. Offline data required.")
 
         features = [item["kwargs"]["base_model_outputs"] for item in features]
+        hs_dim = features[0]["base_model_hidden_states"].shape[1]
+        if any(f["base_model_hidden_states"].shape[1] != hs_dim for f in features):
+            raise ValueError("Inconsistent base_model_hidden_states hidden_size in batch.")
+        if any(f["aux_hidden_states"].shape[1] != hs_dim for f in features):
+            raise ValueError("Inconsistent aux_hidden_states hidden_size in batch.")

257-265: Offline data wiring looks good

Good gating, path assertion, and clear user feedback for offline mode.


272-279: LGTM on explicit None checks for conversation_id

This addresses the prior “conv_id=0” pitfall from earlier feedback.


199-235: CPU-load offline tensors, validate keys, and truncate preprocessed base before shape check

Avoid dataloader-side GPU allocations and make shape checks robust when generation/training max lengths differ.

-        # Extend the data sample with the hidden states from the .pt file
-        max_length = self.tokenizer.model_max_length
-        offline_data = torch.load(offline_file_path)
-        offline_data["input_ids"] = offline_data["input_ids"][:max_length]
-        offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
-        offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
+        # Extend the data sample with the hidden states from the .pt file
+        max_length = self.tokenizer.model_max_length
+        # Truncate base fields before validation (do not mutate cache)
+        preprocessed_base_trunc = {
+            "input_ids": preprocessed_base["input_ids"][:max_length],
+            "labels": preprocessed_base["labels"][:max_length],
+            "attention_mask": preprocessed_base["attention_mask"][:max_length],
+            "loss_mask": preprocessed_base["loss_mask"][:max_length],
+        }
+        offline_data = torch.load(offline_file_path, map_location="cpu")
+        required_keys = {"input_ids", "hidden_states", "aux_hidden_states"}
+        if not required_keys.issubset(offline_data):
+            missing = required_keys.difference(offline_data.keys())
+            raise ValueError(f"Offline data at {offline_file_path} missing keys: {missing}")
+        offline_data["input_ids"] = offline_data["input_ids"][:max_length]
+        offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
+        offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
@@
-        if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape:
+        if preprocessed_base_trunc["input_ids"].shape != offline_data["input_ids"].shape:
             msg = f"""Input IDs from offline data do not match the preprocessed input IDs
                                 for offline data sample at {offline_file_path}."""
             raise ValueError(msg)
@@
-        ret = {**preprocessed_base}  # Shallow copy so we don't accidentally modify the cache
+        ret = {**preprocessed_base_trunc}  # Shallow copy so we don't accidentally modify the cache
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)

105-111: LGTM: tokenizer setup and BOS presence check.

Pad fallback and BOS enforcement are sensible for this workflow.


152-163: Avoid leaking /tmp/meta.json on skips; guard BOS check.

When skipping for length, the temp meta file persists and can confuse the watcher. Also guard against empty inputs.

         num_input_tokens = len(input_ids)
         if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
             num_too_long += 1
+            # Clean up meta file if we decide to skip
+            if temp_meta_file.exists():
+                temp_meta_file.unlink()
             continue
-        if input_ids[0] == bos_token_id:
+        if input_ids and input_ids[0] == bos_token_id:
             # Remove the leading BOS token. vLLM's completion generation
             # endpoint will prepend the BOS token automatically.
             input_ids = input_ids[1:]
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)

21-21: Fix incorrect input path: mtbench.jsonl lives under input_conversations by default.

Current path breaks after gen_synthetic_conversations removal; align with add_mtbench.py output or pre-generate the file.

-  --input-file synthetic_conversations/mtbench.jsonl \
+  --input-file input_conversations/mtbench.jsonl \
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (3)

62-85: Robustify role mapping; avoid halting on unknown roles.

ShareGPT variants sometimes use assistant and may include rare tool/function roles. Raising on unknown roles aborts the whole run. Prefer skipping such conversations; also coalesce assistant.
[Suggest_recommended_refactor]
Apply this diff:

 def parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None:
@@
-    for turn in sharegpt_conv.get("conversations", []):
-        if turn.get("from") in ["human", "user"]:
+    for turn in sharegpt_conv.get("conversations", []):
+        sender = turn.get("from")
+        if sender in ("human", "user"):
             role = "user"
-        elif turn.get("from") in ["gpt", "chatgpt", "bard"]:
+        elif sender in ("gpt", "chatgpt", "assistant", "bard"):
             role = "assistant"
-        elif turn.get("from") == "system":
+        elif sender == "system":
             # ShareGPT system messages are metadata, skip them
             continue
-        elif turn.get("from") == "bing":
+        elif sender == "bing":
             # Bing conversations are skipped for training, omit it
             return None
         else:
-            err_msg = f"Unknown role in conversation: {turn.get('from')}"
-            raise ValueError(err_msg)
+            # Unknown role in this conversation; skip the whole sample to be safe.
+            return None
@@
-        value = turn.get("value", "").strip()
+        value = (turn.get("value") or "").strip()
         if value:
             msgs.append({"role": role, "content": value})

127-131: Entry-point wiring looks good.

Async orchestration is correct and self-contained.


30-30: Validate dataset URL stability (optional).

HF “resolve/main” links occasionally move. Consider a fallback via the datasets library (load_dataset(..., streaming=True)) if download fails.

Would you like me to wire an optional --use-hf-datasets path that streams directly from the hub?

@h-guo18
Copy link
Contributor

h-guo18 commented Sep 22, 2025

/ok to test 5add4cb

@h-guo18 h-guo18 enabled auto-merge (squash) September 22, 2025 23:17
@h-guo18 h-guo18 merged commit add61db into main Sep 22, 2025
26 checks passed
@h-guo18 h-guo18 deleted the bchislett/offline-eagle-training branch September 22, 2025 23:54
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Co-authored-by: h-guo18 <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants