Skip to content

Conversation

@h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Oct 10, 2025

What does this PR do?

Type of change: New feature

Overview:

Added a slurm multi-node script for TRTLLM eagle hidden states dumper.

Usage

sbatch --array=0-n slurm_dumper.sh

# Add a code snippet demonstrating how to use this

Testing

Tested on HSG cluster;

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Load JSONL via HuggingFace datasets (single file or directory).
    • Distributed sharding across data-parallel ranks.
    • Skip already-processed conversations; optional debug cap; progress bar reflects dataset size.
    • Added SLURM submission script for containerized, distributed hidden-state collection.
  • Refactor

    • Dataset-driven iteration and filtering; tokenization input capped at 256 tokens; final summary reflects dataset length.
  • Chores

    • Simplified per-GPU runner to direct Python invocation, removing an alternate launcher.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 10, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 10, 2025

Walkthrough

Replaces file-based JSONL ingestion with HuggingFace Datasets loading (file or directory), adds DP sharding and filtering of already-dumped conversations, caps tokenization to 256 tokens, updates progress/success reporting, simplifies local DP launcher, and adds a new SLURM submission script using containerized trtllm-llmapi-launch.

Changes

Cohort / File(s) Summary
Hidden states computation refactor
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
Replace file-based JSONL ingestion with HuggingFace Datasets (single file or directory), add dp_rank/dp_world_size sharding, filter out conversations whose output .pt already exists (by conversation_id/uuid), iterate over dataset rather than in-memory list, optional debug cap, constrain tokenization to max 256 tokens, and update progress and completion messaging to use dataset length.
Local DP runner
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
Remove SLURM-specific trtllm-llmapi-launch/IPC launcher line; retain per-GPU direct python3 compute_hidden_states_trtllm.py invocations and final wait/cleanup logic.
SLURM submission script (new)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
Add SBATCH-configured script that prepares container image, mounts input/output/model paths, derives TP/DP from SLURM array indices, and runs the dumper inside the container via trtllm-llmapi-launch using srun with PMIx/overlap and logging.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Compute as compute_hidden_states_trtllm.py
  participant HFDS as HuggingFace Datasets
  participant Tokenizer
  participant TRTLLM as TRT-LLM Runtime
  participant FS as Filesystem

  User->>Compute: launch with --input (file/dir), --output, --dp_rank/world_size
  Compute->>HFDS: load_dataset(input)
  Compute->>HFDS: shard(dp_rank, dp_world_size)
  Compute->>HFDS: filter(existing .pt by conversation_id/uuid)
  loop per conversation (dataset)
    Compute->>Tokenizer: tokenize(prompt, max_length=256)
    Tokenizer-->>Compute: input_ids
    Compute->>TRTLLM: run forward to collect hidden states
    TRTLLM-->>Compute: hidden states
    Compute->>FS: save hidden states as `.pt`
  end
  Compute-->>User: print completion with processed count (len(dataset))
Loading
sequenceDiagram
  autonumber
  participant SLURM as SLURM Array Task
  participant srun
  participant Container
  participant Launcher as trtllm-llmapi-launch
  participant Compute as compute_hidden_states_trtllm.py

  SLURM->>SLURM: derive TP/DP from ARRAY_TASK_ID/COUNT
  SLURM->>srun: invoke with container image, mounts, and env
  srun->>Container: start shell
  Container->>Launcher: start TRT-LLM server/context
  Launcher->>Compute: execute with model/input/output + parallel opts
  Compute-->>Launcher: exit status
  Launcher-->>Container: stop
  Container-->>srun: exit
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I nibbled through JSONL fields with care,
Sharded my snacks so each bun gets a share.
If a carrot’s saved, I hop right past—
New crates appear, I pack them fast.
On SLURM winds I launch and zoom—hidden states gathered, thump-thump, boom! 🥕✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title precisely highlights the main enhancement of adding multi-node support to the hidden states dumper within the Eagle Offline framework, aligning with the introduction of the SLURM script and associated updates in this pull request.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch haoguo/multinode-dumper

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 marked this pull request as ready for review October 10, 2025 01:35
@h-guo18 h-guo18 requested a review from a team as a code owner October 10, 2025 01:35
@h-guo18 h-guo18 changed the title slurm multinode support for trtllm dumper feat: slurm multinode support for trtllm dumper Oct 10, 2025
Signed-off-by: h-guo18 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)

