- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
[Eagle Offline] multinode support for hidden states dumper #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: h-guo18 <[email protected]>
| 
           Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here.  | 
    
          
WalkthroughReplaces file-based JSONL ingestion with HuggingFace Datasets loading (file or directory), adds DP sharding and filtering of already-dumped conversations, caps tokenization to 256 tokens, updates progress/success reporting, simplifies local DP launcher, and adds a new SLURM submission script using containerized trtllm-llmapi-launch. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant User
  participant Compute as compute_hidden_states_trtllm.py
  participant HFDS as HuggingFace Datasets
  participant Tokenizer
  participant TRTLLM as TRT-LLM Runtime
  participant FS as Filesystem
  User->>Compute: launch with --input (file/dir), --output, --dp_rank/world_size
  Compute->>HFDS: load_dataset(input)
  Compute->>HFDS: shard(dp_rank, dp_world_size)
  Compute->>HFDS: filter(existing .pt by conversation_id/uuid)
  loop per conversation (dataset)
    Compute->>Tokenizer: tokenize(prompt, max_length=256)
    Tokenizer-->>Compute: input_ids
    Compute->>TRTLLM: run forward to collect hidden states
    TRTLLM-->>Compute: hidden states
    Compute->>FS: save hidden states as `.pt`
  end
  Compute-->>User: print completion with processed count (len(dataset))
    sequenceDiagram
  autonumber
  participant SLURM as SLURM Array Task
  participant srun
  participant Container
  participant Launcher as trtllm-llmapi-launch
  participant Compute as compute_hidden_states_trtllm.py
  SLURM->>SLURM: derive TP/DP from ARRAY_TASK_ID/COUNT
  SLURM->>srun: invoke with container image, mounts, and env
  srun->>Container: start shell
  Container->>Launcher: start TRT-LLM server/context
  Launcher->>Compute: execute with model/input/output + parallel opts
  Compute-->>Launcher: exit status
  Launcher-->>Container: stop
  Container-->>srun: exit
    Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
 Pre-merge checks and finishing touches✅ Passed checks (3 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment   | 
    
Signed-off-by: h-guo18 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)
261-272: Fix prompt tokenization before handing to TRTLLM.Line 261 currently slices the string returned by
tokenizer.apply_chat_template, so we end up truncating by characters, compute “token” length on characters, and then pass a string intogenerate_async. That breaks token limits and can crash depending on backend expectations. Convert to real token IDs (and cap after tokenization) before dispatch.- input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ - :256 - ] - num_input_tokens = ( - input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) - ) + encoded = tokenizer.apply_chat_template( + conversations, + add_generation_template=False, + tokenize=True, + return_tensors="pt", + ) + input_ids = encoded["input_ids"][0][:256] + num_input_tokens = input_ids.numel() @@ - tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + tasks.append(dump_hidden_states(idx, conversation_id, input_ids.tolist()))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(8 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh(0 hunks)examples/speculative_decoding/collect_hidden_states/slurm_dump.sh(1 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
 
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
[warning] 29-29: Use "${var:?}" to ensure this never expands to /* .
(SC2115)
[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: code-quality
 - GitHub Check: build-docs
 
| if args.debug_max_num_conversations is not None: | ||
| dataset = dataset.select(range(args.debug_max_num_conversations)) | ||
| 
               | 
          
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard the debug cap against short datasets.
If --debug-max-num-conversations exceeds len(dataset), dataset.select(range(...)) raises an IndexError. Clamp the range to the dataset length so the debug flag remains safe to use.
-        dataset = dataset.select(range(args.debug_max_num_conversations))
+        limit = min(args.debug_max_num_conversations, len(dataset))
+        dataset = dataset.select(range(limit))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if args.debug_max_num_conversations is not None: | |
| dataset = dataset.select(range(args.debug_max_num_conversations)) | |
| if args.debug_max_num_conversations is not None: | |
| limit = min(args.debug_max_num_conversations, len(dataset)) | |
| dataset = dataset.select(range(limit)) | 
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 166 to 168, the debug cap uses
dataset.select(range(args.debug_max_num_conversations)) which will raise
IndexError if the arg exceeds len(dataset); clamp the requested count to the
dataset length before selecting (e.g., compute count =
min(args.debug_max_num_conversations, len(dataset)) and call
dataset.select(range(count)) or skip selection if count <= 0) so the debug flag
is safe for shorter datasets.
| fi | ||
| 
               | 
          ||
| MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model | ||
| rm -rf $DUMP_DIR/* | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Harden the cleanup step before launching.
rm -rf $DUMP_DIR/* is dangerous—if $DUMP_DIR is unset/empty the glob resolves to / and nukes the host. Add the usual ${var:?} guard (or bail out explicitly) before wiping the directory.
-rm -rf $DUMP_DIR/*
+rm -rf "${DUMP_DIR:?}/"*🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 29-29: Use "${var:?}" to ensure this never expands to /* .
(SC2115)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
29, the cleanup command uses an unguarded rm -rf $DUMP_DIR/* which can delete
root if DUMP_DIR is empty; update the script to validate or bail if DUMP_DIR is
unset/empty (e.g., use the parameter expansion guard ${DUMP_DIR:?} before using
it) or add an explicit check that DUMP_DIR is non-empty and points to the
expected path, then perform a safe removal only after the guard/check passes.
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | ||
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | ||
| --model /model \ | ||
| --input-data /input/ \ | ||
| --output-dir /output \ | ||
| $PARALLEL_ARGS \ | ||
| " | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repair RUN_DUMPER quoting so the script actually runs.
The current assignment ends the string at the inner "error"; the remainder is interpreted as separate tokens, so the script fails before launching anything. Use a heredoc (or escape the inner quotes) to produce a well-formed command string.
-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
-trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
-  --model /model \
-  --input-data /input/ \
-  --output-dir /output \
-  $PARALLEL_ARGS \
-  "
+read -r -d '' RUN_DUMPER <<EOF
+export TLLM_LOG_LEVEL="error"
+trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
+  --model /model \
+  --input-data /input/ \
+  --output-dir /output \
+  $PARALLEL_ARGS
+EOF📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | |
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model /model \ | |
| --input-data /input/ \ | |
| --output-dir /output \ | |
| $PARALLEL_ARGS \ | |
| " | |
| read -r -d '' RUN_DUMPER <<EOF | |
| export TLLM_LOG_LEVEL="error" | |
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model /model \ | |
| --input-data /input/ \ | |
| --output-dir /output \ | |
| $PARALLEL_ARGS | |
| EOF | 
🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 35 to 41, the RUN_DUMPER string is prematurely terminated by the inner
"error" quotes, breaking the command; fix it by constructing a single
well-formed string that contains the TLLM_LOG_LEVEL assignment and the
subsequent command — either escape the inner quotes (e.g., \"error\") or use a
heredoc or single-quoted wrapper so the whole export and trtllm-llmapi-launch
python invocation are part of the RUN_DUMPER value, preserving $PARALLEL_ARGS
and newlines as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (5)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (2)
82-92: Clarify SLURM variable names in help text.The help text mentions "TASK_ID on SLURM" and "Number of tasks on SLURM" but could be more precise. Consider updating to reference the specific SLURM environment variables (
SLURM_ARRAY_TASK_IDandSLURM_ARRAY_TASK_COUNT) for clarity.Apply this diff:
parser.add_argument( "--dp-rank", type=int, default=0, - help="""Data parallel rank. TASK_ID on SLURM.""", + help="""Data parallel rank. Set to SLURM_ARRAY_TASK_ID when using SLURM arrays.""", ) parser.add_argument( "--dp-world-size", type=int, default=1, - help="""Data parallel world size. Number of tasks on SLURM.""", + help="""Data parallel world size. Set to SLURM_ARRAY_TASK_COUNT when using SLURM arrays.""", )
283-286: Success reporting may be misleading with skipped conversations.The success message compares
num_successagainstlen(dataset), but this doesn't account for conversations skipped due to invalid data or length constraints. This could be confusing when debugging failed runs.Consider more accurate reporting:
+expected_success = len(dataset) - num_invalid - num_skipped_too_long -if num_success == len(dataset): - print(f"Successfully processed all {num_success} conversations.") +if num_success == expected_success: + print(f"Successfully processed all {num_success} valid conversations.") else: - print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") + print(f"Successfully processed {num_success} out of {expected_success} valid conversations " + f"({num_invalid} invalid, {num_skipped_too_long} skipped due to length).")examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (3)
8-12: Make account/job-name more configurable.The SBATCH account and job-name are hardcoded to NVIDIA-internal values. Consider either:
- Using placeholder variables like
 <YOUR_ACCOUNT>and<YOUR_JOB_NAME>to match the pattern used for INPUT_DIR, DUMP_DIR, etc.- Adding a comment explicitly stating these must be updated by users.
 This would make the script more clearly a template that requires customization.
Apply this diff:
-#SBATCH -A coreai_dlalgo_modelopt -#SBATCH --job-name=coreai_dlalgo_modelopt-mcore.modelopt +#SBATCH -A <YOUR_ACCOUNT> +#SBATCH --job-name=<YOUR_JOB_NAME> #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 #SBATCH -p batch #SBATCH -t 04:00:00
17-17: Consider making the container version configurable.The container version is hardcoded to
1.2.0rc0, which is a release candidate. Consider:
- Moving this to a variable at the top of the script for easy updates
 - Updating to a stable release if available
 - Adding a comment about compatible versions
 Example:
+# Container version - update as needed +CONTAINER_VERSION="1.2.0rc0" +CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:${CONTAINER_VERSION}" -CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0"
17-17: Consider using a stable container release instead of RC.The script uses a release candidate container version (
1.2.0rc0). For production use, consider switching to a stable release once available.-CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0" +# TODO: Update to stable release when 1.2.0 is GA +CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(8 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh(0 hunks)examples/speculative_decoding/collect_hidden_states/slurm_dump.sh(1 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
 
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: code-quality
 - GitHub Check: build-docs
 
🔇 Additional comments (9)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (6)
131-168: LGTM! Robust data loading with resumability.The refactored data loading logic properly:
- Supports both single files and directories of JSONL files
 - Implements DP sharding for distributed processing
 - Filters out already-processed conversations for resumability
 - Provides debug capability to limit conversation count
 The conversation_id extraction handles both "conversation_id" and "uuid" fields with proper fallback and validation.
261-263: Verify the 256-token tokenization cap.The tokenization input is hard-capped at 256 tokens, which could truncate longer conversations before they're processed. This seems inconsistent with the
--max-seq-lenparameter (default 3072).Please clarify:
- Why is 256 tokens chosen as the cap?
 - Should this be configurable via a command-line argument?
 - How does this interact with the
 --max-seq-lenvalidation on line 267?If this cap is intentional and specific to the dumper's requirements, consider adding an inline comment explaining the rationale:
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ - :256 + :256 # Cap prompt to 256 tokens for hidden state collection ]
26-26: LGTM!The addition of
load_datasetfrom the datasets library is appropriate for the refactoring to HuggingFace datasets-based loading.
144-148: LGTM!The dataset sharding logic correctly distributes data across DP ranks using the datasets library's built-in
shardmethod, which ensures even distribution.
165-167: LGTM!The debug cap feature is well-implemented and will be useful for testing without processing the entire dataset.
131-141: Verify load_dataset handles empty input directory Ensure that passing an empty directory (with no.jsonlfiles) toload_dataset("json", data_files={"train": ".../*.jsonl"}, split="train")raises a clear, descriptive error rather than returning an empty dataset.examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (3)
42-47: LGTM! Appropriate timeout margin.The srun configuration is well-structured:
- The 235-minute timeout provides a 5-minute buffer before the 4-hour SBATCH limit, allowing graceful cleanup
 - PMIx MPI and overlap flags are appropriate for distributed execution
 - Container configuration correctly references the defined variables
 - Log file naming includes job metadata for tracking
 
8-12: LGTM!The SBATCH directives are correctly configured for single-node, 4-GPU execution with appropriate time limits and partition settings.
19-22: LGTM!The placeholder values are clearly marked and provide helpful guidance to users on what needs to be configured.
| # Remove already dumped conversations | ||
| def keep_conversation(entry): | ||
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | ||
| assert conversation_id is not None, "conversation_id is required" | ||
| output_file = args.output_dir / f"{conversation_id}.pt" | ||
| if output_file.exists(): | ||
| continue | ||
| filtered_conversations.append(entry) | ||
| return not output_file.exists() | ||
| 
               | 
          ||
| original_num = len(dataset) | ||
| dataset = dataset.filter(keep_conversation) | ||
| print( | ||
| "Removed", | ||
| len(all_conversations) - len(filtered_conversations), | ||
| original_num - len(dataset), | ||
| "conversations due to existing output files", | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace assertion with proper error handling.
The keep_conversation filter function uses an assertion to validate that conversation_id exists, which will cause the entire process to crash if any conversation lacks this field. In a distributed DP setting, this would fail the entire SLURM job.
Replace the assertion with proper error handling:
 def keep_conversation(entry):
     conversation_id = entry.get("conversation_id", entry.get("uuid", None))
-    assert conversation_id is not None, "conversation_id is required"
+    if conversation_id is None:
+        return False  # Skip conversations without valid ID
     output_file = args.output_dir / f"{conversation_id}.pt"
     return not output_file.exists()Additionally, consider logging a warning when conversations are skipped due to missing IDs, similar to how other invalid conversations are tracked.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Remove already dumped conversations | |
| def keep_conversation(entry): | |
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | |
| assert conversation_id is not None, "conversation_id is required" | |
| output_file = args.output_dir / f"{conversation_id}.pt" | |
| if output_file.exists(): | |
| continue | |
| filtered_conversations.append(entry) | |
| return not output_file.exists() | |
| original_num = len(dataset) | |
| dataset = dataset.filter(keep_conversation) | |
| print( | |
| "Removed", | |
| len(all_conversations) - len(filtered_conversations), | |
| original_num - len(dataset), | |
| "conversations due to existing output files", | |
| ) | |
| # Remove already dumped conversations | |
| def keep_conversation(entry): | |
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | |
| if conversation_id is None: | |
| # Skip conversations without valid ID (consider logging a warning here) | |
| return False | |
| output_file = args.output_dir / f"{conversation_id}.pt" | |
| return not output_file.exists() | |
| original_num = len(dataset) | |
| dataset = dataset.filter(keep_conversation) | |
| print( | |
| "Removed", | |
| original_num - len(dataset), | |
| "conversations due to existing output files", | |
| ) | 
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 150-163, replace the assertion that enforces a conversation_id with
non-crashing error handling: if conversation_id is missing, log a warning (or
print) indicating the skipped entry and return False from keep_conversation so
the entry is filtered out instead of crashing the job; optionally increment or
track a skipped counter for reporting, then continue to check for existing
output files and return not output_file.exists() for valid IDs.
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | ||
| :256 | ||
| ] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Clarify the magic number 256 for tokenization truncation.
The hardcoded limit of 256 tokens appears arbitrary and may not align with the model's actual context window or the user's --max-seq-len parameter. This could lead to confusion when conversations are unexpectedly truncated.
Consider one of the following approaches:
- Make this configurable via a CLI argument:
 
+parser.add_argument(
+    "--max-input-tokens",
+    type=int,
+    default=256,
+    help="Maximum number of tokens to use from conversation input for context."
+)- Or at minimum, add a constant and a comment explaining the rationale:
 
+# Limit input tokens to reduce memory usage during hidden state collection
+MAX_INPUT_TOKENS = 256
+
 input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
-    :256
+    :MAX_INPUT_TOKENS
 ]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | |
| :256 | |
| ] | |
| # Limit input tokens to reduce memory usage during hidden state collection | |
| MAX_INPUT_TOKENS = 256 | |
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | |
| :MAX_INPUT_TOKENS | |
| ] | 
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 261 to 263, replace the hardcoded token truncation slice [:256]
with a named constant or CLI-configurable value tied to the model/context window
(e.g., use the --max-seq-len argument if available, or
model.config.max_position_embeddings minus reserved tokens) and add a comment
explaining why that limit was chosen; update the tokenizer.apply_chat_template
call to slice by that variable so truncation respects the model's actual context
length or the user-provided parameter.
| #EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py. | ||
| PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT" | ||
| 
               | 
          ||
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix shell quoting issue.
The line has incorrect nested quoting that will result in the string "error" (including quotes) being assigned to TLLM_LOG_LEVEL, rather than just error.
Apply this diff:
-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
 trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
   --model /model \
   --input-data /input/ \
   --output-dir /output \
   $PARALLEL_ARGS \
   "🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
34, the RUN_DUMPER assignment uses nested double quotes causing the literal
string "\"error\"" to be assigned; change it so the environment variable is
assigned the plain word error (e.g., RUN_DUMPER="export TLLM_LOG_LEVEL=error;"
or RUN_DUMPER='export TLLM_LOG_LEVEL=error;') by removing the inner quotes or
using single outer quotes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems legit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChenhanYu, thank you for confirming! The fix will prevent the environment variable from receiving the literal string "error" (with quotes) and ensure it gets the plain value error as intended.
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | ||
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | ||
| --model /model \ | ||
| --input-data /input/ \ | ||
| --output-dir /output \ | ||
| $PARALLEL_ARGS \ | ||
| " | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix shell quoting syntax error.
The export statement has incorrect nested quotes that will cause a shell syntax error.
As per static analysis
Apply this diff to fix the quoting:
-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
 trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
   --model /model \
   --input-data /input/ \
   --output-dir /output \
   $PARALLEL_ARGS \
   "🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 34 to 40, the RUN_DUMPER assignment contains incorrectly nested double
quotes causing a shell syntax error; fix by wrapping the whole command in single
quotes (or escape the inner quotes) so the export uses TLLM_LOG_LEVEL="error"
correctly and the rest of the command string remains intact, e.g., assign
RUN_DUMPER='export TLLM_LOG_LEVEL="error"; trtllm-llmapi-launch python3
/modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
--model /model --input-data /input/ --output-dir /output $PARALLEL_ARGS' (ensure
no trailing unmatched quotes).
| timeout 235m srun -l \ | ||
| --mpi=pmix --overlap \ | ||
| --output=%x_%j_$DATETIME.log \ | ||
| --container-image ${CONTAINER} \ | ||
| --container-mounts ${MOUNTS} \ | ||
| bash -c "$RUN_DUMPER" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined variable in log filename.
Line 44 references $DATETIME which is never defined, causing the log filename to be incomplete or use a literal string.
Define the variable before use or remove it from the log filename:
+DATETIME=$(date +%Y%m%d_%H%M%S)
+
 timeout 235m srun -l \
     --mpi=pmix --overlap \
     --output=%x_%j_$DATETIME.log \
     --container-image ${CONTAINER} \
     --container-mounts ${MOUNTS} \
     bash -c "$RUN_DUMPER"Alternatively, if timestamps aren't needed, simplify to:
 timeout 235m srun -l \
     --mpi=pmix --overlap \
-    --output=%x_%j_$DATETIME.log \
+    --output=%x_%j.log \
     --container-image ${CONTAINER} \
     --container-mounts ${MOUNTS} \
     bash -c "$RUN_DUMPER"🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 42 to 47 the --output pattern references $DATETIME which is never defined;
either define DATETIME before the srun call (e.g. set DATETIME to a timestamp
string) or remove $DATETIME from the --output argument so the filename does not
include an undefined variable; update the script to export or assign DATETIME
prior to using it, or simplify the --output value to rely on existing SLURM
placeholders only.
          Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@           Coverage Diff           @@
##             main     #422   +/-   ##
=======================================
  Coverage   73.36%   73.36%           
=======================================
  Files         180      180           
  Lines       17919    17919           
=======================================
  Hits        13147    13147           
  Misses       4772     4772           ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
  | 
    
        
          
                examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
              
                Outdated
          
            Show resolved
            Hide resolved
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments
Signed-off-by: h-guo18 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (2)
34-40: Fix RUN_DUMPER quoting; current form is syntactically broken.The nested quotes terminate the string at "error", so the script fails before launch. Build the command with a heredoc to avoid quoting pitfalls.
-RUN_DUMPER="export TLLM_LOG_LEVEL="error"; -trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ - --model /model \ - --input-data /input/ \ - --output-dir /output \ - $PARALLEL_ARGS \ - " +RUN_DUMPER=$(cat <<EOF +export TLLM_LOG_LEVEL=error +trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ + --model /model \ + --input-data /input/ \ + --output-dir /output \ + $PARALLEL_ARGS +EOF +)
42-47: Define DATETIME (or remove it) in the srun --output pattern.$DATETIME is undefined, producing odd filenames.
+DATETIME=$(date +%Y%m%d_%H%M%S) timeout 235m srun -l \ --mpi=pmix --overlap \ --output=%x_%j_$DATETIME.log \ --container-image ${CONTAINER} \ - --container-mounts ${MOUNTS} \ + --container-mounts "${MOUNTS}" \ bash -c "$RUN_DUMPER"
🧹 Nitpick comments (2)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (2)
19-23: Validate required paths before running.Placeholders/defaults will cause failures or bad mounts. Add guards and existence checks.
INPUT_DIR="<Can be directory containing the .jsonl files, or path to single .jsonl file>" DUMP_DIR="<Directory for output hidden states>" MODELOPT_DIR="<Path to Modelopt repo>" TEACHER_MODEL="<Path to teacher model>" +# Basic validation +for var in INPUT_DIR DUMP_DIR MODELOPT_DIR TEACHER_MODEL; do + val="${!var}" + if [[ -z "$val" || "$val" == \<* \>* ]]; then + echo "ERROR: $var is not set. Please edit slurm_dump.sh." + exit 1 + fi +done + +[[ -e "$INPUT_DIR" ]] || { echo "ERROR: INPUT_DIR not found: $INPUT_DIR"; exit 1; } +[[ -d "$MODELOPT_DIR" ]] || { echo "ERROR: MODELOPT_DIR not found: $MODELOPT_DIR"; exit 1; } +[[ -d "$TEACHER_MODEL" ]] || { echo "ERROR: TEACHER_MODEL not found: $TEACHER_MODEL"; exit 1; }
28-28: Quote mounts to avoid word splitting.Protect paths with spaces/commas by quoting the value and usage.
-MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model +MOUNTS="${INPUT_DIR}:/input,${DUMP_DIR}:/output,${MODELOPT_DIR}:/modelopt,${TEACHER_MODEL}:/model"And in srun:
- --container-mounts ${MOUNTS} \ + --container-mounts "${MOUNTS}" \
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh(1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: build-docs
 - GitHub Check: code-quality
 
| 
               | 
          ||
| #SBATCH -A coreai_dlalgo_modelopt | ||
| #SBATCH --job-name=coreai_dlalgo_modelopt-generate_eagle_hidden_states | ||
| #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align allocation with intended topology (1 proc using 4 GPUs).
You set ntasks-per-node=4 but run a single command that does TP=4. Allocate 1 task and give it 4 GPUs to avoid idle tasks and binding ambiguity.
-#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
+#SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4Optionally, also make the step explicit:
-timeout 235m srun -l \
+timeout 235m srun -l -n 1 \📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 | |
| # In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh | |
| #SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4 | |
| timeout 235m srun -l -n 1 \ | |
| --mpi=pmix --overlap \ | |
| --output=%x_%j_$DATETIME.log \ | |
| --container-image ${CONTAINER} \ | |
| --container-mounts ${MOUNTS} \ | |
| bash -c "$RUN_DUMPER" | 
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
10, the SBATCH line currently specifies --ntasks-per-node=4 which creates four
tasks but you run a single process that needs 4 GPUs; change it to
--ntasks-per-node=1 --gpus-per-node=4 (and optionally add --cpus-per-task=<num>
if you need CPU binding) so one task owns all 4 GPUs, and if desired make the
launch step explicit by using srun --ntasks=1 (or sbatch step) to run the one
process.
| #EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py. | ||
| PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT" | ||
| 
               | 
          
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provide safe defaults when not using an array job.
If SLURM_ARRAY_TASK_ID/COUNT are unset, dp args become empty and the Python launcher may fail. Default to a single‑rank DP.
-#By default: TP inside node, and DP across slurm array
-#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
-PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"
+# By default: TP inside node, and DP across slurm array
+# EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
+DP_RANK="${SLURM_ARRAY_TASK_ID:-0}"
+DP_WORLD="${SLURM_ARRAY_TASK_COUNT:-1}"
+PARALLEL_ARGS="--tp 4 --dp-rank ${DP_RANK} --dp-world-size ${DP_WORLD}"🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 31 to 33, SLURM_ARRAY_TASK_ID and SLURM_ARRAY_TASK_COUNT may be unset
causing PARALLEL_ARGS to be empty and the Python launcher to fail; add safe
defaults (e.g., SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:-0} and
SLURM_ARRAY_TASK_COUNT=${SLURM_ARRAY_TASK_COUNT:-1}) and construct PARALLEL_ARGS
to include "--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size
$SLURM_ARRAY_TASK_COUNT" only after applying those defaults so a non-array run
defaults to single-rank DP.
What does this PR do?
Type of change: New feature
Overview:
Added a slurm multi-node script for TRTLLM eagle hidden states dumper.
Usage
sbatch --array=0-n slurm_dumper.sh# Add a code snippet demonstrating how to use thisTesting
Tested on HSG cluster;
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Chores