261-272: Fix prompt tokenization before handing to TRTLLM.

Line 261 currently slices the string returned by tokenizer.apply_chat_template, so we end up truncating by characters, compute “token” length on characters, and then pass a string into generate_async. That breaks token limits and can crash depending on backend expectations. Convert to real token IDs (and cap after tokenization) before dispatch.

-            input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
-                :256
-            ]
-            num_input_tokens = (
-                input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
-            )
+            encoded = tokenizer.apply_chat_template(
+                conversations,
+                add_generation_template=False,
+                tokenize=True,
+                return_tensors="pt",
+            )
+            input_ids = encoded["input_ids"][0][:256]
+            num_input_tokens = input_ids.numel()
@@
-            tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
+            tasks.append(dump_hidden_states(idx, conversation_id, input_ids.tolist()))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b02483 and 4dce25f.

📒 Files selected for processing (3)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (8 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh (0 hunks)
  • examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (1 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh

[warning] 29-29: Use "${var:?}" to ensure this never expands to /* .

(SC2115)


[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?

(SC2140)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Comment on lines +166 to 168
if args.debug_max_num_conversations is not None:
dataset = dataset.select(range(args.debug_max_num_conversations))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard the debug cap against short datasets.

If --debug-max-num-conversations exceeds len(dataset), dataset.select(range(...)) raises an IndexError. Clamp the range to the dataset length so the debug flag remains safe to use.

-        dataset = dataset.select(range(args.debug_max_num_conversations))
+        limit = min(args.debug_max_num_conversations, len(dataset))
+        dataset = dataset.select(range(limit))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if args.debug_max_num_conversations is not None:
dataset = dataset.select(range(args.debug_max_num_conversations))
if args.debug_max_num_conversations is not None:
limit = min(args.debug_max_num_conversations, len(dataset))
dataset = dataset.select(range(limit))
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 166 to 168, the debug cap uses
dataset.select(range(args.debug_max_num_conversations)) which will raise
IndexError if the arg exceeds len(dataset); clamp the requested count to the
dataset length before selecting (e.g., compute count =
min(args.debug_max_num_conversations, len(dataset)) and call
dataset.select(range(count)) or skip selection if count <= 0) so the debug flag
is safe for shorter datasets.

fi

MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model
rm -rf $DUMP_DIR/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Harden the cleanup step before launching.

rm -rf $DUMP_DIR/* is dangerous—if $DUMP_DIR is unset/empty the glob resolves to / and nukes the host. Add the usual ${var:?} guard (or bail out explicitly) before wiping the directory.

-rm -rf $DUMP_DIR/*
+rm -rf "${DUMP_DIR:?}/"*
🧰 Tools
🪛 Shellcheck (0.11.0)

[warning] 29-29: Use "${var:?}" to ensure this never expands to /* .

(SC2115)

🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
29, the cleanup command uses an unguarded rm -rf $DUMP_DIR/* which can delete
root if DUMP_DIR is empty; update the script to validate or bail if DUMP_DIR is
unset/empty (e.g., use the parameter expansion guard ${DUMP_DIR:?} before using
it) or add an explicit check that DUMP_DIR is non-empty and points to the
expected path, then perform a safe removal only after the guard/check passes.

Comment on lines +35 to +41
RUN_DUMPER="export TLLM_LOG_LEVEL="error";
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
--model /model \
--input-data /input/ \
--output-dir /output \
$PARALLEL_ARGS \
"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Repair RUN_DUMPER quoting so the script actually runs.

The current assignment ends the string at the inner "error"; the remainder is interpreted as separate tokens, so the script fails before launching anything. Use a heredoc (or escape the inner quotes) to produce a well-formed command string.

-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
-trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
-  --model /model \
-  --input-data /input/ \
-  --output-dir /output \
-  $PARALLEL_ARGS \
-  "
+read -r -d '' RUN_DUMPER <<EOF
+export TLLM_LOG_LEVEL="error"
+trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
+  --model /model \
+  --input-data /input/ \
+  --output-dir /output \
+  $PARALLEL_ARGS
+EOF
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
RUN_DUMPER="export TLLM_LOG_LEVEL="error";
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
--model /model \
--input-data /input/ \
--output-dir /output \
$PARALLEL_ARGS \
"
read -r -d '' RUN_DUMPER <<EOF
export TLLM_LOG_LEVEL="error"
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
--model /model \
--input-data /input/ \
--output-dir /output \
$PARALLEL_ARGS
EOF
🧰 Tools
🪛 Shellcheck (0.11.0)

[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?

(SC2140)

🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 35 to 41, the RUN_DUMPER string is prematurely terminated by the inner
"error" quotes, breaking the command; fix it by constructing a single
well-formed string that contains the TLLM_LOG_LEVEL assignment and the
subsequent command — either escape the inner quotes (e.g., \"error\") or use a
heredoc or single-quoted wrapper so the whole export and trtllm-llmapi-launch
python invocation are part of the RUN_DUMPER value, preserving $PARALLEL_ARGS
and newlines as needed.

@h-guo18 h-guo18 changed the title feat: slurm multinode support for trtllm dumper [Eagle Offline] multinode support for hidden states dumper Oct 10, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (2)

82-92: Clarify SLURM variable names in help text.

The help text mentions "TASK_ID on SLURM" and "Number of tasks on SLURM" but could be more precise. Consider updating to reference the specific SLURM environment variables (SLURM_ARRAY_TASK_ID and SLURM_ARRAY_TASK_COUNT) for clarity.

Apply this diff:

     parser.add_argument(
         "--dp-rank",
         type=int,
         default=0,
-        help="""Data parallel rank. TASK_ID on SLURM.""",
+        help="""Data parallel rank. Set to SLURM_ARRAY_TASK_ID when using SLURM arrays.""",
     )
     parser.add_argument(
         "--dp-world-size",
         type=int,
         default=1,
-        help="""Data parallel world size. Number of tasks on SLURM.""",
+        help="""Data parallel world size. Set to SLURM_ARRAY_TASK_COUNT when using SLURM arrays.""",
     )

283-286: Success reporting may be misleading with skipped conversations.

The success message compares num_success against len(dataset), but this doesn't account for conversations skipped due to invalid data or length constraints. This could be confusing when debugging failed runs.

Consider more accurate reporting:

+expected_success = len(dataset) - num_invalid - num_skipped_too_long
-if num_success == len(dataset):
-    print(f"Successfully processed all {num_success} conversations.")
+if num_success == expected_success:
+    print(f"Successfully processed all {num_success} valid conversations.")
 else:
-    print(f"Successfully processed {num_success} out of {len(dataset)} conversations.")
+    print(f"Successfully processed {num_success} out of {expected_success} valid conversations "
+          f"({num_invalid} invalid, {num_skipped_too_long} skipped due to length).")
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (3)

8-12: Make account/job-name more configurable.

The SBATCH account and job-name are hardcoded to NVIDIA-internal values. Consider either:

  1. Using placeholder variables like <YOUR_ACCOUNT> and <YOUR_JOB_NAME> to match the pattern used for INPUT_DIR, DUMP_DIR, etc.
  2. Adding a comment explicitly stating these must be updated by users.

This would make the script more clearly a template that requires customization.

Apply this diff:

-#SBATCH -A coreai_dlalgo_modelopt
-#SBATCH --job-name=coreai_dlalgo_modelopt-mcore.modelopt
+#SBATCH -A <YOUR_ACCOUNT>
+#SBATCH --job-name=<YOUR_JOB_NAME>
 #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
 #SBATCH -p batch
 #SBATCH -t 04:00:00

17-17: Consider making the container version configurable.

The container version is hardcoded to 1.2.0rc0, which is a release candidate. Consider:

  1. Moving this to a variable at the top of the script for easy updates
  2. Updating to a stable release if available
  3. Adding a comment about compatible versions

Example:

+# Container version - update as needed
+CONTAINER_VERSION="1.2.0rc0"
+CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:${CONTAINER_VERSION}"
-CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0"

17-17: Consider using a stable container release instead of RC.

The script uses a release candidate container version (1.2.0rc0). For production use, consider switching to a stable release once available.

-CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0"
+# TODO: Update to stable release when 1.2.0 is GA
+CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b02483 and de4ecc7.

📒 Files selected for processing (3)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (8 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh (0 hunks)
  • examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (1 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh

[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?

(SC2140)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (9)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (6)

131-168: LGTM! Robust data loading with resumability.

The refactored data loading logic properly:

  • Supports both single files and directories of JSONL files
  • Implements DP sharding for distributed processing
  • Filters out already-processed conversations for resumability
  • Provides debug capability to limit conversation count

The conversation_id extraction handles both "conversation_id" and "uuid" fields with proper fallback and validation.


261-263: Verify the 256-token tokenization cap.

The tokenization input is hard-capped at 256 tokens, which could truncate longer conversations before they're processed. This seems inconsistent with the --max-seq-len parameter (default 3072).

Please clarify:

  1. Why is 256 tokens chosen as the cap?
  2. Should this be configurable via a command-line argument?
  3. How does this interact with the --max-seq-len validation on line 267?

If this cap is intentional and specific to the dumper's requirements, consider adding an inline comment explaining the rationale:

         input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
-            :256
+            :256  # Cap prompt to 256 tokens for hidden state collection
         ]

26-26: LGTM!

The addition of load_dataset from the datasets library is appropriate for the refactoring to HuggingFace datasets-based loading.


144-148: LGTM!

The dataset sharding logic correctly distributes data across DP ranks using the datasets library's built-in shard method, which ensures even distribution.


165-167: LGTM!

The debug cap feature is well-implemented and will be useful for testing without processing the entire dataset.


131-141: Verify load_dataset handles empty input directory Ensure that passing an empty directory (with no .jsonl files) to load_dataset("json", data_files={"train": ".../*.jsonl"}, split="train") raises a clear, descriptive error rather than returning an empty dataset.

examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (3)

42-47: LGTM! Appropriate timeout margin.

The srun configuration is well-structured:

  • The 235-minute timeout provides a 5-minute buffer before the 4-hour SBATCH limit, allowing graceful cleanup
  • PMIx MPI and overlap flags are appropriate for distributed execution
  • Container configuration correctly references the defined variables
  • Log file naming includes job metadata for tracking

8-12: LGTM!

The SBATCH directives are correctly configured for single-node, 4-GPU execution with appropriate time limits and partition settings.


19-22: LGTM!

The placeholder values are clearly marked and provide helpful guidance to users on what needs to be configured.

Comment on lines +150 to 163
# Remove already dumped conversations
def keep_conversation(entry):
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
assert conversation_id is not None, "conversation_id is required"
output_file = args.output_dir / f"{conversation_id}.pt"
if output_file.exists():
continue
filtered_conversations.append(entry)
return not output_file.exists()

original_num = len(dataset)
dataset = dataset.filter(keep_conversation)
print(
"Removed",
len(all_conversations) - len(filtered_conversations),
original_num - len(dataset),
"conversations due to existing output files",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace assertion with proper error handling.

The keep_conversation filter function uses an assertion to validate that conversation_id exists, which will cause the entire process to crash if any conversation lacks this field. In a distributed DP setting, this would fail the entire SLURM job.

Replace the assertion with proper error handling:

 def keep_conversation(entry):
     conversation_id = entry.get("conversation_id", entry.get("uuid", None))
-    assert conversation_id is not None, "conversation_id is required"
+    if conversation_id is None:
+        return False  # Skip conversations without valid ID
     output_file = args.output_dir / f"{conversation_id}.pt"
     return not output_file.exists()

Additionally, consider logging a warning when conversations are skipped due to missing IDs, similar to how other invalid conversations are tracked.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Remove already dumped conversations
def keep_conversation(entry):
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
assert conversation_id is not None, "conversation_id is required"
output_file = args.output_dir / f"{conversation_id}.pt"
if output_file.exists():
continue
filtered_conversations.append(entry)
return not output_file.exists()
original_num = len(dataset)
dataset = dataset.filter(keep_conversation)
print(
"Removed",
len(all_conversations) - len(filtered_conversations),
original_num - len(dataset),
"conversations due to existing output files",
)
# Remove already dumped conversations
def keep_conversation(entry):
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
if conversation_id is None:
# Skip conversations without valid ID (consider logging a warning here)
return False
output_file = args.output_dir / f"{conversation_id}.pt"
return not output_file.exists()
original_num = len(dataset)
dataset = dataset.filter(keep_conversation)
print(
"Removed",
original_num - len(dataset),
"conversations due to existing output files",
)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 150-163, replace the assertion that enforces a conversation_id with
non-crashing error handling: if conversation_id is missing, log a warning (or
print) indicating the skipped entry and return False from keep_conversation so
the entry is filtered out instead of crashing the job; optionally increment or
track a skipped counter for reporting, then continue to check for existing
output files and return not output_file.exists() for valid IDs.

Comment on lines +261 to +263
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
:256
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Clarify the magic number 256 for tokenization truncation.

The hardcoded limit of 256 tokens appears arbitrary and may not align with the model's actual context window or the user's --max-seq-len parameter. This could lead to confusion when conversations are unexpectedly truncated.

Consider one of the following approaches:

  1. Make this configurable via a CLI argument:
+parser.add_argument(
+    "--max-input-tokens",
+    type=int,
+    default=256,
+    help="Maximum number of tokens to use from conversation input for context."
+)
  1. Or at minimum, add a constant and a comment explaining the rationale:
+# Limit input tokens to reduce memory usage during hidden state collection
+MAX_INPUT_TOKENS = 256
+
 input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
-    :256
+    :MAX_INPUT_TOKENS
 ]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
:256
]
# Limit input tokens to reduce memory usage during hidden state collection
MAX_INPUT_TOKENS = 256
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
:MAX_INPUT_TOKENS
]
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 261 to 263, replace the hardcoded token truncation slice [:256]
with a named constant or CLI-configurable value tied to the model/context window
(e.g., use the --max-seq-len argument if available, or
model.config.max_position_embeddings minus reserved tokens) and add a comment
explaining why that limit was chosen; update the tokenizer.apply_chat_template
call to slice by that variable so truncation respects the model's actual context
length or the user-provided parameter.

#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"

RUN_DUMPER="export TLLM_LOG_LEVEL="error";
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix shell quoting issue.

The line has incorrect nested quoting that will result in the string "error" (including quotes) being assigned to TLLM_LOG_LEVEL, rather than just error.

Apply this diff:

-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
 trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
   --model /model \
   --input-data /input/ \
   --output-dir /output \
   $PARALLEL_ARGS \
   "
🧰 Tools
🪛 Shellcheck (0.11.0)

[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?

(SC2140)

🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
34, the RUN_DUMPER assignment uses nested double quotes causing the literal
string "\"error\"" to be assigned; change it so the environment variable is
assigned the plain word error (e.g., RUN_DUMPER="export TLLM_LOG_LEVEL=error;"
or RUN_DUMPER='export TLLM_LOG_LEVEL=error;') by removing the inner quotes or
using single outer quotes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems legit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChenhanYu, thank you for confirming! The fix will prevent the environment variable from receiving the literal string "error" (with quotes) and ensure it gets the plain value error as intended.

Comment on lines +34 to +40
RUN_DUMPER="export TLLM_LOG_LEVEL="error";
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
--model /model \
--input-data /input/ \
--output-dir /output \
$PARALLEL_ARGS \
"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix shell quoting syntax error.

The export statement has incorrect nested quotes that will cause a shell syntax error.

As per static analysis

Apply this diff to fix the quoting:

-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
 trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
   --model /model \
   --input-data /input/ \
   --output-dir /output \
   $PARALLEL_ARGS \
   "
🧰 Tools
🪛 Shellcheck (0.11.0)

[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?

(SC2140)

🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 34 to 40, the RUN_DUMPER assignment contains incorrectly nested double
quotes causing a shell syntax error; fix by wrapping the whole command in single
quotes (or escape the inner quotes) so the export uses TLLM_LOG_LEVEL="error"
correctly and the rest of the command string remains intact, e.g., assign
RUN_DUMPER='export TLLM_LOG_LEVEL="error"; trtllm-llmapi-launch python3
/modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
--model /model --input-data /input/ --output-dir /output $PARALLEL_ARGS' (ensure
no trailing unmatched quotes).

Comment on lines +42 to +47
timeout 235m srun -l \
--mpi=pmix --overlap \
--output=%x_%j_$DATETIME.log \
--container-image ${CONTAINER} \
--container-mounts ${MOUNTS} \
bash -c "$RUN_DUMPER"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix undefined variable in log filename.

Line 44 references $DATETIME which is never defined, causing the log filename to be incomplete or use a literal string.

Define the variable before use or remove it from the log filename:

+DATETIME=$(date +%Y%m%d_%H%M%S)
+
 timeout 235m srun -l \
     --mpi=pmix --overlap \
     --output=%x_%j_$DATETIME.log \
     --container-image ${CONTAINER} \
     --container-mounts ${MOUNTS} \
     bash -c "$RUN_DUMPER"

Alternatively, if timestamps aren't needed, simplify to:

 timeout 235m srun -l \
     --mpi=pmix --overlap \
-    --output=%x_%j_$DATETIME.log \
+    --output=%x_%j.log \
     --container-image ${CONTAINER} \
     --container-mounts ${MOUNTS} \
     bash -c "$RUN_DUMPER"
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 42 to 47 the --output pattern references $DATETIME which is never defined;
either define DATETIME before the srun call (e.g. set DATETIME to a timestamp
string) or remove $DATETIME from the --output argument so the filename does not
include an undefined variable; update the script to export or assign DATETIME
prior to using it, or simplify the --output value to rely on existing SLURM
placeholders only.

@codecov
Copy link

codecov bot commented Oct 10, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.36%. Comparing base (5b02483) to head (59d93d8).
⚠️ Report is 4 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #422   +/-   ##
=======================================
  Coverage   73.36%   73.36%           
=======================================
  Files         180      180           
  Lines       17919    17919           
=======================================
  Hits        13147    13147           
  Misses       4772     4772           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments

Signed-off-by: h-guo18 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (2)

34-40: Fix RUN_DUMPER quoting; current form is syntactically broken.

The nested quotes terminate the string at "error", so the script fails before launch. Build the command with a heredoc to avoid quoting pitfalls.

-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
-trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
-  --model /model \
-  --input-data /input/ \
-  --output-dir /output \
-  $PARALLEL_ARGS \
-  "
+RUN_DUMPER=$(cat <<EOF
+export TLLM_LOG_LEVEL=error
+trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
+  --model /model \
+  --input-data /input/ \
+  --output-dir /output \
+  $PARALLEL_ARGS
+EOF
+)

42-47: Define DATETIME (or remove it) in the srun --output pattern.

$DATETIME is undefined, producing odd filenames.

+DATETIME=$(date +%Y%m%d_%H%M%S)
 timeout 235m srun -l \
     --mpi=pmix --overlap \
     --output=%x_%j_$DATETIME.log \
     --container-image ${CONTAINER} \
-    --container-mounts ${MOUNTS} \
+    --container-mounts "${MOUNTS}" \
     bash -c "$RUN_DUMPER"
🧹 Nitpick comments (2)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (2)

19-23: Validate required paths before running.

Placeholders/defaults will cause failures or bad mounts. Add guards and existence checks.

 INPUT_DIR="<Can be directory containing the .jsonl files, or path to single .jsonl file>"
 DUMP_DIR="<Directory for output hidden states>"
 MODELOPT_DIR="<Path to Modelopt repo>"
 TEACHER_MODEL="<Path to teacher model>"
 
+# Basic validation
+for var in INPUT_DIR DUMP_DIR MODELOPT_DIR TEACHER_MODEL; do
+  val="${!var}"
+  if [[ -z "$val" || "$val" == \<* \>* ]]; then
+    echo "ERROR: $var is not set. Please edit slurm_dump.sh."
+    exit 1
+  fi
+done
+
+[[ -e "$INPUT_DIR" ]] || { echo "ERROR: INPUT_DIR not found: $INPUT_DIR"; exit 1; }
+[[ -d "$MODELOPT_DIR" ]] || { echo "ERROR: MODELOPT_DIR not found: $MODELOPT_DIR"; exit 1; }
+[[ -d "$TEACHER_MODEL" ]] || { echo "ERROR: TEACHER_MODEL not found: $TEACHER_MODEL"; exit 1; }

28-28: Quote mounts to avoid word splitting.

Protect paths with spaces/commas by quoting the value and usage.

-MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model
+MOUNTS="${INPUT_DIR}:/input,${DUMP_DIR}:/output,${MODELOPT_DIR}:/modelopt,${TEACHER_MODEL}:/model"

And in srun:

-    --container-mounts ${MOUNTS} \
+    --container-mounts "${MOUNTS}" \
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between de4ecc7 and 59d93d8.

📒 Files selected for processing (1)
  • examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh

[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?

(SC2140)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality


#SBATCH -A coreai_dlalgo_modelopt
#SBATCH --job-name=coreai_dlalgo_modelopt-generate_eagle_hidden_states
#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Align allocation with intended topology (1 proc using 4 GPUs).

You set ntasks-per-node=4 but run a single command that does TP=4. Allocate 1 task and give it 4 GPUs to avoid idle tasks and binding ambiguity.

-#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
+#SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4

Optionally, also make the step explicit:

-timeout 235m srun -l \
+timeout 235m srun -l -n 1 \
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
# In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
#SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4
timeout 235m srun -l -n 1 \
--mpi=pmix --overlap \
--output=%x_%j_$DATETIME.log \
--container-image ${CONTAINER} \
--container-mounts ${MOUNTS} \
bash -c "$RUN_DUMPER"
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
10, the SBATCH line currently specifies --ntasks-per-node=4 which creates four
tasks but you run a single process that needs 4 GPUs; change it to
--ntasks-per-node=1 --gpus-per-node=4 (and optionally add --cpus-per-task=<num>
if you need CPU binding) so one task owns all 4 GPUs, and if desired make the
launch step explicit by using srun --ntasks=1 (or sbatch step) to run the one
process.

Comment on lines +31 to +33
#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Provide safe defaults when not using an array job.

If SLURM_ARRAY_TASK_ID/COUNT are unset, dp args become empty and the Python launcher may fail. Default to a single‑rank DP.

-#By default: TP inside node, and DP across slurm array
-#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
-PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"
+# By default: TP inside node, and DP across slurm array
+# EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
+DP_RANK="${SLURM_ARRAY_TASK_ID:-0}"
+DP_WORLD="${SLURM_ARRAY_TASK_COUNT:-1}"
+PARALLEL_ARGS="--tp 4 --dp-rank ${DP_RANK} --dp-world-size ${DP_WORLD}"
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 31 to 33, SLURM_ARRAY_TASK_ID and SLURM_ARRAY_TASK_COUNT may be unset
causing PARALLEL_ARGS to be empty and the Python launcher to fail; add safe
defaults (e.g., SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:-0} and
SLURM_ARRAY_TASK_COUNT=${SLURM_ARRAY_TASK_COUNT:-1}) and construct PARALLEL_ARGS
to include "--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size
$SLURM_ARRAY_TASK_COUNT" only after applying those defaults so a non-array run
defaults to single-rank DP.

@ChenhanYu ChenhanYu self-requested a review October 10, 2025 22:53
@h-guo18 h-guo18 merged commit 40a7d24 into main Oct 10, 2025
27 checks passed
@h-guo18 h-guo18 deleted the haoguo/multinode-dumper branch October 10, 2025 22:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants