Skip to content

Conversation

jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Sep 4, 2025

What does this PR do?

Type of change: new feature

Overview: Support for launching QAT/QAD Simplified Flow in slurm & Qwen3-8B QAT recipe

Usage

python qat/nemo_qat_flow.py  --log-dir /my/log/dir --experiment qat_experiment

Testing

  • Tested locally
  • Tested launch on slurm

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Slurm-capable, multi-stage QAT/QAD flow with configurable parallelism, hyperparameters, integrated MMLU evaluation, and local/Slurm executors.
    • Utilities for Slurm job configuration, remote execution, dataset download, chat-template reading, and finetune-recipe retrieval.
    • Export tool to convert the latest experiment checkpoint to Hugging Face format.
    • CLI for running MMLU evaluation and a script to fetch/convert NVIDIA OpenScience data into chat-style JSONL.
  • Documentation

    • Expanded QAT/QAD docs and advanced Slurm guide; added NeMo example entry and updated README with flow stages, defaults, and run instructions.

@jenchen13 jenchen13 requested review from a team as code owners September 4, 2025 17:12
Copy link

copy-pr-bot bot commented Sep 4, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@jenchen13 jenchen13 requested a review from AAnoosheh September 4, 2025 17:12
Copy link

codecov bot commented Sep 4, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.82%. Comparing base (d94fc1b) to head (410c7dd).
⚠️ Report is 8 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #285   +/-   ##
=======================================
  Coverage   73.82%   73.82%           
=======================================
  Files         172      172           
  Lines       17438    17438           
=======================================
+ Hits        12873    12874    +1     
+ Misses       4565     4564    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

coderabbitai bot commented Sep 4, 2025

Walkthrough

Refactors the NeMo QAT/QAD flow into a Slurm-capable multi-stage pipeline, adds Slurm and dataset utilities, an OpenScience-to-chat processor, an in-memory MMLU runner, and a NeMo-run checkpoint export helper; updates QAT docs and README entries.

Changes

Cohort / File(s) Summary of Changes
Docs: QAT/QAD overview and advanced notes
examples/llm_qat/README.md, examples/nemo_run/qat/README.md, examples/nemo_run/qat/ADVANCED.md
Added a NeMo QAT/QAD link row to examples/llm_qat/README.md; substantially revised QAT README with new flow stages, defaults, OpenScience dataset guidance, new CLI examples and run instructions; added ADVANCED.md documenting Slurm usage, per-stage layout, and run notes.
Slurm & utility helpers
examples/nemo_run/common/utils.py
New SlurmConfig dataclass and helpers: create_slurm_executor (Local/SSH tunnel selection), get_finetune_recipe, read_chat_template, and download_hf_dataset. Adds validation, preset env-vars, container config, and GitArchivePackager usage.
OpenScience preprocessing
examples/nemo_run/common/process_openscience.py
New CLI script to download NVIDIA OpenScience JSONL, convert rows to OpenAI-style chat messages (convert_row_oai), split train/validation (90/10 of train), and write training.jsonl/validation.jsonl.
In-memory MMLU runner
examples/nemo_run/common/in_memory_mmlu.py
New CLI to run MMLU evaluation with ModelOpt/Megatron: parses --nemo_ckpt or --finetuned_ckpt_dir, optional tensor/pipeline parallelism, resolves most-recent checkpoint when needed, sets up trainer/model via ModelOpt helper, and runs megatron_mmlu.
QAT pipeline orchestration
examples/nemo_run/qat/nemo_qat_flow.py
Large refactor: replaced parser with get_args()/main(args), added global constants (SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL), expanded CLI (--algorithm, --use-slurm, --log-dir, --ptq-gpus, --train-gpus, --train-nodes, --learning-rate, --tensor_parallelism, --pipeline_parallelism, --enable_kv_cache, --kv-cache-qformat), centralized exp_dir path handling, OpenScience integration, Slurm-aware executors, PTQ/train/export orchestration, integrated MMLU evaluation steps, and imports/export helpers moved to utilities and modelopt export plugin.
NeMo-run export plugin
modelopt/torch/export/plugins/nemo_run.py
New module adding _get_most_recent_subdir, _get_most_recent_ckpt, and export_most_recent_ckpt(directory, output_path) to locate latest NeMo Run checkpoint and export it to HuggingFace format via NeMo's export_ckpt, with validation and logging.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Flow as nemo_qat_flow.py
  participant Exec as Executor (Local/Slurm)
  participant Proc as OpenScience Processor
  participant Import as BF16 Import
  participant MMLU1 as MMLU Eval (BF16)
  participant PTQ as PTQ Job
  participant MMLU2 as MMLU Eval (PTQ)
  participant Train as QAT/QAD Trainer
  participant MMLU3 as MMLU Eval (SFT)
  participant Export as export_most_recent_ckpt

  User->>Flow: run with args (exp_dir, algorithm, --use-slurm, parallelism...)
  Flow->>Exec: configure executors (CPU / PTQ-GPU / Train-GPU)
  Flow->>Proc: download & process OpenScience -> produce train/val JSONL
  Proc-->>Flow: training.jsonl, validation.jsonl
  Flow->>Import: restore BF16 checkpoint
  Import-->>Flow: bf16 checkpoint path
  Flow->>MMLU1: evaluate BF16
  MMLU1-->>Flow: BF16 metrics
  Flow->>PTQ: run PTQ (algorithm, kv-cache options)
  PTQ-->>Flow: quantized artifacts
  Flow->>MMLU2: evaluate PTQ model
  MMLU2-->>Flow: PTQ metrics
  Flow->>Train: run QAT/QAD training (recipe, lr, devices)
  Train-->>Flow: SFT checkpoints
  Flow->>MMLU3: evaluate SFT model
  MMLU3-->>Flow: SFT metrics
  Flow->>Export: export_most_recent_ckpt(exp_dir, hf_out)
  Export-->>Flow: HF export
  Flow-->>User: logs, metrics, exported model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Poem

I nibbled logs and hopped through queues,
OpenScience crumbs and Slurm-born views.
BF16 teacher, tiny quantized sprout,
I taught it tricks and checked metrics out.
Exported carrots to HF's house—hop, scout! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Slurm support for QAT Simplified Flow + Qwen3-8B recipe" succinctly and accurately summarizes the PR's main changes — adding Slurm orchestration for the QAT/QAD simplified flow and introducing a Qwen3-8B QAT recipe — and maps to the modified README, nemo_qat_flow.py, Slurm utilities, and recipe-related additions. It is concise, specific, and informative for a teammate scanning history.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (20)
modelopt/torch/export/plugins/nemo_run.py (3)

38-47: Type + stability nit: annotate return type and prefer higher-resolution mtime.

Add the return type and consider st_mtime_ns for better tie-breaking.

-def _get_most_recent_subdir(directory: Path):
+def _get_most_recent_subdir(directory: Path) -> Path:
@@
-    most_recent = max(subdirs, key=lambda x: x.stat().st_mtime)
+    most_recent = max(subdirs, key=lambda x: x.stat().st_mtime_ns)

24-35: Log destination and resolve absolute paths for clarity.

Small QoL: include output_path in logs; resolve paths to avoid ambiguity.

-def export_most_recent_ckpt(directory: str, output_path: str):
+def export_most_recent_ckpt(directory: str, output_path: str):
@@
-    logging.info(f"Exporting most recent NeMo Run checkpoint: {most_recent_ckpt}")
+    most_recent_ckpt = str(Path(most_recent_ckpt).resolve())
+    output_path = str(Path(output_path).resolve())
+    logging.info(f"Exporting most recent NeMo Run checkpoint: {most_recent_ckpt} -> {output_path}")

50-71: Avoid importing private helpers across modules; provide a public wrapper.

examples/nemo_run/common/in_memory_mmlu.py imports _get_most_recent_ckpt. Expose a public symbol to decouple callers from private helpers.

 def _get_most_recent_ckpt(directory: str):
@@
     return str(most_recent)
+
+def get_most_recent_ckpt(directory: str) -> str:
+    """Public wrapper for resolving the most recent NeMo Run checkpoint directory."""
+    return _get_most_recent_ckpt(directory)
examples/nemo_run/common/in_memory_mmlu.py (2)

20-21: Stop importing a private symbol; switch to a public helper.

Once get_most_recent_ckpt is added, import and use it here.

-from modelopt.torch.export.plugins.nemo_run import _get_most_recent_ckpt
+from modelopt.torch.export.plugins.nemo_run import get_most_recent_ckpt
@@
-        ckpt_path = _get_most_recent_ckpt(args.ckpt_dir)
+        ckpt_path = get_most_recent_ckpt(args.ckpt_dir)

Also applies to: 49-49


46-49: Simplify ckpt selection after making args mutually exclusive.

No need for an assert or pre-init; pick based on which arg is set.

-    assert args.nemo_ckpt or args.ckpt_dir, "Provide one of either --nemo_ckpt or --ckpt_dir."
-    ckpt_path = args.nemo_ckpt
-    if args.ckpt_dir:
-        ckpt_path = _get_most_recent_ckpt(args.ckpt_dir)
+    ckpt_path = args.nemo_ckpt if args.nemo_ckpt else get_most_recent_ckpt(args.ckpt_dir)
examples/nemo_run/common/process_openscience.py (3)

17-19: Remove unused import.

-import json
 import os
 from pathlib import Path

58-59: Ensure parent directories exist when creating the processed dir.

-        Path(proc_dir).mkdir(exist_ok=True)
+        Path(proc_dir).mkdir(parents=True, exist_ok=True)

43-45: Add a split seed for reproducibility.

-    split_ds = ds["train"].train_test_split(test_size=0.1)
+    split_ds = ds["train"].train_test_split(test_size=0.1, seed=42)
examples/nemo_run/qat/README.md (6)

11-17: Fix ordered list numbering.

Use 1..6 to silence linters and improve readability.

-1. Process Nvidia/OpenScience data (if `--data-path` is not specified)
-1. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on BF16 checkpoint
-1. PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint
-1. SFT (finetune) the model
-1. Evaluate 5% of MMLU on the SFT checkpoint
-1. Export model to Unified checkpoint (HuggingFace) format in lower precision
+1. Process NVIDIA/OpenScience data (if `--data-path` is not specified)
+2. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on the BF16 checkpoint
+3. PTQ the model and evaluate 5% of MMLU on the PTQ checkpoint
+4. SFT (fine-tune) the model
+5. Evaluate 5% of MMLU on the SFT checkpoint
+6. Export model to Unified checkpoint (Hugging Face) format in lower precision

44-44: Grammar: duplicate “on”.

-To run the example locally, launch a [NeMo container](...) with version 25.07 or higher using Docker on on a Slurm interactive node.
+To run the example locally, launch a [NeMo container](...) with version 25.07 or higher using Docker on a Slurm interactive node.

50-51: Capitalize “Slurm” and tighten wording.

-To run the example on slurm, edit the `SLURM_CONFIG` ...
+To run the example on Slurm, edit the `SLURM_CONFIG` ...

85-88: Specify a language for fenced block (lint).

-```
+```text
 qat_flow_ckpts qat_flow_ckpts_1755708286

---

`91-128`: **Specify a language for directory tree (lint).**


```diff
-```
+```text
 ├── 00_openscience_data
 ...

---

`132-132`: **Minor grammar cleanup.**


```diff
-By default the script will use the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around the assistant tokens which are needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
+By default, the script uses the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around the assistant tokens needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
examples/nemo_run/common/utils.py (2)

52-66: Require job_dir also for LocalTunnel.

LocalTunnel still needs a job directory; validate consistently.

-        if not self.use_local_tunnel:
+        if not self.use_local_tunnel:
             # Only validate SSH tunnel settings if not using local tunnel
@@
-            if not self.job_dir:
-                raise ValueError(
-                    "SlurmConfig.job_dir must be set to directory for storing runs on cluster"
-                )
+            if not self.job_dir:
+                raise ValueError("SlurmConfig.job_dir must be set to directory for storing runs on cluster")
+        else:
+            if not self.job_dir:
+                raise ValueError("SlurmConfig.job_dir must be set when use_local_tunnel=True")

126-129: Specify encoding when reading templates.

-def read_chat_template(template_path: str):
-    with open(template_path) as f:
+def read_chat_template(template_path: str):
+    with open(template_path, encoding="utf-8") as f:
         return f.read().strip()
examples/nemo_run/qat/nemo_qat_flow.py (4)

245-245: Remove debug limiter on validation.

limit_val_batches=2 will skew metrics. Either make it a CLI flag for dev only or remove.

-    train.trainer.limit_val_batches = 2  # TODO remove
+    # Consider exposing via CLI for quick dev runs:
+    # train.trainer.limit_val_batches = args.limit_val_batches

219-230: Constants defined under main are used inside main; import-time usage will crash.

If this module is imported and main(args) is called, SEQUENCE_LENGTH/GBS/MBS/TRAIN_STEPS/VAL_INTERVAL won’t exist. Hoist them to module scope or make them CLI args.

Also applies to: 367-372


29-31: Avoid sys.path manipulation for intra-repo imports.

Prefer packaging examples as a module or using relative imports via an installed editable package to remove path hacks.


343-363: SLURM_CONFIG lifecycle and defaults.

SLURM_CONFIG exists only under main and only when --use-slurm is passed. If someone imports and calls main(args) with use_slurm=True, this will NameError. Also, verify that time="240" matches your site policy (some clusters require HH:MM:SS) and ensure HF_TOKEN isn’t logged.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6ec8cdc and 6a33bc5.

📒 Files selected for processing (6)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/common/process_openscience.py (1 hunks)
  • examples/nemo_run/common/utils.py (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
  • _get_most_recent_ckpt (50-71)
modelopt/torch/utils/plugins/megatron_mmlu.py (1)
  • megatron_mmlu (65-152)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
  • download_hf_dataset (29-37)
examples/nemo_run/qat/nemo_qat_flow.py (2)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-41)
🪛 LanguageTool
examples/nemo_run/qat/README.md

[grammar] ~12-~12: There might be a mistake here.
Context: ...odel checkpoint and evaluate 5% of MMLU on BF16 checkpoint 1. PTQ the model and ev...

(QB_NEW_EN)


[grammar] ~13-~13: There might be a mistake here.
Context: .... PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint 1. SFT (finetune) the mo...

(QB_NEW_EN)


[grammar] ~16-~16: There might be a mistake here.
Context: ...ate 5% of MMLU on the SFT checkpoint 1. Export model to Unified checkpoint (HuggingFac...

(QB_NEW_EN)


[grammar] ~35-~35: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ## Usage ### Prereq...

(QB_NEW_EN)


[grammar] ~44-~44: There might be a mistake here.
Context: ...version 25.07 or higher using Docker on on a Slurm interactive node. Mount your cl...

(QB_NEW_EN)


[style] ~53-~53: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...

(FEEL_FREE_TO_STYLE_ME)


[style] ~62-~62: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...

(ENGLISH_WORD_REPEAT_BEGINNING_RULE)


[grammar] ~132-~132: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...

(QB_NEW_EN)


[grammar] ~132-~132: There might be a mistake here.
Context: ...--chat-template <my_template.txt>` flag.

(QB_NEW_EN)

🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/README.md

52-52: Headings should be surrounded by blank lines
Expected: 1; Actual: 0; Below

(MD022, blanks-around-headings)


85-85: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


91-91: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/common/in_memory_mmlu.py (1)

55-56: Confirm model.module is always present.

Depending on how the model is wrapped, .module may be absent on single-GPU runs.

Please run a quick smoke test on both single-GPU and DDP to ensure model.module is valid; otherwise pass model directly when .module is missing.

examples/nemo_run/qat/nemo_qat_flow.py (1)

170-178: HuggingFace offline mode caveat.

With TRANSFORMERS_OFFLINE=1 in SlurmConfig defaults, importing from hf:// may fail unless the model/tokenizer are pre-cached inside the container. Confirm cache availability or override the env var.

Also applies to: 202-211

Comment on lines +75 to +106
def create_slurm_executor(
slurm_cfg: SlurmConfig, nodes: int = 1, ntasks_per_node: int = 1, num_gpus: int = 0
):
# Configure tunnel
if slurm_cfg.use_local_tunnel:
# Use LocalTunnel when already on the cluster
tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir)
else:
# Use SSH tunnel when launching from local machine
tunnel = run.SSHTunnel(
host=slurm_cfg.host,
user=slurm_cfg.user,
job_dir=slurm_cfg.job_dir,
identity=slurm_cfg.identity, # can be None
)

if num_gpus > 0:
return run.SlurmExecutor(
account=slurm_cfg.account,
partition=slurm_cfg.partition_gpu,
ntasks_per_node=ntasks_per_node,
gpus_per_node=num_gpus,
nodes=nodes,
tunnel=tunnel,
container_image=slurm_cfg.container_image,
container_mounts=slurm_cfg.container_mounts,
time=slurm_cfg.time,
packager=run.GitArchivePackager(),
mem="0",
gres=f"gpu:{num_gpus}",
)
else:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Propagate env vars to Slurm jobs.

SlurmConfig.env_vars is never applied; pass it to the executor (param name may differ by nemo-run version).

-        return run.SlurmExecutor(
+        return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_gpu,
             ntasks_per_node=ntasks_per_node,
             gpus_per_node=num_gpus,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
             gres=f"gpu:{num_gpus}",
+            env_vars=slurm_cfg.env_vars,  # verify parameter name: env_vars/env
         )
@@
-        return run.SlurmExecutor(
+        return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_cpu,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
+            env_vars=slurm_cfg.env_vars,  # verify parameter name: env_vars/env
         )

If constructor doesn’t support it, set on the instance if available (e.g., executor.env_vars.update(...)).

Also applies to: 107-117


🏁 Script executed:

#!/usr/bin/env bash
set -e

# 1. Show import lines to determine how `run` is imported
sed -n '1,30p' examples/nemo_run/common/utils.py

# 2. Search repository for a local SlurmExecutor definition
rg -nP --type=py "class\s+SlurmExecutor" -n -C3 .

# 3. If not found locally, install the package and introspect its signature
pip install nemo-run
python3 - << 'PYCODE'
import inspect
try:
    from nemo_run.common import run
except ImportError:
    import run
sig = inspect.signature(run.SlurmExecutor)
print("SlurmExecutor signature:", sig)
PYCODE

Length of output: 1194


🏁 Script executed:

#!/usr/bin/env bash
set -e

# 1. Confirm no local SlurmExecutor in repo
rg -nP --type=py "class\s+SlurmExecutor" -n -C3 .

# 2. Install nemo-run and inspect SlurmExecutor signature
pip install nemo-run
python3 - << 'PYCODE'
import inspect
import nemo_run as run
sig = inspect.signature(run.SlurmExecutor)
print("Signature:", sig)
print("Parameters:", list(sig.parameters.keys()))
PYCODE

Length of output: 60


🏁 Script executed:

#!/usr/bin/env bash
# Re-verify SlurmExecutor existence and signature without exiting on errors.

# 1. Check for local SlurmExecutor; ignore if none found.
rg -nP --type=py "class\s+SlurmExecutor" -n -C3 . || echo "No local SlurmExecutor definition in repo."

# 2. Install nemo-run package (suppressing verbose output).
pip install nemo-run >/dev/null 2>&1

# 3. Introspect the SlurmExecutor constructor signature and doc.
python3 - << 'PYCODE'
import inspect
import nemo_run as run

sig = inspect.signature(run.SlurmExecutor)
params = sig.parameters

print("Signature:", sig)
print("Parameter names:", list(params.keys()))
# Print each parameter with default if any
for name, param in params.items():
    print(f"  - {name}: default={param.default!r}")
PYCODE

Length of output: 4059


🏁 Script executed:

#!/usr/bin/env bash
set -e

# Locate SlurmConfig definition
grep -R "class SlurmConfig" -n examples/nemo_run/common/utils.py

# Display SlurmConfig fields
sed -n '1,200p' examples/nemo_run/common/utils.py | sed -n '/class SlurmConfig/,/^$/p'

# Search for env_vars usage in the file
grep -R "env_vars" -n examples/nemo_run/common/utils.py

Length of output: 427


Pass env_vars from SlurmConfig to SlurmExecutor
Add env_vars=slurm_cfg.env_vars to both run.SlurmExecutor(…) calls in the GPU and CPU branches of create_slurm_executor (examples/nemo_run/common/utils.py) so that user-configured environment variables are applied.

🤖 Prompt for AI Agents
In examples/nemo_run/common/utils.py around lines 75 to 106, the SlurmExecutor
calls do not pass the user-configured environment variables from SlurmConfig;
add env_vars=slurm_cfg.env_vars to both the GPU branch and the CPU (else) branch
SlurmExecutor(...) invocations so the executor is launched with the configured
environment variables.

Comment on lines +120 to +124
def get_finetune_recipe(recipe_name: str):
if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix AttributeError risk when recipe module is missing.

getattr(llm, recipe_name) will raise if the module doesn’t exist; guard it first.

-def get_finetune_recipe(recipe_name: str):
-    if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
-        raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
-    return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
+def get_finetune_recipe(recipe_name: str):
+    recipe_mod = getattr(llm, recipe_name, None)
+    if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"):
+        raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe")
+    return recipe_mod.finetune_recipe(peft_scheme=None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_finetune_recipe(recipe_name: str):
if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
def get_finetune_recipe(recipe_name: str):
recipe_mod = getattr(llm, recipe_name, None)
if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"):
raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe")
return recipe_mod.finetune_recipe(peft_scheme=None)
🤖 Prompt for AI Agents
In examples/nemo_run/common/utils.py around lines 120 to 124, the code calls
getattr(llm, recipe_name) directly which will raise AttributeError if the recipe
module is missing; first check that the recipe exists (e.g., use hasattr(llm,
recipe_name) or getattr(llm, recipe_name, None) and verify it’s not None), then
ensure the found object has a finetune_recipe attribute before accessing it; if
the recipe is missing or lacks finetune_recipe, raise a clear ValueError,
otherwise call finetune_recipe(peft_scheme=None).

@jenchen13 jenchen13 requested a review from a team as a code owner September 5, 2025 17:06
@jenchen13 jenchen13 force-pushed the jennifchen/qat_slurm branch from 2edf3fa to 0eac8f6 Compare September 5, 2025 17:07
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (9)
examples/llm_qat/README.md (1)

14-14: Add Docs link for consistency (optional).

Consider filling the “Docs” column with a pointer to the new NeMo QAT/QAD guide (e.g., ../nemo_run/qat/README.md) to match other rows that include docs.

examples/nemo_run/qat/ADVANCED.md (3)

11-11: Document how to install rsync inside the container.

Provide an explicit install step or a container tag that includes rsync.

Add:

# Inside container (Debian/Ubuntu base)
apt-get update && apt-get install -y rsync openssh-client

Or note the minimal container version that bundles rsync.


13-16: Specify fenced code language.

Add a language to satisfy MD040 and improve rendering.

-```
+```text
 qat_flow_ckpts qat_flow_ckpts_1755708286

---

`19-56`: **Specify fenced code language for the directory tree.**

Add a language to satisfy MD040 and improve rendering.

```diff
-```
+```text
 ├── 00_openscience_data
 │   ├── code
 ...
 │   └── configs

</blockquote></details>
<details>
<summary>examples/nemo_run/qat/README.md (5)</summary><blockquote>

`5-8`: **Header links: remove duplicate or rename.**

Both “Slurm Examples” and “Advanced Topics” point to ADVANCED.md. Either remove one or rename to distinct anchors.

```diff
-[Slurm Examples](ADVANCED.md) |
-[Advanced Topics](ADVANCED.md) |
+[Slurm Examples](ADVANCED.md) |

Or create section anchors in ADVANCED.md and link each separately.


44-44: Remove double space.

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

46-55: Safer container mounting; avoid hardcoding dist-packages path.

Mount repos and use editable installs or PYTHONPATH to avoid Python-version-specific paths.

-Example docker command:
-```
-docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
-```
+Example docker command:
+```bash
+docker run --gpus all --shm-size 20g --rm -it \
+  -v /home/user/NeMo:/opt/NeMo \
+  -v /home/user/TensorRT-Model-Optimizer:/workspace/TRTMO \
+  -v /home/user:/home/user \
+  nvcr.io/nvidia/nemo:25.07 bash
+```
+Then inside the container:
+```bash
+pip install -e /workspace/TRTMO  # or: export PYTHONPATH=/workspace/TRTMO:$PYTHONPATH
+```

63-69: Clarify working directory.

Specify that the command is run from examples/nemo_run.

-From the `nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script.
+From the `examples/nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script.

96-96: Grammar: add article.

-To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
+To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6a33bc5 and 1ef040f.

📒 Files selected for processing (3)
  • examples/llm_qat/README.md (1 hunks)
  • examples/nemo_run/qat/ADVANCED.md (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/ADVANCED.md

13-13: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


19-19: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🪛 LanguageTool
examples/nemo_run/qat/README.md

[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...

(QB_NEW_EN)


[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...

(QB_NEW_EN)


[grammar] ~22-~22: There might be a mistake here.
Context: ...odel checkpoint and evaluate 5% of MMLU on BF16 checkpoint 1. PTQ the model and ev...

(QB_NEW_EN)


[grammar] ~23-~23: There might be a mistake here.
Context: .... PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint 1. SFT (finetune) the mo...

(QB_NEW_EN)


[grammar] ~26-~26: There might be a mistake here.
Context: ...ate 5% of MMLU on the SFT checkpoint 1. Export model to Unified checkpoint (HuggingFac...

(QB_NEW_EN)


[style] ~64-~64: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...

(ENGLISH_WORD_REPEAT_BEGINNING_RULE)


[grammar] ~90-~90: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Tem...

(QB_NEW_EN)


[grammar] ~96-~96: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...

(QB_NEW_EN)


[style] ~99-~99: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...

(FEEL_FREE_TO_STYLE_ME)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/qat/README.md (2)

28-37: Mermaid flow looks good.


48-50: Pin NeMo commit with full hash and date
In examples/nemo_run/qat/README.md (lines 48–50), replace the short hash ddcb75f with its full 40-character commit hash and include the commit date/message. You can retrieve these details by cloning the NeMo repo locally and running:

git clone https://github.com/NVIDIA-NeMo/NeMo.git
cd NeMo
git rev-parse ddcb75f
git show -s --format='%H %ad %s' ddcb75f

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
examples/nemo_run/qat/README.md (2)

19-27: Align the stage list with the 8-stage flow and exact task names.

The list shows six generic steps while the code/diagram uses eight named stages. Please list all eight with exact IDs to match logs and ADVANCED.md.

Use:

-Currently the Simplified Flow runs the following steps in order:
-
-1. Process Nvidia/OpenScience data (if `--data-path` is not specified)
-1. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on BF16 checkpoint
-1. PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint
-1. SFT (finetune) the model
-1. Evaluate 5% of MMLU on the SFT checkpoint
-1. Export model to Unified checkpoint (HuggingFace) format in lower precision
+Flow stages:
+
+1. 00_openscience_data — Process NVIDIA OpenScience data (skipped if `--data-path` is provided)
+2. 01_import_model — Import NeMo BF16 model checkpoint
+3. 02_mmlu_bf16 — Evaluate 5% MMLU on BF16 checkpoint
+4. 03_ptq — Apply PTQ
+5. 04_mmlu_ptq — Evaluate 5% MMLU on PTQ checkpoint
+6. 05_train — SFT/QAT (and optional QAD)
+7. 06_mmlu_sft — Evaluate 5% MMLU on SFT/QAT checkpoint
+8. 07_export_hf — Export to Hugging Face (Unified) format

15-16: Fix incomplete sentence after PTQ.

This reads as a fragment; complete it to explain why QAT/QAD follow PTQ.

Apply:

-After PTQ (post-training quantization), the quantized model may
+After PTQ (post-training quantization), the quantized model may exhibit accuracy degradation on tasks like MMLU; the subsequent QAT/QAD stages aim to recover that loss.
🧹 Nitpick comments (7)
examples/nemo_run/qat/README.md (7)

5-7: Avoid duplicate links to the same target in the header.

Both “Slurm Examples” and “Advanced Topics” point to ADVANCED.md. Either consolidate or point “Slurm Examples” to a section anchor.

If ADVANCED.md has a Slurm section, consider:

-[Slurm Examples](ADVANCED.md) |
-[Advanced Topics](ADVANCED.md) |
+[Slurm Examples](ADVANCED.md#slurm) |
+[Advanced Topics](ADVANCED.md) |

43-43: Fix minor spacing typo.

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

45-49: Prefer PYTHONPATH over bind-mounting into site-packages.

Mounting into /usr/local/.../site-packages can be brittle across images/python versions. Using PYTHONPATH keeps the container clean and reduces surprises.

-Example docker command:
+Example docker command:
@@
-```
-docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
-```
+```bash
+docker run --gpus all --shm-size 20g --rm -it \
+  -v /home/user/NeMo:/opt/NeMo \
+  -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \
+  -e PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer \
+  nvcr.io/nvidia/nemo:25.07 bash
+```

52-55: Add language to the fenced code block (markdownlint MD040).

-```
+```bash
 docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash

---

`62-62`: **Minor wording polish and consistent naming.**

Use “Hugging Face” and ensure dataset name capitalization is consistent.


```diff
-... use the model's HuggingFace name ...
+... use the model's Hugging Face name ...

Also, earlier “Nvidia/OpenScience” → “NVIDIA OpenScience” (addressed in the stage list fix).


91-94: Tighten wording and fix article usage.

-By default the script will use the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around the assistant tokens which are needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
+By default, the script uses the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around assistant tokens that are needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.

97-98: Tone/style refinement.

Optional rephrase to avoid “Feel free to” and tighten the message.

-The current QAT recipe has been tuned for the Qwen3-8B model to improve accuracy on the MMLU benchmark after PTQ degradation. QAT/QAD results are highly dependent on the specific model, dataset, and hyperparameters. There is no guarantee that the same dataset will recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinations and test which combination works best.
+The current QAT recipe is tuned for Qwen3-8B to improve MMLU after PTQ-induced degradation. QAT/QAD results depend on the model, dataset, and hyperparameters, and the same dataset may not recover PTQ accuracy. You are encouraged to try different model–dataset combinations and validate which works best in your environment.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 1ef040f and 46e6c13.

📒 Files selected for processing (2)
  • examples/nemo_run/common/process_openscience.py (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/nemo_run/common/process_openscience.py
🧰 Additional context used
🪛 LanguageTool
examples/nemo_run/qat/README.md

[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...

(QB_NEW_EN)


[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...

(QB_NEW_EN)


[grammar] ~22-~22: There might be a mistake here.
Context: ...odel checkpoint and evaluate 5% of MMLU on BF16 checkpoint 1. PTQ the model and ev...

(QB_NEW_EN)


[grammar] ~23-~23: There might be a mistake here.
Context: .... PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint 1. SFT (finetune) the mo...

(QB_NEW_EN)


[grammar] ~26-~26: There might be a mistake here.
Context: ...ate 5% of MMLU on the SFT checkpoint 1. Export model to Unified checkpoint (HuggingFac...

(QB_NEW_EN)


[style] ~63-~63: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...

(ENGLISH_WORD_REPEAT_BEGINNING_RULE)


[grammar] ~88-~88: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Temp...

(QB_NEW_EN)


[grammar] ~93-~93: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...

(QB_NEW_EN)


[style] ~97-~97: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...

(FEEL_FREE_TO_STYLE_ME)

🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/README.md

52-52: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
examples/nemo_run/qat/README.md (4)

66-68: LGTM: local QAT invocation example is clear and minimal.


79-80: LGTM: QAD invocation example is clear and mirrors QAT usage.


28-37: No changes needed: diagram labels align with code stage IDs.
All Mermaid nodes in examples/nemo_run/qat/README.md correspond exactly to the stage names in nemo_qat_flow.py.


84-87: Resolved: CLI flags and defaults match documentation
Flags --train-nodes (default 1), --train-gpus (default 8), and --ptq-gpus (default 4) are correctly defined in nemo_qat_flow.py and align with the README.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
examples/nemo_run/qat/ADVANCED.md (1)

5-5: Don’t instruct users to hardcode credentials; point to env-driven config.

Replace “edit SLURM_CONFIG in source” with guidance to set environment variables that SlurmConfig reads, and reference the utils.SlurmConfig schema. This avoids secrets in source and eases CI.

-To run the example on slurm, edit the `SLURM_CONFIG` at the bottom of `nemo_qat_flow.py` with the appropriate credentials, container, cluster name (host), and container mounts.
+Define Slurm settings via environment variables consumed by `SlurmConfig` (see `examples/nemo_run/common/utils.py`) rather than editing source. Set values like `SLURM_ACCOUNT`, `SLURM_PARTITION_{GPU,CPU}`, `SLURM_TIME`, `CONTAINER_IMAGE`, `CONTAINER_MOUNTS`, `SLURM_HOST`, and `SLURM_USER` before launching.
examples/nemo_run/qat/nemo_qat_flow.py (3)

146-153: Separate “recipe name” from “model name” and validate before getattr.

Avoid overloading model_name for both concepts; it risks attribute errors and wrong path prefixes.

-    model_name = args.finetune_recipe
-    model_module = getattr(llm, model_name)
-    if not model_name:
-        model_name = os.path.basename(args.model_name)
+    recipe_name = args.finetune_recipe
+    if not recipe_name:
+        raise ValueError("--finetune-recipe must be specified when --distill is not used")
+    model_module = getattr(llm, recipe_name)
+    # Use recipe name as the filesystem prefix for artifacts
+    model_name = recipe_name

138-142: KV-cache flag is forced to “disabled” by default and CLI uses underscore not hyphen.

Use a tri-state CLI: only pass a flag when explicitly set, and expose hyphenated flags to match README. This avoids overriding ptq.py defaults unintentionally and removes doc/code mismatch.

-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    # Tri-state KV-cache control; pass a flag only if explicitly set
+    kv = parser.add_mutually_exclusive_group(required=False)
+    parser.set_defaults(enable_kv_cache=None)
+    kv.add_argument("--enable-kv-cache", dest="enable_kv_cache", action="store_true", help="Enable KV-cache quantization")
+    kv.add_argument("--disable-kv-cache", dest="enable_kv_cache", action="store_false", help="Disable KV-cache quantization")
@@
-    ptq = run.Script(
+    # Build KV-cache flag only when explicitly set
+    kv_cache_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
@@
-            "--kv_cache_qformat",
-            args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            "--kv_cache_qformat",
+            args.kv_cache_qformat,
+            *kv_cache_flag,

Also applies to: 176-194


266-275: Don’t reuse a single Slurm executor for PTQ, Train, and Export.

Reusing and mutating one executor risks wrong gres/gpu allocations across stages. Create dedicated executors per stage.

-    if args.use_slurm:
-        cpu_executor = create_slurm_executor(SLURM_CONFIG)
-        gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
-        )
-        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
+    if args.use_slurm:
+        cpu_executor = create_slurm_executor(SLURM_CONFIG)
+        ptq_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
+        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
+        )
+        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
     else:
-        cpu_executor = single_gpu_executor = run.LocalExecutor()
-        gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        cpu_executor = single_gpu_executor = run.LocalExecutor()
+        ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus)
@@
-        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1])
+        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1])
@@
-        if args.use_slurm:  # Set training arguments
-            gpu_executor.nodes = args.train_nodes
-            gpu_executor.devices = gpu_executor.ntasks_per_node = args.train_gpus
-        else:
-            gpu_executor.ntasks_per_node = args.train_gpus
+        # training resources already encoded in train_gpu_executor
@@
-        s4 = exp.add(
-            train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep
-        )
+        s4 = exp.add(train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep)
@@
-        gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
-            export,
-            tail_logs=True,
-            name="07_export_hf",
-            executor=gpu_executor,
-            dependencies=[s5],
-        )
+        exp.add(export, tail_logs=True, name="07_export_hf", executor=single_gpu_executor, dependencies=[s5])

Also applies to: 293-301, 302-308, 311-314, 322-329

🧹 Nitpick comments (13)
examples/nemo_run/qat/ADVANCED.md (4)

3-3: Polish the intro sentence (comma + “one”).

Add a comma after “for example” and prefer “one” over “1” for docs consistency.

-If you need to run QAT/QAD on a Slurm cluster (for example to use more than 1 node), this guide covers how to configure and launch on Slurm.
+If you need to run QAT/QAD on a Slurm cluster (for example, to use more than one node), this guide covers how to configure and launch on Slurm.

11-11: Clarify how to get rsync inside the container.

Add a one-liner to install rsync if missing to prevent launch failures.

-**NOTE:** `rsync` may not currently be available in the NeMo container and will be added as a dependency.
+**NOTE:** If `rsync` is not available in the NeMo container, install it before launching:
+```bash
+apt-get update && apt-get install -y rsync
+```

13-16: Add a language to fenced code block (markdownlint MD040).

-```
+```text
 qat_flow_ckpts qat_flow_ckpts_1755708286

---

`19-56`: **Add a language to fenced code block (markdownlint MD040).**


```diff
-```
+```text
 ├── 00_openscience_data
 │   ├── code
 │   ├── configs
 …

</blockquote></details>
<details>
<summary>examples/nemo_run/common/in_memory_mmlu.py (1)</summary><blockquote>

`25-27`: **Align CLI help with actual flag name.**

Help text says “--ckpt_dir” but the flag is “--finetuned_ckpt_dir”. Update description to avoid confusion.


```diff
-        description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --ckpt_dir"
+        description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --finetuned_ckpt_dir"
examples/nemo_run/qat/nemo_qat_flow.py (3)

29-31: Avoid sys.path mutation; import via package.

Prefer making examples/nemo_run/common a package and importing utils normally.

-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "common")))
-from utils import SlurmConfig, create_slurm_executor, get_finetune_recipe, read_chat_template
+from examples.nemo_run.common.utils import (
+    SlurmConfig,
+    create_slurm_executor,
+    get_finetune_recipe,
+    read_chat_template,
+)

239-239: Temporary validation cap left on.

train.trainer.limit_val_batches = 2 looks like a debug setting. Remove or guard with a flag before merging.


361-366: Module constants are defined after use; move to top-level for clarity.

They’re set before main() runs, but relocating them near imports improves readability and reduces ordering hazards.

examples/nemo_run/qat/README.md (5)

45-45: Fix double space and minor phrasing.

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

54-56: Add language to the Docker command block (markdownlint MD040).

-```
+```bash
 docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash

---

`64-70`: **Consider a more robust local install flow for modelopt.**

Mounting into `dist-packages` is brittle. Suggest editable installs.


```diff
-Example docker command:
+Example docker command (then install repos inside the container):
@@
-```bash
-docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
-```
+```bash
+docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
+pip install -e /opt/NeMo
+pip install -e /home/user/TensorRT-Model-Optimizer/modelopt
+```

72-72: CLI flag mismatch with code (--enable-kv-cache vs --enable_kv_cache).

Docs use hyphenated flag; current code defines underscore. Recommend aligning to hyphenated form per argparse conventions, or update docs once code changes.


95-95: Minor grammar.

-To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
+To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 46e6c13 and 856159e.

📒 Files selected for processing (5)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/qat/ADVANCED.md (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/plugins/nemo_run.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
  • _get_most_recent_ckpt (50-73)
modelopt/torch/utils/plugins/megatron_mmlu.py (1)
  • megatron_mmlu (65-152)
examples/nemo_run/qat/nemo_qat_flow.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-35)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
🪛 LanguageTool
examples/nemo_run/qat/README.md

[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...

(QB_NEW_EN)


[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...

(QB_NEW_EN)


[style] ~65-~65: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...

(ENGLISH_WORD_REPEAT_BEGINNING_RULE)


[grammar] ~90-~90: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Temp...

(QB_NEW_EN)


[grammar] ~95-~95: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...

(QB_NEW_EN)


[style] ~99-~99: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...

(FEEL_FREE_TO_STYLE_ME)

🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/README.md

54-54: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

examples/nemo_run/qat/ADVANCED.md

13-13: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


19-19: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/common/in_memory_mmlu.py (1)

28-35: Good switch to a mutually exclusive group for checkpoint inputs.

examples/nemo_run/qat/nemo_qat_flow.py (1)

336-357: I’ve initiated a shell check to verify SlurmConfig import and whether os is already imported in nemo_qat_flow.py.

train.trainer.max_steps = 200
train.log.log_dir = exp_dir
train.trainer.val_check_interval = VAL_INTERVAL
train.trainer.max_steps = TRAIN_STEPS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
train.trainer.max_steps = TRAIN_STEPS
train.trainer.max_steps = TRAIN_STEPS
train.trainer.strategy.tensor_model_parallel_size = args.train_gpus

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now it's running DP by default which is what I want (since 8B model will fit in 1 GPU)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess could be helpful if someone has less gpu memory per device, but not 100% necessary

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/nemo_run/qat/README.md (1)

74-75: Wrong flag spelling: hyphen vs underscore.

The code defines --enable_kv_cache and --kv-cache-qformat; the README shows --enable-kv-cache. This will 100% fail at runtime.

Apply:

-> **_NOTE:_** To enable KV cache quantization, add `--enable-kv-cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`.
+> **_NOTE:_** To enable KV cache quantization, add `--enable_kv_cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`.
♻️ Duplicate comments (3)
examples/nemo_run/common/in_memory_mmlu.py (1)

48-49: Bug: unreachable finetuned path due to arg name mismatch.

Branch checks args.ckpt_dir which doesn’t exist; it should use args.finetuned_ckpt_dir.

Apply:

-    if args.ckpt_dir:
-        ckpt_path = _get_most_recent_ckpt(args.ckpt_dir)
+    if args.finetuned_ckpt_dir:
+        ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir)
examples/nemo_run/qat/nemo_qat_flow.py (2)

137-142: Tri‑state KV‑cache flag handling; don’t force disable by default.

Currently --disable_kv_cache is always passed when the user doesn’t specify anything, changing defaults implicitly.

Apply:

@@ def get_args():
-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    kv_group = parser.add_mutually_exclusive_group(required=False)
+    kv_group.add_argument(
+        "--enable_kv_cache",
+        dest="enable_kv_cache",
+        help="Enable KV-cache quantization",
+        action="store_true",
+    )
+    kv_group.add_argument(
+        "--disable_kv_cache",
+        dest="enable_kv_cache",
+        help="Disable KV-cache quantization",
+        action="store_false",
+    )
+    parser.set_defaults(enable_kv_cache=None)  # tri-state: True/False/None
@@
-    ptq = run.Script(
+    # Build KV-cache flag only if explicitly set
+    kv_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
         "/opt/NeMo/scripts/llm/ptq.py",
         args=[
@@
-            "--kv_cache_qformat",
-            args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            "--kv_cache_qformat",
+            args.kv_cache_qformat,
+            *kv_flag,

Also applies to: 183-200


274-283: Do not reuse one GPU Slurm/Local executor for PTQ, Train, and Export.

Mutating gpu_executor (nodes/ntasks_per_node/devices) is error‑prone; requests can leak across stages.

Apply:

@@
-    if args.use_slurm:
-        cpu_executor = create_slurm_executor(SLURM_CONFIG)
-        gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
-        )
-        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
+    if args.use_slurm:
+        cpu_executor = create_slurm_executor(SLURM_CONFIG)
+        ptq_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
+        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
+        )
+        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
     else:
-        cpu_executor = single_gpu_executor = run.LocalExecutor()
-        gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        cpu_executor = single_gpu_executor = run.LocalExecutor()
+        ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus)
@@
-        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1])
+        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1])
@@
-        if args.use_slurm:  # Set training arguments
-            gpu_executor.nodes = args.train_nodes
-            gpu_executor.devices = gpu_executor.ntasks_per_node = args.train_gpus
-        else:
-            gpu_executor.ntasks_per_node = args.train_gpus
         train_dep = [s3]
         if not args.data_path:
             train_dep.append(s0)
-        s4 = exp.add(
-            train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep
-        )
+        s4 = exp.add(train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep)
@@
-        gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
-            export,
-            tail_logs=True,
-            name="07_export_hf",
-            executor=gpu_executor,
-            dependencies=[s5],
-        )
+        exp.add(export, tail_logs=True, name="07_export_hf", executor=single_gpu_executor, dependencies=[s5])

Also applies to: 301-336

🧹 Nitpick comments (10)
examples/nemo_run/common/in_memory_mmlu.py (2)

26-26: Fix help text to match the actual flag name.

Says “--ckpt_dir” but the CLI exposes “--finetuned_ckpt_dir”.

Apply:

-        description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --ckpt_dir"
+        description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --finetuned_ckpt_dir"

47-55: Optional: validate resolved checkpoint path early.

A quick exists-check gives clearer error messages before restoration.

Example:

     ckpt_path = args.nemo_ckpt
     if args.finetuned_ckpt_dir:
         ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir)
+    if not ckpt_path or not os.path.exists(ckpt_path):
+        raise FileNotFoundError(f"Checkpoint path not found: {ckpt_path}")

(Remember to import os.)

examples/nemo_run/qat/README.md (4)

45-45: Minor: stray double space.

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

97-97: Grammar nit: add article.

- To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
+ To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.

55-56: Fragile mount path to site-packages. Prefer editable install or compute path.

Hardcoding /usr/local/lib/python3.12/dist-packages/modelopt is container/Python-version dependent.

Consider replacing with one of:

  • Use editable install inside container: pip install -e /home/user/TensorRT-Model-Optimizer/modelopt.
  • Or compute site-packages path dynamically:
python - <<'PY'
import site, sys, os
print(next(p for p in site.getsitepackages() if 'site-packages' in p))
PY

Then mount to that returned path.

Would you like me to propose a concise Docker snippet reflecting this?


5-7: Duplicate links to the same target.

“Slurm Examples” and “Advanced Topics” both point to ADVANCED.md. If intentional, consider distinct anchors; otherwise dedupe.

examples/nemo_run/qat/nemo_qat_flow.py (4)

146-153: Separate “recipe name” from “model name/path prefix” to avoid overload.

Using model_name = args.finetune_recipe mixes two concepts and complicates fallbacks.

Apply:

 def main(args):
     if not args.distill and not args.finetune_recipe:
         raise ValueError("If distillation is not used, --finetune-recipe must be specified")
-    model_name = args.finetune_recipe
-    model_module = getattr(llm, model_name)
-    if not model_name:
-        model_name = os.path.basename(args.model_name)
+    recipe_name = args.finetune_recipe
+    model_module = getattr(llm, recipe_name)
+    # Use recipe name as path prefix; keep HF model name separate
+    path_prefix = recipe_name
@@
-    bf16_ckpt_path = f"{exp_dir}/{model_name}-nemo"
+    bf16_ckpt_path = f"{exp_dir}/{path_prefix}-nemo"
@@
-    ptq_model_out = f"{exp_dir}/{model_name}-{args.algorithm}"
+    ptq_model_out = f"{exp_dir}/{path_prefix}-{args.algorithm}"

Also update the export path at Line 249 accordingly (see next comment).

Also applies to: 171-182


248-250: Keep export path consistent with path prefix refactor.

-    export = run.Partial(
-        export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{model_name}_hf"
-    )
+    export = run.Partial(
+        export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{path_prefix}_hf"
+    )

351-352: Slurm time format should be HH:MM:SS.

"240" is ambiguous and contradicts SlurmConfig’s guidance.

-            time="240",
+            time="04:00:00",

338-338: Detach only when using Slurm; keep local runs blocking.

Improves local UX and log visibility.

-        exp.run(detach=True)
+        exp.run(detach=args.use_slurm)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 856159e and 14f15ab.

📒 Files selected for processing (3)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-35)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-42)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
  • _get_most_recent_ckpt (50-73)
modelopt/torch/utils/plugins/megatron_mmlu.py (1)
  • megatron_mmlu (65-152)
🪛 LanguageTool
examples/nemo_run/qat/README.md

[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...

(QB_NEW_EN)


[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...

(QB_NEW_EN)


[style] ~67-~67: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...

(ENGLISH_WORD_REPEAT_BEGINNING_RULE)


[grammar] ~92-~92: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Temp...

(QB_NEW_EN)


[grammar] ~97-~97: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...

(QB_NEW_EN)


[style] ~101-~101: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...

(FEEL_FREE_TO_STYLE_ME)

🔇 Additional comments (1)
examples/nemo_run/common/in_memory_mmlu.py (1)

28-35: Good: proper CLI exclusivity for checkpoint inputs.

Using argparse’s mutually exclusive group is the right fix; this prevents ambiguous invocations.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (10)
modelopt/torch/export/plugins/nemo_run.py (1)

59-71: Good fix: explicit exceptions instead of asserts.

Replacing asserts with FileNotFoundError makes the checks reliable at runtime.

examples/nemo_run/common/in_memory_mmlu.py (1)

31-38: Nice: proper argparse mutually-exclusive group.

This removes the runtime assert and gives standard CLI errors.

examples/nemo_run/common/process_openscience.py (1)

20-21: Bug: loading on-disk dataset with load_dataset will fail; use load_from_disk.

You save with save_to_disk and must load with load_from_disk.

-from datasets import load_dataset
+from datasets import load_dataset, load_from_disk
@@
-    ds = load_dataset(raw_dir)
+    ds = load_from_disk(raw_dir)

Also applies to: 38-44

examples/nemo_run/common/utils.py (1)

120-124: Guard against missing recipe module to avoid AttributeError.

Accessing getattr(llm, recipe_name) unguarded can raise.

-def get_finetune_recipe(recipe_name: str):
-    if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
-        raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
-    return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
+def get_finetune_recipe(recipe_name: str):
+    recipe_mod = getattr(llm, recipe_name, None)
+    if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"):
+        raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe")
+    return recipe_mod.finetune_recipe(peft_scheme=None)
examples/nemo_run/qat/nemo_qat_flow.py (3)

137-143: Fix KV-cache tri-state; don’t force disable by default.

Currently, --disable_kv_cache is always passed unless --enable_kv_cache is given. Make it tri-state and only append a flag when explicitly set.

Apply:

@@
-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    kv_group = parser.add_mutually_exclusive_group()
+    kv_group.add_argument(
+        "--enable_kv_cache",
+        dest="enable_kv_cache",
+        help="Enable KV-cache quantization",
+        action="store_true",
+    )
+    kv_group.add_argument(
+        "--disable_kv_cache",
+        dest="enable_kv_cache",
+        help="Disable KV-cache quantization",
+        action="store_false",
+    )
+    parser.set_defaults(enable_kv_cache=None)
@@
-    ptq = run.Script(
+    # Build KV-cache flag only when explicitly set
+    kv_cache_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
@@
-            "--kv_cache_qformat",
-            args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            "--kv_cache_qformat",
+            args.kv_cache_qformat,
+            *kv_cache_flag,

Also applies to: 183-201


146-153: Separate “recipe name” from “filesystem prefix”; avoid getattr before validation.

Avoid overloading model_name for two concepts; use recipe_name for lookup and ckpt_prefix for paths.

 def main(args):
-    if not args.distill and not args.finetune_recipe:
+    if not args.distill and not args.finetune_recipe:
         raise ValueError("If distillation is not used, --finetune-recipe must be specified")
-    model_name = args.finetune_recipe
-    model_module = getattr(llm, model_name)
-    if not model_name:
-        model_name = os.path.basename(args.model_name)
+    recipe_name = args.finetune_recipe
+    model_module = getattr(llm, recipe_name)
+    ckpt_prefix = recipe_name or os.path.basename(args.model_name)
@@
-    bf16_ckpt_path = f"{exp_dir}/{model_name}-nemo"
+    bf16_ckpt_path = f"{exp_dir}/{ckpt_prefix}-nemo"
@@
-    ptq_model_out = f"{exp_dir}/{model_name}-{args.algorithm}"
+    ptq_model_out = f"{exp_dir}/{ckpt_prefix}-{args.algorithm}"
@@
-    export = run.Partial(
-        export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{model_name}_hf"
-    )
+    export = run.Partial(
+        export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{ckpt_prefix}_hf"
+    )

Also applies to: 171-173, 181-182, 248-251


274-283: Use dedicated executors for PTQ, Train, and Export; don’t mutate a single SlurmExecutor.

Reusing and mutating gpu_executor can request the wrong GPU count for later stages (and some fields aren’t honored after init). Create separate executors and use single_gpu_executor for export.

@@
-    if args.use_slurm:
-        cpu_executor = create_slurm_executor(SLURM_CONFIG)
-        gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
-        )
-        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
+    if args.use_slurm:
+        cpu_executor = create_slurm_executor(SLURM_CONFIG)
+        ptq_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
+        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
+        )
+        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
     else:
-        cpu_executor = single_gpu_executor = run.LocalExecutor()
-        gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        cpu_executor = single_gpu_executor = run.LocalExecutor()
+        ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus)
@@
-        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1])
+        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1])
@@
-        if args.use_slurm:  # Set training arguments
-            gpu_executor.nodes = args.train_nodes
-            gpu_executor.gpus_per_node = gpu_executor.ntasks_per_node = args.train_gpus
-        else:
-            gpu_executor.ntasks_per_node = args.train_gpus
+        # Use dedicated training executor; no mutation required
         train_dep = [s3]
@@
-        s4 = exp.add(
-            train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep
-        )
+        s4 = exp.add(
+            train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep
+        )
@@
-        gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
+        exp.add(
             export,
             tail_logs=True,
             name="07_export_hf",
-            executor=gpu_executor,
+            executor=single_gpu_executor,
             dependencies=[s5],
         )

Also applies to: 301-303, 311-316, 320-321, 330-336

examples/nemo_run/qat/ADVANCED.md (1)

5-5: Avoid instructing users to hardcode credentials; prefer env/config-driven SlurmConfig.

Recommend documenting required env vars or a YAML config and loading it in nemo_qat_flow.py instead of editing source.

-To run the example on slurm, edit the `SLURM_CONFIG` at the bottom of `nemo_qat_flow.py` with the appropriate credentials, container, cluster name (host), and container mounts. Make sure you are mounting the NeMo and Megatron-LM repositories above in the Slurm cluster and that you've checked out the correct commits.
+To run on Slurm, set the required environment variables (e.g., SLURM_ACCOUNT, SLURM_PARTITION_CPU/GPU, SLURM_TIME, SLURM_HOST, SLURM_USER, SLURM_JOB_DIR, CONTAINER_IMAGE, CONTAINER_MOUNTS) and construct `SLURM_CONFIG` from them in `nemo_qat_flow.py`. Avoid committing secrets or editing source with credentials.

If you want, I can send a follow-up PR wiring env-var loading.

examples/nemo_run/qat/README.md (2)

86-94: Move Slurm flag details to ADVANCED.md; keep this section model-focused.
This repeats Slurm specifics already covered elsewhere.

-Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. On Slurm you can configure the number of nodes/gpus for training and PTQ with the following flags: `--train-nodes`, `--train-gpus`, `--ptq-gpus`.
+Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. For Slurm configuration (nodes/GPUs for training and PTQ), see [Slurm Examples](ADVANCED.md).
#!/bin/bash
# Verify flags mentioned in docs exist in the CLI.
rg -nP -C2 '(--train-nodes|--train-gpus|--ptq-gpus|--enable-kv-cache|--kv-cache-qformat|--distill)\b' examples -S

41-41: Consider moving “Usage” right after “Overview”.
Improves discoverability; aligns with prior feedback.

🧹 Nitpick comments (18)
modelopt/torch/export/plugins/nemo_run.py (3)

44-47: Make most-recent selection deterministic.

Tie on mtime yields nondeterministic picks. Add name as a secondary key.

-    most_recent = max(subdirs, key=lambda x: x.stat().st_mtime)
+    most_recent = max(subdirs, key=lambda x: (x.stat().st_mtime, x.name))

24-35: Add return annotation and clarify API intent.

Small polish: annotate return type and note that input expects a NeMo Run experiment root containing a default/checkpoints hierarchy.

-def export_most_recent_ckpt(directory: str, output_path: str):
+def export_most_recent_ckpt(directory: str, output_path: str) -> None:
     """Export most recent checkpoint from a NeMo Run experiment directory.
+    
+    `directory` should be the experiment root containing `default/` with
+    either `default/checkpoints/*` or `default/<run>/checkpoints/*`.
     """

50-53: Avoid importing a private helper across modules; promote to public API.

This module’s _get_most_recent_ckpt is imported from another file. Either rename to a public helper or re-export a public alias.

-def _get_most_recent_ckpt(directory: str):
+def get_most_recent_ckpt(directory: str):
@@
-    return str(most_recent)
+    return str(most_recent)

Then adjust imports in the caller to use get_most_recent_ckpt. If you prefer to keep the private name, add get_most_recent_ckpt = _get_most_recent_ckpt at module end.

Also applies to: 59-73

examples/nemo_run/common/in_memory_mmlu.py (3)

26-29: Fix spacing in help description.

The concatenated strings currently render “…--nemo_ckptor…”. Add a space.

-            "Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt"
-            "or --finetuned_ckpt_dir"
+            "Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt "
+            "or --finetuned_ckpt_dir"

50-53: Don’t import/use a private helper; switch to public API.

Use get_most_recent_ckpt (see plugin comment) instead of _get_most_recent_ckpt.

-from modelopt.torch.export.plugins.nemo_run import _get_most_recent_ckpt
+from modelopt.torch.export.plugins.nemo_run import get_most_recent_ckpt
@@
-    if args.finetuned_ckpt_dir:
-        ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir)
+    if args.finetuned_ckpt_dir:
+        ckpt_path = get_most_recent_ckpt(args.finetuned_ckpt_dir)

50-52: Log chosen checkpoint for traceability.

Helpful when resolving “wrong model evaluated” issues.

     ckpt_path = args.nemo_ckpt
     if args.finetuned_ckpt_dir:
         ckpt_path = get_most_recent_ckpt(args.finetuned_ckpt_dir)
+    print(f"Evaluating checkpoint: {ckpt_path}")
examples/nemo_run/common/process_openscience.py (1)

57-59: Create parent dirs when making proc_dir.

Avoid failures if parent path is missing.

-        Path(proc_dir).mkdir(exist_ok=True)
+        Path(proc_dir).mkdir(parents=True, exist_ok=True)
examples/nemo_run/common/utils.py (2)

107-117: Apply env vars and consider ntasks_per_node for CPU branch.

Parity with GPU path and user expectations.

-        return run.SlurmExecutor(
+        return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_cpu,
+            ntasks_per_node=ntasks_per_node,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
+            env_vars=slurm_cfg.env_vars,  # verify exact param name
         )

79-82: Ensure LocalTunnel has a usable job_dir.

If use_local_tunnel=True and job_dir is empty, LocalTunnel may fail.

-        tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir)
+        tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir or ".")

Alternatively, validate job_dir for both tunnel modes in __post_init__.

examples/nemo_run/qat/nemo_qat_flow.py (2)

155-169: Remove stale TODO and clarify path handling.

Comments reference common/process.py, but the code uses process_openscience.py. Clean up to avoid confusion.

-    # TODO figure out path
-    # LOCALLY common/process.py works
-    # On slurm examples/nemo_run/common/process.py works

118-129: Naming clarity: consider --train-gpus-per-node (and possibly --ptq-gpus-per-node).

Avoid ambiguity between total GPUs vs per-node. Aligns with Slurm executor semantics.

examples/nemo_run/qat/ADVANCED.md (1)

11-16: Add fenced code languages for markdownlint (MD040).

Annotate blocks as text to satisfy linters.

-```
+```text
 qat_flow_ckpts qat_flow_ckpts_1755708286

@@
- +text
├── 00_openscience_data
│   ├── code
│   ├── configs
│   ├── log-coreai_dlalgo_modelopt-modelopt.00_openscience_data_5345664_0.out
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.00_openscience_data_5345664.out
...

Also applies to: 19-56

examples/llm_qat/README.md (1)

14-15: LGTM — helpful cross-link to NeMo QAT/QAD flow.

Consider adding a Docs link when available.

examples/nemo_run/qat/README.md (5)

5-7: Deduplicate header links to ADVANCED.md.
Both “Slurm Examples” and “Advanced Topics” point to the same file; keep one.

-[Slurm Examples](ADVANCED.md) |
-[Advanced Topics](ADVANCED.md) |
+[Advanced Topics (incl. Slurm)](ADVANCED.md) |

45-45: Fix double space.

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

49-56: Avoid bind-mounting into site-packages; prefer PYTHONPATH.
Mounting into /usr/local/lib/python3.12/dist-packages is version-specific and brittle. Use PYTHONPATH to point at local checkouts (or pip install -e inside the container).

-Example docker command:
+Example docker command (avoid writing into site-packages; use PYTHONPATH to point to your local checkouts):
@@
-docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
+docker run --gpus all -it --shm-size 20g --rm \
+  -v /home/user/NeMo:/opt/NeMo \
+  -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \
+  -e PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer \
+  nvcr.io/nvidia/nemo:25.07 bash

58-59: Branding and permissions guidance.
Use “Hugging Face” and avoid chmod 777; point users to a writable log dir instead.

-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
+Set your Hugging Face token with `export HF_TOKEN=<your-token>`. Ensure your `--log-dir` is a writable path you own (e.g., `mkdir -p /my/log/dir`), rather than using `chmod 777`.

66-66: Minor wording/brand capitalization and punctuation.

-From the `nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script. To use a different model than the default model (Qwen3-8B), you can add the `--model-name <hf-model-name> --finetune-recipe <recipe-name>` flags and use the model's HuggingFace name and NeMo recipe names listed [here](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes). To provide your own custom dataset, use the `--data-path` flag, otherwise the default [NVIDIA OpenScience](https://huggingface.co/datasets/nvidia/OpenScience) dataset will be used.
+From the `nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script. To use a different model than the default model (Qwen3-8B), you can add the `--model-name <hf-model-name> --finetune-recipe <recipe-name>` flags and use the model's Hugging Face name and NeMo recipe names listed [here](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes). To provide your own custom dataset, use the `--data-path` flag; otherwise, the default [NVIDIA OpenScience](https://huggingface.co/datasets/nvidia/OpenScience) dataset will be used.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1cf78b2 and 9cc3625.

📒 Files selected for processing (8)
  • examples/llm_qat/README.md (1 hunks)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/common/process_openscience.py (1 hunks)
  • examples/nemo_run/common/utils.py (1 hunks)
  • examples/nemo_run/qat/ADVANCED.md (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
  • _get_most_recent_ckpt (50-73)
modelopt/torch/utils/plugins/megatron_mmlu.py (1)
  • megatron_mmlu (65-152)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
  • download_hf_dataset (29-37)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-35)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-45)
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/ADVANCED.md

13-13: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


19-19: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (7)
examples/nemo_run/common/in_memory_mmlu.py (1)

57-58: Good: devices computed from TP x PP.

Matches the prior suggestion to wire devices to parallelism settings.

examples/nemo_run/common/utils.py (1)

91-106: Verify SlurmExecutor accepts an env_vars parameter
We weren’t able to locate or import SlurmExecutor in this repo or via nemo_run—please confirm (via your installed package or docs) that its constructor supports an env_vars argument before merging.

examples/nemo_run/qat/nemo_qat_flow.py (1)

183-201: Confirmed ptq.py CLI arguments: -nc maps to --nemo_checkpoint, -out to the export/save path, and -ctp to the calibration tensor-parallel size (torchrun’s nproc_per_node must match this). No changes needed.

examples/nemo_run/qat/ADVANCED.md (1)

11-12: Verify rsync availability in the referenced container/tag.

If missing in nvcr.io/nvidia/nemo:25.07, add install steps or suggest an alternative sync method.

examples/nemo_run/qat/README.md (3)

21-29: Flow stages list looks accurate and uses exact task names.
Matches the 8-stage design and prior feedback.


30-39: Mermaid diagram aligns with the listed stages.
Edges reflect the intended data/control flow.


15-16: Good fix on the PTQ sentence.
Clear explanation of expected accuracy drop and QAT/QAD intent.

Comment on lines +146 to +155
def main(args):
if not args.distill and not args.finetune_recipe:
raise ValueError("If distillation is not used, --finetune-recipe must be specified")
model_name = args.finetune_recipe
model_module = getattr(llm, model_name)
if not model_name:
model_name = os.path.basename(args.model_name)
exp_dir = f"{args.log_dir.rstrip('/')}/{args.experiment}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guard globals used by main() to support programmatic invocation.

SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL are set under __main__ but read inside main(). If imported and called directly, this raises NameError. Add fallbacks in main().

 def main(args):
+    # Fallbacks if module constants aren’t initialized (e.g., when imported)
+    global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL
+    SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+    MBS = globals().get("MBS", 1)
+    GBS = globals().get("GBS", 512)
+    TRAIN_STEPS = globals().get("TRAIN_STEPS", 200)
+    VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)

Optionally, hoist the constants to module scope instead of defining under __main__.

Also applies to: 220-231, 368-376

🤖 Prompt for AI Agents
In examples/nemo_run/qat/nemo_qat_flow.py around lines 146-153 (also apply same
change to blocks at 220-231 and 368-376), main() reads globals SEQUENCE_LENGTH,
MBS, GBS, TRAIN_STEPS, VAL_INTERVAL which are only defined under __main__,
causing NameError when the module is imported and main() is called
programmatically; add explicit fallbacks inside main() by reading each with
globals().get("NAME", <sensible_default>) or os.environ fallback (e.g.,
SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 2048) etc.), or better
move/hoist the constant definitions from the __main__ block to module scope so
they are always defined before main() is used; update all referenced locations
(lines listed) consistently.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
examples/nemo_run/qat/nemo_qat_flow.py (4)

351-351: Slurm time format fixed. LGTM.

The HH:MM:SS format resolves SlurmConfig validation errors.


138-143: Fix KV-cache flag tri-state; don’t force disable by default.

When the user passes neither flag, we should omit both PTQ flags. Currently, --disable_kv_cache is always injected because enable_kv_cache defaults to False. Build the flag list conditionally and parse a tri-state.

@@
-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    kv_group = parser.add_mutually_exclusive_group(required=False)
+    kv_group.add_argument(
+        "--enable_kv_cache",
+        dest="enable_kv_cache",
+        action="store_true",
+        help="Enables KV-cache quantization",
+    )
+    kv_group.add_argument(
+        "--disable_kv_cache",
+        dest="enable_kv_cache",
+        action="store_false",
+        help="Disables KV-cache quantization",
+    )
+    parser.set_defaults(enable_kv_cache=None)
@@
-    ptq = run.Script(
+    # Build KV-cache flag only when explicitly set
+    kv_cache_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
         "/opt/NeMo/scripts/llm/ptq.py",
         args=[
             "-nc",
             bf16_ckpt_path,
             "-out",
             ptq_model_out,
             "--export_format",
             "nemo",
             "--algorithm",
             args.algorithm,
             "--kv_cache_qformat",
             args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            *kv_cache_flag,
             "-ctp",
             f"{args.ptq_gpus}",
         ],
         entrypoint="python",
     )

Also applies to: 183-199


274-283: Use separate executors for PTQ, Train, and Export; stop mutating a shared Slurm executor.

Mutating gpu_executor risks stale gres/gpus_per_node in Slurm and incorrect local torchrun settings. Create dedicated executors and use the single-GPU executor for export.

@@
-    if args.use_slurm:
-        cpu_executor = create_slurm_executor(SLURM_CONFIG)
-        gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
-        )
-        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
-    else:
-        cpu_executor = single_gpu_executor = run.LocalExecutor()
-        gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+    if args.use_slurm:
+        cpu_executor = create_slurm_executor(SLURM_CONFIG)
+        ptq_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
+        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
+        )
+        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
+    else:
+        cpu_executor = single_gpu_executor = run.LocalExecutor()
+        ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
+        train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus)
@@
-        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1])
+        s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1])
@@
-        if args.use_slurm:  # Set training arguments
-            gpu_executor.nodes = args.train_nodes
-            gpu_executor.gpus_per_node = gpu_executor.ntasks_per_node = args.train_gpus
-        else:
-            gpu_executor.ntasks_per_node = args.train_gpus
+        # use dedicated training executor; no mutation needed
@@
-        s4 = exp.add(
-            train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep
-        )
+        s4 = exp.add(
+            train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep
+        )
@@
-        gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
-            export,
-            tail_logs=True,
-            name="07_export_hf",
-            executor=gpu_executor,
-            dependencies=[s5],
-        )
+        exp.add(
+            export,
+            tail_logs=True,
+            name="07_export_hf",
+            executor=single_gpu_executor,
+            dependencies=[s5],
+        )

Also applies to: 302-303, 311-316, 320-321, 330-336


146-153: Guard module-level constants and SLURM_CONFIG for programmatic main() calls.

Importing and calling main(args) will raise NameError for undefined globals and SLURM_CONFIG. Add safe fallbacks and a clear error when Slurm config is missing.

 def main(args):
+    # Fallbacks if __main__ block wasn't executed
+    global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL
+    SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+    MBS = globals().get("MBS", 1)
+    GBS = globals().get("GBS", 512)
+    TRAIN_STEPS = globals().get("TRAIN_STEPS", 200)
+    VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)
+    if args.use_slurm and "SLURM_CONFIG" not in globals():
+        raise ValueError(
+            "SLURM_CONFIG must be defined before calling main() with --use-slurm. "
+            "See the __main__ section for an example."
+        )
🧹 Nitpick comments (3)
examples/nemo_run/qat/nemo_qat_flow.py (3)

147-149: Dead condition; can be removed.

--finetune-recipe has a default, so this branch never triggers.

-    if not args.distill and not args.finetune_recipe:
-        raise ValueError("If distillation is not used, --finetune-recipe must be specified")

151-153: Unreachable fallback for model_name.

model_name is set from args.finetune_recipe and is always truthy; this block never runs.

-    if not model_name:
-        model_name = os.path.basename(args.model_name)

29-31: Avoid mutating sys.path; prefer package-relative imports.

Relying on sys.path.append is brittle for both local and Slurm. If feasible, package examples and use absolute imports (or compute script paths without altering sys.path).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9cc3625 and d46399c.

📒 Files selected for processing (1)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-35)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
examples/nemo_run/qat/nemo_qat_flow.py (1)

338-339: Confirm detach behavior.

exp.run(detach=True) backgrounds the run. If interactive logs are desired by default, consider making this a CLI flag (e.g., --detach).

@@ def get_args():
+    parser.add_argument("--detach", action="store_true", help="Run experiment detached", default=False)
@@
-        exp.run(detach=True)
+        exp.run(detach=args.detach)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
examples/nemo_run/qat/nemo_qat_flow.py (5)

138-142: Fix the KV-cache flag tri-state handling

The current implementation only uses --enable_kv_cache when the flag is set to True, but always defaults to --disable_kv_cache otherwise (line 198). This prevents a neutral state where neither flag is passed.

Apply this diff to properly handle the tri-state:

     parser.add_argument(
         "--enable_kv_cache",
         help="Enables KV-cache quantization",
         action="store_true",
-        default=False,
     )
+    parser.add_argument(
+        "--disable_kv_cache",
+        help="Disables KV-cache quantization",
+        action="store_true",
+    )

Then update the PTQ args construction:

+    # Build KV-cache flag only when explicitly set
+    kv_cache_args = []
+    if args.enable_kv_cache:
+        kv_cache_args.append("--enable_kv_cache")
+    elif args.disable_kv_cache:
+        kv_cache_args.append("--disable_kv_cache")
+
     ptq = run.Script(
         "/opt/NeMo/scripts/llm/ptq.py",
         args=[
             "-nc",
             bf16_ckpt_path,
             "-out",
             ptq_model_out,
             "--export_format",
             "nemo",
             "--algorithm",
             args.algorithm,
             "--kv_cache_qformat",
             args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            *kv_cache_args,
             "-ctp",
             f"{args.ptq_gpus}",
         ],

247-249: Consider tensor parallelism configuration based on GPU count

Setting both tensor and pipeline parallelism to 1 by default means the model runs with data parallelism only. For larger models or when using fewer GPUs with limited memory, tensor parallelism might be required.

Consider automatically setting tensor parallelism based on the number of GPUs:

     train.trainer.strategy.tensor_model_parallel_size = args.tensor_parallelism
     train.trainer.strategy.pipeline_model_parallel_size = args.pipeline_parallelism
+    # Optionally, auto-configure TP if not specified and model is large
+    # if args.tensor_parallelism == 1 and "70B" in args.model_name:
+    #     train.trainer.strategy.tensor_model_parallel_size = min(8, args.train_gpus)

279-286: Verify that GPUs are not overprovisioned in Slurm executors

Creating separate executors for different GPU counts looks correct now. This properly addresses the previous issue of reusing executors with mutated parameters.


360-360: Slurm time format is correct

The time format "04:00:00" follows the HH:MM:SS format required by SlurmConfig validation.


226-233: Global constants are undefined - will cause NameError

The code references SEQUENCE_LENGTH, GBS, MBS, TRAIN_STEPS, and VAL_INTERVAL which are only defined within the __main__ block (lines 375-379). This will raise a NameError when main() is called.

Move the constants to module scope before the main() function:

 import nemo_run as run
 from nemo.collections import llm
 # ... other imports ...

+# Configurable parameters
+SEQUENCE_LENGTH = 4096
+MBS = 1
+GBS = 512
+TRAIN_STEPS = 400
+VAL_INTERVAL = 50
+
 def get_args():

And remove the duplicate definitions from the __main__ block (lines 373-380).

🧹 Nitpick comments (5)
examples/nemo_run/qat/nemo_qat_flow.py (5)

339-340: WAR comment indicates a NeMo export bug

The workaround mutates train_gpu_executor.ntasks_per_node = 1 to handle a NeMo bug. This is fragile and could break if the executor is reused.

Consider using a dedicated single-task executor for export to avoid mutation:

-        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
-        train_gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
+        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
+        export_executor = create_slurm_executor(
+            SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=1
+        ) if args.use_slurm else run.LocalExecutor()
         exp.add(
             export,
             tail_logs=True,
             name="07_export_hf",
-            executor=train_gpu_executor,
+            executor=export_executor,
             dependencies=[s5],
         )

Would you like me to open an issue to track the underlying NeMo export bug?


357-370: Empty Slurm configuration values will raise validation errors

The SlurmConfig has empty strings for required fields like account, host, and user. These will trigger validation errors in SlurmConfig.__post_init__().

Add a comment to guide users on what needs to be configured:

     if args.use_slurm:
+        # IMPORTANT: Fill in these required fields before running on Slurm
         SLURM_CONFIG = SlurmConfig(
-            account="",
+            account="",  # REQUIRED: Your Slurm account name
             partition_gpu="batch",
             partition_cpu="cpu",
             time="04:00:00",
             container_image="nvcr.io/nvidia/nemo:25.07",
             env_vars={
-                "HF_TOKEN": "",
+                "HF_TOKEN": "",  # REQUIRED if using gated HF models
             },
             use_local_tunnel=False,
-            host="",
-            user="",
+            host="",  # REQUIRED: Slurm cluster hostname (e.g., "cluster.example.com")
+            user="",  # REQUIRED: Your username on the cluster
             container_mounts=[],
-            job_dir="/path/to/logs",
+            job_dir="/path/to/logs",  # REQUIRED: Directory for job logs on cluster
             identity=None,
         )

143-144: Parallelism arguments lack help text

The tensor and pipeline parallelism arguments are missing help descriptions, making their purpose unclear to users.

-    parser.add_argument("--tensor_parallelism", type=int, default=1)
-    parser.add_argument("--pipeline_parallelism", type=int, default=1)
+    parser.add_argument(
+        "--tensor_parallelism", 
+        type=int, 
+        default=1,
+        help="Tensor parallelism degree for model training"
+    )
+    parser.add_argument(
+        "--pipeline_parallelism", 
+        type=int, 
+        default=1,
+        help="Pipeline parallelism degree for model training"
+    )

347-347: Consider adding timeout handling for long-running experiments

The experiment runs with detach=True, which is good for long-running jobs. However, there's no timeout or monitoring mechanism for stuck jobs.

Consider adding a comment about monitoring:

-        exp.run(detach=True)
+        exp.run(detach=True)  # Detached mode for long-running jobs
+        # Monitor job progress via: exp.status() or check logs at {exp_dir}

158-164: Resolve process_openscience.py path to an absolute path

File: examples/nemo_run/qat/nemo_qat_flow.py (lines 158–164)

Confirmed examples/nemo_run/common/process_openscience.py exists in the repo; the Slurm branch’s use of a CWD-relative literal can still fail at runtime — resolve the script path relative to this file and validate existence before calling run.Script. Suggested replacement:

openscience_path = os.path.abspath(os.path.join(os.path.dirname(file), "..", "common", "process_openscience.py"))
if not os.path.exists(openscience_path):
raise FileNotFoundError(f"OpenScience script not found at {openscience_path}")
openscience_data = run.Script(openscience_path)

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d46399c and c97f9bd.

📒 Files selected for processing (1)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-35)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/qat/nemo_qat_flow.py (2)

257-276: MMLU evaluation script path logic matches OpenScience pattern

The conditional path resolution for the MMLU script follows the same pattern as the OpenScience script, maintaining consistency.


235-239: Add graceful handling for missing distillation recipe

get_finetune_recipe already validates and raises ValueError (examples/nemo_run/common/utils.py:120–123). distillation_recipe is imported at examples/nemo_run/qat/nemo_qat_flow.py:24 but its implementation is not in this repo — wrap the distillation call (examples/nemo_run/qat/nemo_qat_flow.py:235–239) in a presence check or try/except to surface a clear error like the finetune path.

@jenchen13 jenchen13 requested review from a team as code owners September 16, 2025 21:25
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (2)

65-74: Fix return type mismatch in load_engine (function returns a tuple).

The annotation/doc promise an engine, but the function returns (engine | None, error_str). Correct the signature/doc to avoid downstream type/logic bugs.

Apply:

-def load_engine(buffer: bytes, log_level: int = trt.Logger.ERROR) -> trt.tensorrt.ICudaEngine:
-    """Load a TensorRT engine from engine data and return."""
+def load_engine(
+    buffer: bytes, log_level: int = trt.Logger.ERROR
+) -> tuple[trt.tensorrt.ICudaEngine | None, str]:
+    """Load a TensorRT engine from engine data.
+
+    Returns:
+        (engine | None, error_message)
+    """

170-179: Avoid double‑hashing; fix doc typo.

hashlib.sha256(engine_bytes) already digests the data; calling update(engine_bytes) again computes SHA256(engine_bytes || engine_bytes).

-def prepend_hash_to_bytes(engine_bytes: bytes) -> bytes:
-    """Prepend the engine bytes with the SHA256 hash of the engine bytes
-    This has will serve as a unique identifier for the engine and will be used to manage
-    TRTSessions in the TRTClient.
-    """
-    hash_object = hashlib.sha256(engine_bytes)
-    hash_object.update(engine_bytes)
-    hash_bytes = hash_object.digest()
-    engine_bytes = hash_bytes + engine_bytes
-    return engine_bytes
+def prepend_hash_to_bytes(engine_bytes: bytes) -> bytes:
+    """Prepend the engine bytes with the SHA256 hash of the engine bytes.
+    This hash serves as a unique identifier for the engine and is used to manage
+    TRTSessions in the TRTClient.
+    """
+    hash_bytes = hashlib.sha256(engine_bytes).digest()
+    return hash_bytes + engine_bytes
examples/speculative_decoding/launch.sh (1)

92-95: Fix divide‑by‑zero when no GPU is present.

GPU_COUNT can be 0 causing an arithmetic error; fall back to 1 before dividing. Apply at examples/speculative_decoding/launch.sh (around lines 92–95):

 GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+# Calculate save_steps (fallback to 1 when no GPU is detected)
+if [[ -z "$GPU_COUNT" || "$GPU_COUNT" -le 0 ]]; then
+  GPU_COUNT=1
+fi
+DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))

Test on a CPU-only node to confirm the script no longer exits early.

modelopt/onnx/quantization/qdq_utils.py (1)

639-642: Create FP8 ONNX tensors with correct dtype (bug fix).

numpy_helper.from_array will set dtype to UINT8, not FLOAT8. Use onnx.helper.make_tensor(..., data_type=FLOAT8, raw=True) like the MXFP8 path below.

Apply this diff:

-def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
-    """Create a FLOAT8E4M3FN tensor directly from numpy array."""
-    fp8_data = _cast_fp8(scaled)
-    return onnx.numpy_helper.from_array(fp8_data, weight_name)
+def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
+    """Create a FLOAT8E4M3FN tensor directly from numpy array."""
+    fp8_bytes = _cast_fp8(scaled).tobytes()
+    return onnx.helper.make_tensor(
+        name=weight_name,
+        data_type=onnx_dtype_map["Float8"],
+        dims=list(scaled.shape),
+        vals=fp8_bytes,
+        raw=True,
+    )
tests/unit/onnx/test_qdq_utils.py (1)

35-36: Fix MatMul shape inconsistency in the synthetic graph (pre/post-quantization).

As written, the MatMul inputs are dimensionally incompatible both before and after the Reshape/Transpose removal. Make the reshape produce (8, 32) and keep the post‑transpose as (32, 8), then drive MatMul with input [..., 32] to yield output [..., 8]. Also avoid orphaning the original scale initializer when constant_scale=True.

Apply:

-    reshape_shape = np.array([16, 16], dtype=np.int64)
+    reshape_shape = np.array([8, 32], dtype=np.int64)
-    reshape_output_info = helper.make_tensor_value_info(
-        "reshape_output", TensorProto.FLOAT, [16, 16]
-    )
+    reshape_output_info = helper.make_tensor_value_info(
+        "reshape_output", TensorProto.FLOAT, [8, 32]
+    )
-    graph = helper.make_graph(
-        nodes=nodes,
-        name="test_graph",
-        inputs=[input_tensor],
-        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 16])],
-        initializer=[weight_tensor, scale_tensor],
-        value_info=[reshape_output_info],
-    )
+    initializers = [weight_tensor] if constant_scale else [weight_tensor, scale_tensor]
+    graph = helper.make_graph(
+        nodes=nodes,
+        name="test_graph",
+        inputs=[input_tensor],  # make sure this is [None, 32] above where defined
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 8])],
+        initializer=initializers,
+        value_info=[reshape_output_info],
+    )

Also update the input tensor shape at line 38 to [None, 32] to match. If you prefer not to change I/O shapes, alternatively set reshape_shape = [32, 8] and perm=[1,0] to be a no‑op on shape.

Also applies to: 91-96, 100-106

examples/speculative_decoding/export_hf_checkpoint.py (1)

38-49: Move execution under a main guard.

Avoids side effects on import and clarifies entrypoint.

Apply this diff:

-mto.enable_huggingface_checkpointing()
-
-args = parse_args()
-model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
-model.eval()
-with torch.inference_mode():
-    export_hf_checkpoint(
-        model,  # The quantized model.
-        export_dir=args.export_path,  # The directory where the exported files will be stored.
-    )
-print(f"Exported checkpoint to {args.export_path}")
+def main():
+    mto.enable_huggingface_checkpointing()
+    args = parse_args()
+    model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
+    model.eval()
+    with torch.inference_mode():
+        export_hf_checkpoint(model, export_dir=args.export_path)
+    print(f"Exported checkpoint to {args.export_path}")
+
+if __name__ == "__main__":
+    main()
♻️ Duplicate comments (5)
examples/nemo_run/qat/nemo_qat_flow.py (5)

137-145: KV‑cache flag should be tri‑state; avoid forcing disable by default.

Only append --enable_kv_cache/--disable_kv_cache when explicitly requested.

Apply:

-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    kv_group = parser.add_mutually_exclusive_group()
+    kv_group.add_argument(
+        "--enable_kv_cache",
+        dest="enable_kv_cache",
+        action="store_true",
+        help="Enables KV-cache quantization",
+    )
+    kv_group.add_argument(
+        "--disable_kv_cache",
+        dest="enable_kv_cache",
+        action="store_false",
+        help="Disables KV-cache quantization",
+    )
+    parser.set_defaults(enable_kv_cache=None)
-    ptq = run.Script(
+    # Build KV-cache flag only when explicitly set
+    kv_cache_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
         "/opt/NeMo/scripts/llm/ptq.py",
         args=[
@@
-            "--kv_cache_qformat",
-            args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            "--kv_cache_qformat",
+            args.kv_cache_qformat,
+            *kv_cache_flag,

Also applies to: 185-203


360-361: Slurm time format fix acknowledged.

Using HH:MM:SS (“04:00:00”) resolves SlurmConfig validation.


338-346: Don’t mutate the train executor for export; use the single‑GPU executor to avoid Slurm resource mismatches.

Changing only ntasks_per_node leaves gres/gpus_per_node stale on Slurm and can cause mis-scheduled jobs. Reuse the already-created single_gpu_executor for export and drop the mutation.

Apply:

-        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
-        train_gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
+        exp.add(
             export,
             tail_logs=True,
             name="07_export_hf",
-            executor=train_gpu_executor,
+            executor=single_gpu_executor,
             dependencies=[s5],
         )

Also applies to: 339-339


222-233: Guard globals and config for programmatic invocation.

SEQUENCE_LENGTH/GBS/MBS/TRAIN_STEPS/VAL_INTERVAL and SLURM_CONFIG are referenced inside main() but defined under main. Importing and calling main(args) will raise NameError.

Apply:

 def main(args):
+    # Ensure module-scope defaults exist when called programmatically
+    global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL, SLURM_CONFIG
+    SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+    MBS = globals().get("MBS", 1)
+    GBS = globals().get("GBS", 512)
+    TRAIN_STEPS = globals().get("TRAIN_STEPS", 400)
+    VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)
+    if args.use_slurm and "SLURM_CONFIG" not in globals():
+        raise ValueError(
+            "SLURM_CONFIG must be defined (see __main__ block) when --use-slurm is set."
+        )

Also applies to: 278-287, 375-380


283-286: Pass nodes to the Slurm training executor.

Multi-node training won’t be honored without nodes=args.train_nodes.

Apply:

-        train_gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
-        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG,
+            nodes=args.train_nodes,
+            num_gpus=args.train_gpus,
+            ntasks_per_node=args.train_gpus,
+        )
🧹 Nitpick comments (77)
modelopt/onnx/autocast/precisionconverter.py (4)

179-184: Use HasField and clear the oneof when anonymizing dim_value.

if d.dim_value: misses dimensions explicitly set to 0 and doesn't make the oneof switch explicit. Prefer HasField("dim_value") and clear the field before setting dim_param. If the intent is to preserve explicit 0-dims, gate with and d.dim_value != 0.

Apply this diff in the value_info loop:

-            for vi in self.model.graph.value_info:
-                vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
-                for idx, d in enumerate(vi.type.tensor_type.shape.dim):
-                    if d.dim_value:
-                        vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
+            for vi in self.model.graph.value_info:
+                vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
+                for d in vi.type.tensor_type.shape.dim:
+                    # If preserving explicit 0-dims, add: and d.dim_value != 0
+                    if d.HasField("dim_value"):
+                        d.ClearField("dim_value")
+                        d.dim_param = "unk"

185-189: Mirror the safer oneof handling for graph outputs.

Same concern as above: use HasField("dim_value") and clear it before setting dim_param to avoid silently skipping explicit 0 or leaving implicit defaults ambiguous.

Apply this diff in the outputs loop:

-            for out in self.model.graph.output:
-                out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
-                for idx, d in enumerate(out.type.tensor_type.shape.dim):
-                    if d.dim_value:
-                        out.type.tensor_type.shape.dim[idx].dim_param = "unk"
+            for out in self.model.graph.output:
+                out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
+                for d in out.type.tensor_type.shape.dim:
+                    if d.HasField("dim_value"):
+                        d.ClearField("dim_value")
+                        d.dim_param = "unk"

191-194: Defaulting UNDEFINED types to low precision can mis-type non-float tensors.

Blindly setting all remaining UNDEFINED value_info to self.low_precision_type risks clobbering indices/bool/int tensors (e.g., TopK/ArgMax paths) and may only be caught late by strict checks. Gate this by the original element type (when known) or by neighboring float-only usage.

Apply this diff to constrain the assignment using the original value_info_map:

-            self._ensure_types_are_defined()
+            self._ensure_types_are_defined()

And update _ensure_types_are_defined per next comment.


207-213: Constrain fallback type assignment to originally-float tensors; add optional neighbor hint.

Use the original value_info_map to only default tensors that were float-typed before we cleared metadata. Optionally, fall back to a lightweight neighbor check.

Apply this diff:

-    def _ensure_types_are_defined(self):
-        """Ensure that all tensor types are defined."""
-        for vi in self.model.graph.value_info:
-            if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
-                vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type
+    def _ensure_types_are_defined(self):
+        """Ensure float tensors have defined types without clobbering non-float tensors."""
+        # Build a quick lookup of original elem_types (pre-mutation)
+        orig_types = {
+            name: info.type.tensor_type.elem_type
+            for name, info in (self.value_info_map or {}).items()
+        }
+        for vi in self.model.graph.value_info:
+            if vi.type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED:
+                continue
+            name = vi.name
+            # Prefer original knowledge
+            if name in orig_types and orig_types[name] in ONNX_TYPES:
+                vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type
+                continue
+            # Lightweight neighbor hint: if any producer/consumer is a Cast or float op, assume float
+            producers = utils.get_producer_nodes(self.model, name)
+            consumers = utils.get_consumer_nodes(self.model, name)
+            neighbors = producers + consumers
+            if any(n.op_type in {"Cast", "Add", "Mul", "MatMul", "Conv"} for n in neighbors):
+                vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type

Please confirm intent for 0-dim behavior (preserve vs. anonymize). Also consider adding a unit test covering:

  • value_info for ArgMax/TopK indices stays INT64;
  • tensors left UNDEFINED after first inference but belonging to float paths get defaulted to low_precision_type;
  • shapes with explicit 0 dims are preserved.
modelopt/torch/_deploy/utils/torch_onnx.py (2)

490-492: Clarify assertion message for MXFP8/INT4 mixed precision

Message reads as BF16-specific; make the constraint explicit for all quantized cases here.

-            assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
+            assert weights_dtype == "fp16", "Only FP16 weights are supported when the model is MXFP8 or INT4 quantized (BF16 unsupported)."

333-339: Avoid mutable default for dynamic_axes and pass only when provided

Using {} as a default can lead to unintended shared state; also skip passing empty dict to export.

-    dynamic_axes: dict = {},
+    dynamic_axes: dict | None = None,
@@
-        if not dynamo_export and Version(torch.__version__) >= Version("2.8"):
-            additional_kwargs["dynamic_axes"] = dynamic_axes
+        if not dynamo_export and Version(torch.__version__) >= Version("2.8") and dynamic_axes:
+            additional_kwargs["dynamic_axes"] = dynamic_axes

Also applies to: 435-437

examples/nemo_run/qat/nemo_qat_flow.py (2)

29-31: Avoid sys.path mutation; prefer package/relative imports.

Using sys.path.append can surprise downstream tools. If feasible, make examples a package and import via examples.nemo_run.common.utils or use relative imports.


149-156: Separate recipe name from model naming to avoid confusion.

model_name is overloaded (recipe vs HF model). The fallback branch is unreachable. Consider recipe_name = args.finetune_recipe for llm module lookup and model_dir = os.path.basename(args.model_name) for path naming (bf16_ckpt_path/ptq_model_out).

Also applies to: 173-184

examples/onnx_ptq/torch_quant_to_onnx.py (5)

86-92: Avoid re-instantiating/downloading the model just to read input_size

Constructing a model here duplicates work (and with pretrained=True may trigger weight downloads). Derive input_size from timm’s pretrained cfg, falling back to a lightweight model with pretrained=False.

Apply this diff:

 def get_model_input_shape(model_name, batch_size):
     """Get the input shape from timm model configuration."""
-    model = timm.create_model(model_name, pretrained=True, num_classes=1000)
-    data_config = timm.data.resolve_model_data_config(model)
-    input_size = data_config["input_size"]
-    return (batch_size, *tuple(input_size))  # Add batch dimension
+    # Prefer config path to avoid heavyweight instantiation / downloads.
+    try:
+        cfg = timm.get_pretrained_cfg(model_name)
+        input_size = tuple(getattr(cfg, "input_size", cfg.get("input_size")))
+    except Exception:
+        # Fallback: create a lightweight model without pretrained weights.
+        model = timm.create_model(model_name, pretrained=False, num_classes=1000)
+        data_config = timm.data.resolve_model_data_config(model)
+        input_size = data_config["input_size"]
+    return (batch_size, *input_size)  # Add batch dimension

122-127: Validate --batch_size (> 0) to avoid runtime errors

Negative/zero batch sizes will break DataLoader and export assumptions. Add an argparse validator.

Apply this diff:

-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="Batch size for calibration and ONNX model export.",
-    )
+    parser.add_argument(
+        "--batch_size",
+        type=positive_int,
+        default=1,
+        help="Batch size for calibration and ONNX model export (must be > 0).",
+    )

Add this helper near the other imports:

def positive_int(v: str) -> int:
    iv = int(v)
    if iv <= 0:
        raise argparse.ArgumentTypeError("batch_size must be a positive integer")
    return iv

132-132: Remove duplicate model construction; derive input_shape from the instantiated model

You construct the model again at Line 136. Compute input_shape from that instance instead of creating one inside get_model_input_shape.

Apply this diff:

-    # Get input shape from model config
-    input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)
-
-    # Create model and move to appropriate device
+    # Create model and move to appropriate device
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
+    # Derive input shape from the instantiated model (avoid extra construction)
+    data_config = timm.data.resolve_model_data_config(model)
+    input_shape = (args.batch_size, *tuple(data_config["input_size"]))

55-67: Don’t preload calibration samples onto GPU; move batches in the forward loop + no_grad/eval

Preloading each sample to device inflates GPU memory and fights DataLoader workers. Keep tensors on CPU, use pin_memory, and move inside the loop with inference_mode. Also ensure eval() during calibration.

Apply this diff:

 def load_calibration_data(model_name, data_size, batch_size, device):
@@
-    images = dataset["train"][:data_size]["image"]
-    calib_tensor = [transforms(img) for img in images]
-    calib_tensor = [t.to(device) for t in calib_tensor]
+    images = dataset["train"][:data_size]["image"]
+    calib_tensor = [transforms(img) for img in images]
     return torch.utils.data.DataLoader(
-        calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4
+        calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
     )
@@
-        def forward_loop(model):
-            for batch in data_loader:
-                model(batch)
+        def forward_loop(model):
+            model.eval()
+            with torch.inference_mode():
+                for batch in data_loader:
+                    if isinstance(batch, (list, tuple)):
+                        batch = [b.to(next(model.parameters()).device, non_blocking=True) for b in batch]
+                    else:
+                        batch = batch.to(next(model.parameters()).device, non_blocking=True)
+                    model(batch)

Also applies to: 74-79


155-161: Expose dynamic_axes in export_to_onnx and forward it to the ONNX exporter

export_to_onnx (examples/onnx_ptq/download_example_onnx.py) builds a fixed-size dummy_input and calls get_onnx_bytes without any dynamic_axes, so the produced ONNX is fixed-batch.

  • Change export_to_onnx signature to accept an optional dynamic_axes and forward it to get_onnx_bytes / torch.onnx.export (or update modelopt/torch/_deploy/utils/torch_onnx.py if needed).
  • Update the call in examples/onnx_ptq/torch_quant_to_onnx.py to pass, e.g. dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}.
modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (2)

60-63: Return bytes, not bytearray, to match type hint.

-def get_engine_bytes(engine: trt.tensorrt.ICudaEngine) -> bytes:
-    """Return serialized TensorRT engine bytes."""
-    return bytearray(engine.serialize())  # type: ignore[return-value]
+def get_engine_bytes(engine: trt.tensorrt.ICudaEngine) -> bytes:
+    """Return serialized TensorRT engine bytes."""
+    return bytes(engine.serialize())

134-141: Minor: avoid parsing ONNX twice in calib_data_generator.

Parse once and reuse for input names and batch size.

-def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]):
+def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]):
     """The calibation data generator that yields calibration feed_dict to tensorrt."""
-    input_names = get_onnx_input_names(onnx.load_from_string(onnx_bytes))
-
-    batch_size = get_batch_size(onnx.load_from_string(onnx_bytes))
+    model = onnx.load_from_string(onnx_bytes)
+    input_names = get_onnx_input_names(model)
+    batch_size = get_batch_size(model)
modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py (1)

103-110: STRONGLY_TYPED intentionally treated as low‑bit; mapping present — document rationale

Verified: TRTMode.STRONGLY_TYPED is included in _is_low_bit_mode and TRT_MODE_FLAGS maps it to ["--stronglyTyped"] (modelopt/torch/_deploy/_runtime/tensorrt/constants.py); examples/README also use --builderOptimizationLevel=4 with --stronglyTyped, so forcing opt‑level=4 appears intentional.

  • Optional: add a brief inline comment explaining why strongly‑typed forces builderOptimizationLevel=4 or gate this behavior if different opt‑levels are expected for fp16/bf16 strongly‑typed builds.
examples/windows/onnx_ptq/genai_llm/requirements.txt (1)

2-2: Avoid ONNX version skew with setup.py (extras use ~=1.19.0).

Pinning 1.18.0 here while extras require ~=1.19.0 can cause resolver conflicts in mixed flows (dev installs extras + runs this example). Either align, or clearly scope this pin to Windows-only pipelines.

Two safe options:

-onnx==1.18.0
+onnx>=1.18,<1.20  # Aligns with repo extras (~=1.19.x) while keeping 1.18 compatibility

or keep 1.18.0 but document intent:

+# Note: ONNX pinned to 1.18.0 for Windows PTQ sample stability; repo extras target ~=1.19.x.
 onnx==1.18.0

Please confirm which constraint you prefer and that CI matrices won’t install both this file and the onnx extra together.

tests/unit/onnx/test_qdq_rules_int8.py (1)

98-101: Guard against graph inputs when inspecting Add inputs.

This list comp will crash if an Add input is a graph input (no producers). Mirror the new guard here.

-    add_input_ops = [inp.inputs[0].op for inp in add_node.inputs]
-    assert np.isin(add_input_ops, ["Conv", "DequantizeLinear"]).all(), (
+    add_inputs_with_producer = [inp for inp in add_node.inputs if inp.inputs]
+    add_input_ops = [inp.inputs[0].op for inp in add_inputs_with_producer]
+    assert set(add_input_ops) <= {"Conv", "DequantizeLinear"} and len(add_input_ops) == 2, (
         f"Add node {add_node.name} was not quantized correctly!"
     )

Please confirm test models never wire raw graph inputs into Add in this check path; if they can, the len==2 assertion is the right expectation.

examples/speculative_decoding/launch.sh (1)

153-154: Make reporting backend configurable (default remains tensorboard).

Hard‑coding tensorboard is OK, but a flag improves ergonomics and parity with wandb.

-    --report_to tensorboard \
+    --report_to ${REPORT_TO:-tensorboard} \

Add parser support:

@@
     --do_eval*)
       if [[ "$1" != *=* ]]; then shift; fi
       DO_EVAL="${1#*=}"
       ;;
+    --report_to*)
+      if [[ "$1" != *=* ]]; then shift; fi
+      REPORT_TO="${1#*=}"
+      ;;
modelopt/torch/export/plugins/__init__.py (1)

23-24: Avoid hard dependency via wildcard import; degrade gracefully.

Unconditional from .hf_spec_export import * can pull in heavy/optional deps and pollute the namespace. Prefer a guarded import.

-from .hf_spec_export import *
+try:
+    from .hf_spec_export import *
+except ImportError:
+    # hf_spec_export depends on optional stacks; keep plugins package importable without them
+    pass

If hf_spec_export has no optional deps, we can still import explicitly‑named symbols instead of * to keep API surface tight. Want me to propose an explicit all?

modelopt/onnx/quantization/int8.py (1)

127-129: Defaulting to fp16 is a behavior change; confirm CPU/ORT compatibility.

This will convert models to fp16/bf16 by default. Ensure:

  • CPU‑only pipelines and DML EP handle the converted graph without perf/accuracy regressions.
  • CLI/docs reflect the new default and how to opt out.

Optionally allow disabling via None:

-    high_precision_dtype: str = "fp16",
+    high_precision_dtype: str | None = "fp16",
@@
-    if high_precision_dtype in ["fp16", "bf16"]:
+    if high_precision_dtype in ["fp16", "bf16"]:
         ...

Please run a quick smoke on a CPU runner with ORT CPU EP to verify no op‑type falls back or dtype unsupported errors.

.github/workflows/unit_tests.yml (1)

122-127: Aggregator job added — consider making it resilient to upstream skips/failures.

Looks good for enforcing a single PR “unit tests complete” signal. To make the job always run and clearly fail when any prerequisite fails or is skipped, add job-level if: always() and an explicit failure check against needs.*.result.

-  unit-pr-required-check:
-    if: github.event_name == 'pull_request'
+  unit-pr-required-check:
+    # Always evaluate this job for PRs so it can fail if any prereq failed/skipped
+    if: github.event_name == 'pull_request' && always()
     needs: [linux, windows, multi-py, multi-torch, multi-transformers, partial-install]
     runs-on: ubuntu-latest
     steps:
-      - run: echo "All PR unit test jobs completed"
+      - name: Verify prerequisite jobs
+        run: |
+          echo "linux: ${{ needs.linux.result }}"
+          echo "windows: ${{ needs.windows.result }}"
+          echo "multi-py: ${{ needs.multi-py.result }}"
+          echo "multi-torch: ${{ needs.multi-torch.result }}"
+          echo "multi-transformers: ${{ needs.multi-transformers.result }}"
+          echo "partial-install: ${{ needs['partial-install'].result }}"
+          if [[ "${{ needs.linux.result }}" != "success" || \
+                "${{ needs.windows.result }}" != "success" || \
+                "${{ needs.multi-py.result }}" != "success" || \
+                "${{ needs.multi-torch.result }}" != "success" || \
+                "${{ needs.multi-transformers.result }}" != "success" || \
+                "${{ needs['partial-install'].result }}" != "success" ]]; then
+            echo "One or more unit test jobs did not succeed"
+            exit 1
+          fi
+      - run: echo "All PR unit test jobs completed"
tests/_test_utils/onnx_quantization/utils.py (1)

23-29: Unwrap chains of Cast nodes, not just a single Cast.

Current logic handles only one Cast. Robustly skip over multiple Casts to reach DQ.

-                producer = node.i(inp_idx)
-                # Quantized path may include a Cast right after DQ
-                if producer and producer.op == "Cast":
-                    producer = producer.i(0)
+                producer = node.i(inp_idx)
+                # Quantized path may include one or more Casts right after DQ
+                while producer and producer.op == "Cast":
+                    producer = producer.i(0)
modelopt/onnx/trt_utils.py (1)

419-424: Good: avoid empty quantize mappings.

Creating custom_ops_to_quantize[op_type] only when IOs exist is correct and prevents misleading empty config. Consider mirroring this for custom_ops_to_cast (skip when both lists are empty) for symmetry and cleaner downstream handling.

tests/_test_utils/import_helper.py (1)

80-93: Fix skip message logic and align skip behavior.

Message says “less than required” but you skip when ONNX is greater than 1.18. Also, add allow_module_level=True for consistency.

-def skip_if_onnx_version_above_1_18():
+def skip_if_onnx_version_above_1_18():
     package_name = "onnx"
     required_version = "1.18.0"
 
     try:
         installed_version = importlib.metadata.version(package_name)
     except importlib.metadata.PackageNotFoundError:
-        pytest.skip(f"{package_name} is not installed")
+        pytest.skip(f"{package_name} is not installed", allow_module_level=True)
 
     if version.parse(installed_version) > version.parse(required_version):
-        pytest.skip(
-            f"{package_name} version {installed_version} is less than required {required_version}"
-        )
+        pytest.skip(
+            f"{package_name} version {installed_version} is greater than allowed {required_version}",
+            allow_module_level=True,
+        )
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (2)

43-44: Place version gating at module import time to save setup cost.

Optional: invoke the skip helper at module scope so the entire file is skipped before CUDA/model setup.

- def test_int4_awq(tmp_path):
-    skip_if_onnx_version_above_1_18()
+skip_if_onnx_version_above_1_18()
+
+def test_int4_awq(tmp_path):

119-121: Order of skips is fine; minor consistency nit.

Calling ONNX version skip before libcudnn skip is OK. If you move the version skip to module scope, keep skip_if_no_libcudnn() first inside the test to preserve current behavior for environments without cuDNN.

examples/speculative_decoding/ar_validate.py (2)

29-31: Default sample count increased — confirm runtime/CI budget.

Bumping num_samples to 80 increases runtime. If this is used in CI, consider parameterizing via env or keeping a lower default for CI.


58-63: CLI defaults updated — keep docs/launch scripts in sync.

Ensure any docs or scripts referencing --osl and --num_samples defaults are updated.

.github/workflows/gpu_tests.yml (1)

85-92: Required GPU check logic is sound; add a success echo for clarity.

The conditional failure is correct. Add a final unconditional echo so the job shows a green step on success.

   gpu-pr-required-check:
     # Run even if gpu-tests-pr is skipped
     if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }}
     needs: [check-file-changes, gpu-tests-pr]
     runs-on: ubuntu-latest
     steps:
       - name: Required GPU tests did not succeed
         if: ${{ needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && needs.gpu-tests-pr.result != 'success') }}
         run: exit 1
+      - run: echo "GPU test requirements satisfied"
examples/speculative_decoding/server_generate.py (1)

155-158: Avoid NameError when printing prompt in exceptions.

prompt is undefined in the chat path; the except handler can raise another error.

Apply this diff:

-    except Exception as e:
-        print(e)
-        print(prompt)
-        print("Failed to generate data")
+    except Exception as e:
+        print(e)
+        if "prompt" in locals():
+            print(prompt)
+        print("Failed to generate data")
examples/speculative_decoding/calibrate_draft_vocab.py (1)

31-35: Validate --draft_vocab_size and clarify help.

Add basic bounds checks to fail fast and document expectations.

Apply this diff:

-        "--draft_vocab_size",
-        type=int,
-        required=True,
-        help="Draft vocab size",
+        "--draft_vocab_size",
+        type=int,
+        required=True,
+        help="Draft vocab size (must be > 0 and <= tokenizer vocab size)",

And add after parsing (right after Line 45):

 args = parser.parse_args()
 
+if args.draft_vocab_size <= 0:
+    raise ValueError("--draft_vocab_size must be > 0")

Optionally check against tokenizer size after loading it:

 tokenizer = AutoTokenizer.from_pretrained(args.model)
+if hasattr(tokenizer, "vocab") and args.draft_vocab_size > len(tokenizer.vocab):
+    raise ValueError(f"--draft_vocab_size ({args.draft_vocab_size}) exceeds tokenizer size ({len(tokenizer.vocab)})")
modelopt/torch/export/plugins/hf_spec_export.py (4)

43-49: Report all missing keys at once for better diagnostics.

Collect and display the full set of missing required keys; current code raises on the first one.

 def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
     """Check if the state dict keys match."""
     draft_keys = set(draft_model.state_dict().keys())
-    for required_key in required_items:
-        if required_key not in draft_keys:
-            raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")
+    missing = [k for k in required_items if k not in draft_keys]
+    if missing:
+        raise ValueError(
+            "State dict keys mismatch! Missing in draft model: "
+            + ", ".join(sorted(missing))
+        )

63-75: Guard against missing eagle_module attribute.

If _modelopt_state indicates eagle but model.eagle_module is absent, this will raise AttributeError.

Apply this diff:

-    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
+    if not hasattr(model, "eagle_module"):
+        raise ValueError("Eagle mode detected but model.eagle_module is missing")
+    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])

76-79: Validate LM head fallback shape to avoid silent mismatch.

When eagle_lm_head.weight is missing, copying base lm_head.weight may mis-shape the draft head (e.g., draft vs base vocab sizes). Validate and fail fast.

-    if "eagle_lm_head.weight" not in eagle_state:
-        export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
+    if "eagle_lm_head.weight" not in eagle_state:
+        base_lm = model.state_dict().get("lm_head.weight")
+        if base_lm is None:
+            raise ValueError("lm_head.weight not found in base model for fallback")
+        # Optional: if d2t present, ensure vocab alignment
+        d2t = export_state_dict.get("d2t")
+        if d2t is not None and base_lm.shape[0] != d2t.numel():
+            raise ValueError(
+                f"LM head vocab size ({base_lm.shape[0]}) does not match draft vocab size ({d2t.numel()})"
+            )
+        export_state_dict["lm_head.weight"] = base_lm

94-131: Set transformers_version automatically when missing.

Helps downstream tooling that expects this field.

 def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
@@
-    template_config = {
+    template_config = {
@@
         "transformers_version": None,
@@
     }
@@
-    for key in template_config:
+    for key in template_config:
         value = template_config[key]
@@
             template_config[key] = new_value
 
+    # Populate transformers_version if available
+    if template_config.get("transformers_version") is None:
+        try:
+            import transformers
+            template_config["transformers_version"] = transformers.__version__
+        except Exception:
+            pass
modelopt/onnx/quantization/qdq_utils.py (3)

620-635: Verify FP4 packing axis; consider packing along the last dimension for consistency.

Current packing flattens and halves the first dimension, whereas NVFP4 tensor packing typically pairs along the last axis. If consumer kernels expect last-axis packing, this will mis-shape weights.

Alternative packing along the last axis:

-    array_f32_t_shape = array_f32_t.shape
-    assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
-    array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
+    shape = list(array_f32_t.shape)
+    assert shape[-1] % 2 == 0, "last dimension must be divisible by 2 for FP4 packing"
+    packed_shape = [*shape[:-1], shape[-1] // 2]
@@
-    array_f4_t = array_f4_t.flatten()
-    array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
+    array_f4_t_packed = (array_f4_t[..., 0::2] | (array_f4_t[..., 1::2] << 4)).contiguous()
+    array_f4_t_packed = array_f4_t_packed.reshape(packed_shape)

If first-axis packing is intentional for downstream ORT/TRT consumption, please confirm and ignore this suggestion.


943-945: Don’t rely on “Constant” substring in tensor names; check producer op_type.

Name-based matching is brittle. Use the producer map to find and remove the Constant node.

Apply this diff:

-        # Remove constant node from reshape node
-        shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
-        nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)
+        # Remove Constant node feeding the Reshape shape input
+        shape_input = next(inp for inp in reshape_node.input if inp != node.output[0])
+        shape_producer = tensor_producer_map.get(shape_input)
+        if shape_producer and shape_producer.op_type == "Constant":
+            nodes_to_remove.append(shape_producer.name)

869-869: Avoid forcing ir_version downgrade unless strictly required.

Setting onnx_model.ir_version = 10 can regress compatibility/features. Prefer preserving original IR or conditionally lowering only when exporters/ORT demand it.

If this is required for current ORT targets, please document the constraint and link the issue.

tests/unit/onnx/test_qdq_utils.py (1)

70-77: Redundant Cast to FLOAT.

DequantizeLinear already produces FLOAT. Casting to FLOAT again is harmless but unnecessary noise for the unit graph. Consider casting to FLOAT16 if the aim is to exercise downcast logic, or remove this node to keep the graph minimal.

modelopt/torch/speculative/plugins/transformers.py (2)

817-819: Guard DynamicCache conversion when cache is None.

DynamicCache.from_legacy_cache(None) may not be supported across all HF versions. Add a None check.

-        if not isinstance(past_key_values, Cache):
-            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+        if past_key_values is not None and not isinstance(past_key_values, Cache):
+            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
-            if not isinstance(eagle_cache, Cache):
-                eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+            if eagle_cache is not None and not isinstance(eagle_cache, Cache):
+                eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)

Also applies to: 855-857


720-725: Make draft‑vocab mapping device‑agnostic.

Index tensor should live on the same device as full_logits.

-        reverse_mapping = (
-            torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
-            + self.eagle_module.d2t
-        )
-        return full_logits[:, :, reverse_mapping]
+        device = full_logits.device
+        d2t = self.eagle_module.d2t.to(device)
+        reverse_mapping = torch.arange(d2t.numel(), device=device, dtype=d2t.dtype) + d2t
+        return full_logits.index_select(dim=2, index=reverse_mapping)
examples/vlm_ptq/README.md (1)

39-45: Fix NVFP4 support inconsistency for Qwen2.5‑VL.

Support matrix marks NVFP4 as unsupported (❌), but the HF example advertises --quant nvfp4. Please align one of them.

If unsupported, apply:

- scripts/huggingface_example.sh --type qwen --model Qwen2.5-VL-7B-Instruct --export_fmt hf --quant [fp8|nvfp4|int8_sq|int4_awq|w4a8_awq]
+ scripts/huggingface_example.sh --type qwen --model Qwen2.5-VL-7B-Instruct --export_fmt hf --quant [fp8|int8_sq|int4_awq|w4a8_awq]

Also applies to: 80-85

examples/speculative_decoding/main.py (3)

211-224: Callback uses processing_class key; may not exist in HF callbacks.

HF usually passes tokenizer, not processing_class; this can KeyError. Prefer a safe fallback.

Apply:

-                ars = validate_ar(
-                    model=kwargs["model"],
-                    tokenizer=kwargs["processing_class"],
+                ars = validate_ar(
+                    model=kwargs["model"],
+                    tokenizer=kwargs.get("tokenizer") or kwargs.get("processing_class"),
                     ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
                     device=kwargs["model"].device,
                 )

Optionally cache the dataset in init to avoid re-loading each interval.


211-214: Default mismatch: ARValidationCallback(500) vs TrainingArguments default (1000).

Not harmful since you pass the value, but confusing; align defaults.

-        def __init__(self, ar_validate_steps: int = 500):
+        def __init__(self, ar_validate_steps: int = 1000):

236-236: Avoid using Trainer._move_model_to_device (private API).

Trainer handles device placement; calling a private method can break under FSDP/DeepSpeed.

-    trainer._move_model_to_device(model, trainer.args.device)
+    # Rely on Trainer to handle device placement
modelopt/onnx/quantization/__main__.py (1)

181-189: Behavior change: default high_precision_dtype now fp16 (was mode-dependent).

This is a user-visible default change (e.g., INT8 used to keep fp32). Confirm docs/changelog call this out and consider emitting a runtime INFO when quantize_mode == "int8" and user didn't override.

Possible guard:

if args.quantize_mode == "int8" and not any(a.startswith("--high_precision_dtype") for a in sys.argv):
    print("INFO: defaulting high_precision_dtype=fp16; set --high_precision_dtype=fp32 to keep previous behavior.")
tests/examples/speculative_decoding/test_eagle.py (1)

37-51: Skip gracefully when no GPU to reduce CI flakiness.

-def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
+import pytest
+def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
+    if num_gpus < 1:
+        pytest.skip("No GPU available")
modelopt/onnx/quantization/quantize.py (1)

289-296: Docstring tweak: clarify activations vs. weights conversion.

Minor clarity: “weights and activations” not “weight and activation”.

-            and the input model is of dtype fp32, model's weight and activation will be converted to
-            'fp16' or 'bf16'.
+            and the input model is fp32, the model's weights and activations will be converted to
+            'fp16' or 'bf16'.
examples/vlm_ptq/scripts/huggingface_example.sh (1)

94-99: Batch size 20 for qwen may OOM on smaller GPUs.

Consider making this env/flag tunable or auto-scaling by GPU memory.

: "${BUILD_MAX_BATCH_SIZE:=$([ "$MODEL_TYPE" = "llava" ] || [ "$MODEL_TYPE" = "vila" ] || [ "$MODEL_TYPE" = "qwen" ] && echo 20 || echo 4)}"
modelopt/torch/export/unified_export_hf.py (1)

513-519: Gate saving hf_quant_config.json when no quantization is applied.

Writing hf_quant_config.json even for QUANTIZATION_NONE creates confusing artifacts. Suggest saving only when a quant scheme is present.

Apply this diff:

-        # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
-        # Save hf_quant_config.json for backward compatibility
-        with open(f"{export_dir}/hf_quant_config.json", "w") as file:
-            json.dump(hf_quant_config, file, indent=4)
+        # Save hf_quant_config.json only if any quantization is applied (backward compatibility)
+        quant = hf_quant_config.get("quantization", {})
+        if quant.get("quant_algo") or quant.get("kv_cache_quant_algo") != QUANTIZATION_NONE:
+            with open(f"{export_dir}/hf_quant_config.json", "w") as file:
+                json.dump(hf_quant_config, file, indent=4)
examples/onnx_ptq/evaluate.py (2)

52-58: CLI arg rename looks good; minor help text nit.

"all other modes have been deprecated in TensorRT" is broad; if you keep it, consider “deprecated here” to avoid implying upstream deprecations. No functional change needed.


85-87: Use deterministic evaluation dataloader.

Shuffling eval data isn’t typical and makes runs non‑reproducible.

Apply this diff:

-        val_loader = torch.utils.data.DataLoader(
-            val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4
-        )
+        val_loader = torch.utils.data.DataLoader(
+            val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4
+        )
tests/examples/test_onnx_ptq.sh (1)

164-173: Fix array matching for latency-only models (ShellCheck SC2199/SC2076).

[[ " ${arr[@]} " =~ " $item " ]] is brittle. Use a loop to test membership.

Apply this diff:

-        if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then
+        is_latency_model=false
+        for lm in "${latency_models[@]}"; do
+            if [[ "$lm" == "$model_name" ]]; then
+                is_latency_model=true
+                break
+            fi
+        done
+        if $is_latency_model; then
examples/speculative_decoding/train_eagle3_and_export.sh (1)

49-55: Guard against requesting more GPUs than available.

If NUM_GPU exceeds available devices, CUDA_VISIBLE_DEVICES will reference non-existent IDs.

Apply this diff:

-if [[ "$NUM_GPU" == 1 ]]; then
+avail="$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits 2>/dev/null | head -n1)"
+avail="${avail:-0}"
+if [[ "$NUM_GPU" -gt "$avail" ]]; then
+  echo "Requested NUM_GPU=$NUM_GPU exceeds available GPUs ($avail)"; exit 1
+fi
+if [[ "$NUM_GPU" == 1 ]]; then
   export CUDA_VISIBLE_DEVICES=0
 else
   # Export as 0,1,...,N-1 for NUM_GPU GPUs
   devs="$(seq -s, 0 $((NUM_GPU-1)))"
   export CUDA_VISIBLE_DEVICES="$devs"
 fi
examples/speculative_decoding/launch_train.sh (3)

77-80: Improve error message for invalid args.

Print the actual flag that was invalid, not the post-‘=’ substring.

Apply this diff:

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument: %s\n" "$1"

88-91: Avoid divide-by-zero when no GPUs are visible.

GPU_COUNT can be 0 on CPU-only envs; protect DEFAULT_SAVE_STEPS.

Apply this diff:

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
+GPU_COUNT=$(( GPU_COUNT > 0 ? GPU_COUNT : 1 ))
+DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))

112-116: Validate EAGLE config path when provided.

Early fail produces clearer errors.

Apply this diff:

   if [[ -n "$EAGLE_CONFIG" ]]; then
-    SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
+    if [[ ! -f "$EAGLE_CONFIG" ]]; then
+      echo "eagle_config not found: $EAGLE_CONFIG"; exit 1
+    fi
+    SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py (5)

35-39: Beware: mutating shared default configs across tests.

ALGO_TO_CONFIG points to module-level defaults; the code below mutates nested fields, risking cross‑test leakage. Clone before mutation.

Apply this diff:

-        mtsp_config = ALGO_TO_CONFIG[algo]
+        from copy import deepcopy
+        mtsp_config = deepcopy(ALGO_TO_CONFIG[algo])

If prior guidance asserts these are already deepcopied per access, please confirm; otherwise prefer defensive deepcopy here.


88-89: Tighten error message formatting.

Use f-string or remove braces for clarity.

Apply this diff:

-        raise ValueError("Only algo={eagle1, eagle3, medusa} are supported!")
+        raise ValueError("Only algo in {eagle1, eagle3, medusa} are supported!")

123-124: Use integer division for shape checks.

Avoid float comparison by using //.

Apply this diff:

-    assert logits.shape[2] == vocab_size / size
+    assert logits.shape[2] == vocab_size // size

95-101: Typo in comment (non-functional).

“extrat” → “extract”.

Apply this diff:

-        # Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state
+        # Eagle3 last layer has a forward hook to extract the pre_norm hidden_state

159-165: Dead code: 'algo == "eagle"' branch never hit.

Parametrization doesn’t include "eagle" anymore. Remove or update the skip logic.

-    if algo == "eagle":
-        try:
-            import megatron.core.post_training  # noqa: F401
-        except ImportError:
-            pytest.skip("megatron.core.post_training not found")
+    # If specific dependencies are required for eagle variants, add checks for "eagle1"/"eagle3" here if needed.
examples/speculative_decoding/README.md (14)

5-10: Define α and γ once, and consider ASCII fallbacks.
A one‑liner clarifying α=accepted tokens per step and γ=draft length avoids ambiguity in downstream sections and helps readers who can’t render Unicode.

Apply this diff:

-Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed.
+Speculative decoding accelerates auto‑regressive generation by using a lightweight draft model to predict the next γ (gamma) tokens. The main LLM verifies these candidates in one forward pass. If the draft model is correct for α (alpha) tokens, the LLM accepts and generates α+1 tokens per step, improving throughput.

15-22: Table wording polish and consistency.

  • “Pre‑Requisites” → “Prerequisites”.
  • Capitalization: “EAGLE model” → “EAGLE model” is fine; keep section names in Title Case.

Apply this diff:

-| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] |
+| Prerequisites | Required & optional dependencies | \[[Link](#prerequisites)\] |
-| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
+| Simplified Workflow | Train, evaluate, and export an EAGLE model with a one‑line command | \[[Link](#getting-started-simplified-workflow)\] |

And change the section header at Line 26 accordingly:

-## Pre-Requisites
+## Prerequisites

28-39: Installation placeholders and dataset fetch details need to be actionable.

  • Replace pip install -e ... with the actual path/package, or show both editable‑install and wheel options.
  • Cloning HF datasets via git requires git‑lfs; provide an alternative using datasets.load_dataset to avoid LFS issues.

Apply this diff:

-Install Modelopt with `hf` dependencies and other requirements for this example:
+Install ModelOpt with Hugging Face dependencies and other requirements for this example:

 ```bash
-pip install -e ...
+pip install -e .[hf]
 pip install -r requirements.txt

-We use Daring-Anteater dataset in this example. Download by:
+We use the Daring‑Anteater dataset. You can either clone with git‑lfs:

-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
+git lfs install
+git clone https://huggingface.co/datasets/nvidia/Daring-Anteater

+Or programmatically download with the datasets library inside your script:
+
+python +from datasets import load_dataset +ds = load_dataset("nvidia/Daring-Anteater", split="train") +


---

`93-99`: **Link target inconsistency for default configs.**
Earlier you link to `eagle/default_config.py#L18`; here you link to `speculative/config.py#L37`. Use one stable link (without line anchors) to avoid rot.

Apply this diff:

```diff
-For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt.
+For EAGLE‑1 and EAGLE‑3 we provide a default model architecture config in ModelOpt ([default_config.py](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py)).

106-111: Hugging Face model load snippet: add tokenizer and dtype/device hints.
Minimal, but readers often copy‑paste. Adding tokenizer load and torch dtype makes it runnable.

Apply this diff:

-model = transformers.AutoModelForCausalLM.from_pretrained(
-    "<path to your pretrained model>"
-)
+tokenizer = transformers.AutoTokenizer.from_pretrained("<base model or path>", use_fast=True)
+model = transformers.AutoModelForCausalLM.from_pretrained(
+    "<base model or path>", torch_dtype="auto", device_map="auto"
+)

136-138: Show imports for mtsp.convert.
Without imports, mtsp is undefined. Add a one‑liner import.

Apply this diff:

-```python
-mtsp.convert(model, [("eagle", config)])
-```
+```python
+from modelopt.torch import speculative as mtsp
+mtsp.convert(model, [("eagle", config)])
+```

157-164: Script invocation LGTM, add line-break alignment nit.
The multi‑line backslash formatting is readable; keep comments spaced by two spaces after #.


170-175: Validation command LGTM.
Clear and minimal. Consider noting expected output (acceptance rate summary) for quick sanity checks.


180-185: Export step LGTM; add note on target format.
Specify whether the export produces a Hugging Face‑compatible directory, a safetensors file, or TRT‑LLM artifacts.


229-236: Support matrix naming consistency and scope disclaimer.

  • Normalize model names: “Llama 2”, “Llama 3/3.1”, “Qwen 1.5/2/2.5”.
  • Add a note that support depends on upstream ecosystem versions and is subject to change.

Apply this diff:

-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
+| Llama 2 | ✅ | ✅ | ✅ |
+| Llama 3, 3.1 | ✅ | ✅ | ✅ |
 ...
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+| Qwen 1.5/2/2.5 | ✅ | ✅ | ✅ |

And append under the table:

+Note: Support may vary by framework/runtime versions and will evolve over time. Refer to the release notes for the most up‑to‑date matrix.

239-241: Checkpoint collection link LGTM.
Consider adding a note about licenses and usage terms for individual checkpoints.


64-70: Update vLLM example: keep --api-key, use env var, and correct quantization note

  • vLLM's serve exposes --api-key — do not remove it. Replace the hardcoded token with an env-var pattern and show passing it to the CLI, e.g.:
    export VLLM_API_KEY="token-abc123"
    vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key "$VLLM_API_KEY" --port 8000 --tensor-parallel-size 1
  • Correct the quantization note: vLLM supports --quantization (including modelopt) and often auto-detects quantized models via hf_quant_config.json; update wording to reflect this and advise verifying the exact flag for the vLLM version used.

File: examples/speculative_decoding/README.md (lines 64–70)


142-153: Avoid private Trainer API; import ModelOpt checkpoint helper from modelopt.torch.opt.

  • Remove trainer._move_model_to_device(...) — it's a private Trainer helper; replace with model.to(trainer.args.device) or rely on Trainer's device-placement (TrainingArguments) so examples don't call a private API. (huggingface.co)
  • The HF-checkpoint helper is modelopt.torch.opt.enable_huggingface_checkpointing — use either import modelopt.torch.opt as mto; mto.enable_huggingface_checkpointing() or from modelopt.torch.opt import enable_huggingface_checkpointing; enable_huggingface_checkpointing(). Do not use modelopt.torch.export for this function. (nvidia.github.io)
  • Location: examples/speculative_decoding/README.md (lines 142–153). (github.com)

41-53: Simplified workflow example is valid — optional clarity tweak recommended.
train_eagle3_and_export.sh parses --data (default Daring-Anteater/train.jsonl), --num_gpu, and auto-creates OUTPUT_DIR, so the README one-liner works as-is; optionally add explicit --data and --output_dir in examples/speculative_decoding/README.md to make inputs/outputs clearer.

Comment on lines 2 to 9
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"initializer_range": 0.02
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Validate removal of draft_vocab_size and new rope_scaling fields usage.

If any code still reads draft_vocab_size, this change could break speculative decoding config parsing. Please verify references and update defaults.

[suggest_minor_issue]

Run to check for references:


🏁 Script executed:

#!/bin/bash
fd -HI --strip-cwd-prefix -t f '.*' | xargs rg -nP 'draft_vocab_size\b' || true

Length of output: 5961


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Searching for rope_scaling and related keys..."
fd -HI --strip-cwd-prefix -t f '.*' | xargs rg -nP 'rope_scaling\b|rope_type\b|low_freq_factor\b|high_freq_factor\b|original_max_position_embeddings\b' || true

echo
echo "Searching for eagle_architecture_config usages (context +-2 lines)..."
fd -HI --strip-cwd-prefix -t f '.*' | xargs rg -nP -n -C2 'eagle_architecture_config\b' || true

Length of output: 18106


Restore draft_vocab_size or update all references and defaults

draft_vocab_size is still referenced across the codebase and tests — removing it from configs will break parsing/initialization. Update or restore the key and align defaults/exports/tests.

Key locations to change:

  • modelopt/torch/speculative/eagle/default_config.py
  • modelopt/torch/speculative/plugins/transformers.py
  • modelopt/torch/speculative/plugins/megatron_eagle.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/export/unified_export_megatron.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/calibrate_draft_vocab.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py
  • tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py

Fix options: (A) keep "draft_vocab_size" in example eagle_config.json/defaults (recommended), or (B) remove it everywhere and update parsing, exporters, and tests to use the new rope_scaling fields.

Comment on lines 27 to 35
def parse_args():
parser = argparse.ArgumentParser(
description="Export a HF checkpoint (with ModelOpt state) for deployment."
)
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
parser.add_argument(
"--export_path", type=str, default="Destination directory for exported files."
)
return parser.parse_args()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix argparse defaults and missing help; mark args required.

Current defaults are descriptive strings, leading to load failures when args aren’t passed.

Apply this diff:

-def parse_args():
+def parse_args():
     parser = argparse.ArgumentParser(
         description="Export a HF checkpoint (with ModelOpt state) for deployment."
     )
-    parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
-    parser.add_argument(
-        "--export_path", type=str, default="Destination directory for exported files."
-    )
+    parser.add_argument("--model_path", type=str, required=True,
+                        help="Path to the trained checkpoint directory.")
+    parser.add_argument("--export_path", type=str, required=True,
+                        help="Destination directory for exported files.")
     return parser.parse_args()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def parse_args():
parser = argparse.ArgumentParser(
description="Export a HF checkpoint (with ModelOpt state) for deployment."
)
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
parser.add_argument(
"--export_path", type=str, default="Destination directory for exported files."
)
return parser.parse_args()
def parse_args():
parser = argparse.ArgumentParser(
description="Export a HF checkpoint (with ModelOpt state) for deployment."
)
parser.add_argument("--model_path", type=str, required=True,
help="Path to the trained checkpoint directory.")
parser.add_argument("--export_path", type=str, required=True,
help="Destination directory for exported files.")
return parser.parse_args()
🤖 Prompt for AI Agents
In examples/speculative_decoding/export_hf_checkpoint.py around lines 27 to 35,
the argparse arguments use descriptive strings as defaults which causes load
failures and lack help text; change each parser.add_argument to remove the
descriptive default, set required=True, and provide a proper help="..." string
(e.g., help="Path of the trained checkpoint." for --model_path and
help="Destination directory for exported files." for --export_path) so the args
are mandatory and have correct help guidance.

Comment on lines 50 to 55
try:
import wandb

wandb.init()
except ImportError:
wandb = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Don't init wandb at import time; broaden error handling.

Module-level init causes side effects, multi-process duplication, and uncaught init errors (non-ImportError). Move init into train(), gate on main process, and tolerate failures.

Apply this diff here and add in-train init per snippet below:

-try:
-    import wandb
-
-    wandb.init()
-except ImportError:
-    wandb = None
+wandb = None

Add inside train() after parsing args (use global to keep callback check working):

# inside train(), after args parsing
global wandb
if os.environ.get("WANDB_DISABLED", "").lower() not in ("1", "true", "yes"):
    try:
        import wandb as _wandb
        if getattr(training_args, "local_rank", -1) in (-1, 0):
            _wandb.init()  # optionally project/name from args/env
        wandb = _wandb
    except Exception as e:
        print_rank_0(f"wandb disabled: {e}")
        wandb = None
🤖 Prompt for AI Agents
In examples/speculative_decoding/main.py around lines 50-55, the module
currently imports and calls wandb.init() at import time and only catches
ImportError; move the side-effectful initialization into train() (after args
parsing) and replace the module-level init with a simple import-or-None
assignment (wandb = None or import without init). Inside train(), use global
wandb, check WANDB_DISABLED env var, attempt to import wandb as a temporary name
(e.g., _wandb), and only call _wandb.init() when on the main process
(training_args.local_rank in (-1,0)); catch broad exceptions during import/init,
print a rank-0 message on failure, and set wandb = None so training proceeds
without crashing.

Comment on lines 225 to 226
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Gate wandb.log to main process to avoid duplicate logs.

All ranks invoke callbacks; without gating you'll spam W&B.

-                if wandb:
-                    wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
+                if wandb and getattr(args, "local_rank", -1) in (-1, 0):
+                    wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
if wandb and getattr(args, "local_rank", -1) in (-1, 0):
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
🤖 Prompt for AI Agents
In examples/speculative_decoding/main.py around lines 225-226, wandb.log is
called from every rank which causes duplicate logs; gate the call so only the
main process logs (e.g., wrap wandb.log with a check for main process such as if
accelerator.is_main_process: or if dist.get_rank() == 0 / rank == 0), ensuring
that the main-process check is available/imported from your launcher
(Accelerator or torch.distributed) and leave all other ranks skipping the
wandb.log call.

Comment on lines 83 to 90
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:

```bash
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
```

`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
In this case, the first 40 shards of data will be processed.
To process the next 40 shards
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

d2t mapping formula is likely incorrect.
A draft→target token mapping generally isn’t an additive offset; it’s an index map. The expression target_token = draft_token + d2t[draft_token] risks invalid IDs. Prefer direct lookup target_token = d2t[draft_token].

Apply this diff:

-This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
+This produces a `d2t.pt` file in `save_dir`, which maps draft token IDs to target token IDs. During inference, map with `target_token = d2t[draft_token]`.
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around lines 83–90, the documented
draft→target mapping formula is incorrect; replace the additive formula with a
direct index lookup. Update the text to say that d2t is an index map and during
inference you obtain the target token by looking up the draft token in d2t
(target token = the value at d2t[draft_token]) rather than adding an offset, and
ensure the example command and resulting file description remain unchanged.

"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
)
parser.add_argument("--chat", action="store_true", help="Use chat mode")
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix argparse boolean parsing for --chat (current code always evaluates to True).

type=bool on argparse treats any non-empty string as True, so --chat False still enables chat. Use BooleanOptionalAction (Python ≥3.9) or paired flags.

Apply this diff:

-parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
+parser.add_argument(
+    "--chat",
+    action=argparse.BooleanOptionalAction,
+    default=True,
+    help="Use chat mode",
+)

If older Python needs support, I can wire a str2bool helper or add --no-chat.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
parser.add_argument(
"--chat",
action=argparse.BooleanOptionalAction,
default=True,
help="Use chat mode",
)
🤖 Prompt for AI Agents
In examples/speculative_decoding/server_generate.py around line 49, the argparse
argument parser uses type=bool which makes any non-empty string evaluate True so
"--chat False" still enables chat; replace this with a proper boolean flag:
either use argparse.BooleanOptionalAction (Python ≥3.9) by changing the
add_argument call to use action=argparse.BooleanOptionalAction and remove
type/default pairing, or implement a str2bool helper and set type=str2bool with
default=True, or add paired flags (--chat / --no-chat) to correctly parse false
values; update help text accordingly.

Comment on lines 396 to 400
use_torch_autocast = not (
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
)
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Include INT4 in autocast gating to avoid unintended mixed precision during export

Autocast remains enabled for INT4-quantized models, which can inject casts during tracing/export and skew ONNX dtypes. Align with FP4/MXFP8 behavior by disabling autocast for INT4 as well.

Apply:

-    use_torch_autocast = not (
-        is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
-    )
+    use_torch_autocast = not (
+        is_fp4_quantized(model)
+        or is_mxfp8_quantized(model)
+        or is_int4_quantized(model)
+        or weights_dtype == "fp32"
+    )

Optional guard to prevent crashes on CPU-only hosts:

-    autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
+    autocast = (
+        torch.autocast("cuda") if use_torch_autocast and torch.cuda.is_available() else nullcontext()
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
use_torch_autocast = not (
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
)
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
use_torch_autocast = not (
is_fp4_quantized(model)
or is_mxfp8_quantized(model)
or is_int4_quantized(model)
or weights_dtype == "fp32"
)
autocast = (
torch.autocast("cuda") if use_torch_autocast and torch.cuda.is_available() else nullcontext()
)
🤖 Prompt for AI Agents
In modelopt/torch/_deploy/utils/torch_onnx.py around lines 396 to 400, the
autocast gating does not include INT4 quantized models so autocast can remain
enabled for INT4 and introduce unwanted casts during export; update the
use_torch_autocast condition to also check for INT4 (e.g., add
is_int4_quantized(model) or equivalent) so autocast is disabled for INT4 like
FP4/MXFP8, and when creating autocast only use torch.autocast("cuda") if CUDA is
available (torch.cuda.is_available()) otherwise fall back to nullcontext() to
avoid crashes on CPU-only hosts.

05_train-->07_export_hf;
```

## Results
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@omrialmog I added QAT results here

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
examples/nemo_run/qat/README.md (2)

19-29: Flow stage names now match exact task IDs — LGTM.

Stages align with logs/ADVANCED.md naming (e.g., 00_openscience_data).


51-52: Move “Usage” right after “Overview.”

Previous feedback requested this; it’s still after “Results.” Reordering improves UX.

🧹 Nitpick comments (5)
examples/nemo_run/qat/README.md (5)

5-7: Consolidate duplicate ADVANCED.md links.

Both “Slurm Examples” and “Advanced Topics” point to the same ADVANCED.md. Collapse into one, or point them to distinct anchors.

Apply this diff:

-[Slurm Examples](ADVANCED.md) |
-[Advanced Topics](ADVANCED.md) |
+[Advanced Topics (incl. Slurm)](ADVANCED.md) |

41-50: Add reproducibility metadata to results.

Include seed, NeMo container tag, NeMo commit, TE/version, and exact eval subset definition (which “5%” slice, shuffled/fixed). Improves comparability of MMLU numbers.


55-55: Nit: extra space and wording.

“…either locally or…” → single space; also prefer “on a Slurm cluster (see ADVANCED.md)”.

Apply this diff:

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

76-82: Make the working directory explicit in the example.

Reduce ambiguity by including a cd.

Apply this diff:

-```bash
-python qat/nemo_qat_flow.py  --log-dir /my/log/dir --experiment qat_experiment
-```
+```bash
+cd examples/nemo_run
+python qat/nemo_qat_flow.py --log-dir /my/log/dir --experiment qat_experiment
+```

88-94: Mirror the explicit cd in QAD example.

Consistency helps copy‑paste.

Apply this diff:

-```bash
-python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
-```
+```bash
+cd examples/nemo_run
+python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
+```
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e6f3ed7 and 43718bc.

📒 Files selected for processing (2)
  • examples/nemo_run/qat/README.md (2 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/plugins/nemo_run.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
examples/nemo_run/qat/README.md (3)

30-39: Mermaid graph matches the 8-stage flow — LGTM.

Edges reflect the described execution order.


98-104: Defaults and HW notes — LGTM.

Clear default (Qwen3-8B/qwen3_8b) and GPU/node requirements.


107-108: Confirmed: CLI flags exist as documented.

Found: --tensor_parallelism and --pipeline_parallelism (examples/nemo_run/qat/nemo_qat_flow.py; examples/nemo_run/common/in_memory_mmlu.py); --ptq-gpus, --train-gpus, --train-nodes (examples/nemo_run/qat/nemo_qat_flow.py).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
examples/nemo_run/qat/nemo_qat_flow.py (3)

138-145: Fix KV‑cache tri‑state; don’t force disable by default.

Current CLI always passes --disable_kv_cache when the user doesn’t opt in. Make it an explicit tri‑state and build the flag only if set. (Matches prior feedback.)

@@
-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    # Tri-state KV cache flag: None (unspecified), True (enable), False (disable)
+    kv_group = parser.add_mutually_exclusive_group()
+    kv_group.add_argument(
+        "--enable_kv_cache",
+        dest="enable_kv_cache",
+        help="Enable KV-cache quantization",
+        action="store_true",
+    )
+    kv_group.add_argument(
+        "--disable_kv_cache",
+        dest="enable_kv_cache",
+        help="Disable KV-cache quantization",
+        action="store_false",
+    )
+    parser.set_defaults(enable_kv_cache=None)
@@
-            "--kv_cache_qformat",
-            args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            "--kv_cache_qformat",
+            args.kv_cache_qformat,
+            *(
+                ["--enable_kv_cache"] if args.enable_kv_cache is True
+                else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+            ),

Also applies to: 196-201


148-156: Guard globals used in main() for programmatic invocation.

SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL are defined under main only; importing this module and calling main(args) will raise NameError. (Matches prior feedback.)

 def main(args):
+    # Fallbacks if constants aren’t initialized (e.g., when called programmatically)
+    global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL
+    SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+    MBS = globals().get("MBS", 1)
+    GBS = globals().get("GBS", 512)
+    TRAIN_STEPS = globals().get("TRAIN_STEPS", 200)
+    VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)

Also applies to: 224-233, 242-249, 375-379


283-285: Slurm multi‑node training bug: nodes not set on executor.

Training will always request 1 node because nodes isn’t passed to create_slurm_executor. NeMo config alone won’t fix the Slurm allocation.

-        train_gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
-        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG,
+            nodes=args.train_nodes,
+            num_gpus=args.train_gpus,
+            ntasks_per_node=args.train_gpus,
+        )
examples/nemo_run/qat/README.md (1)

70-71: Token env var and permissions guidance.

Use the canonical env var and avoid chmod 777; suggest running as host UID and/or mounting a writable logs dir.

-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
+Set your Hugging Face token (HF_TOKEN is also recognized):
+`export HUGGING_FACE_HUB_TOKEN=<your-token>`  # or `export HF_TOKEN=<your-token>`
+
+For write access, avoid `chmod 777`. Run the container as your user (`-u $(id -u):$(id -g)`) and/or mount a writable logs directory (e.g., `-v /home/user/logs:/logs`) and pass `--log-dir /logs`.
🧹 Nitpick comments (3)
examples/nemo_run/qat/nemo_qat_flow.py (2)

278-287: Avoid mutating the training executor for export; use a dedicated export executor.

Changing ntasks_per_node on train_gpu_executor is brittle and can affect subsequent tasks. Create a dedicated export executor that requests all GPUs but runs a single task.

@@
-        cpu_executor = create_slurm_executor(SLURM_CONFIG)
+        cpu_executor = create_slurm_executor(SLURM_CONFIG)
         ptq_gpu_executor = create_slurm_executor(
             SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
         )
         train_gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
+            SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
         )
-        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
+        single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
+        # Export needs visibility to all GPUs but only 1 task due to NeMo bug
+        export_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=1
+        )
@@
-        cpu_executor = single_gpu_executor = run.LocalExecutor()
+        cpu_executor = single_gpu_executor = run.LocalExecutor()
         ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
         train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus)
+        export_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=1)
@@
-        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
-        train_gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
+        exp.add(
             export,
             tail_logs=True,
             name="07_export_hf",
-            executor=train_gpu_executor,
+            executor=export_gpu_executor,
             dependencies=[s5],
         )

Also applies to: 288-291, 339-346


149-156: Clarify naming to avoid confusion between recipe and HF model name.

model_name is first a recipe (e.g., qwen3_8b) then potentially overwritten with HF basename; this risks mistakes. Use distinct variables (e.g., recipe_name and hf_model_basename).

examples/nemo_run/qat/README.md (1)

59-63: Add editable install step after cloning to ensure mounted sources are importable.

Without installs, Python may import preinstalled packages from the container.

 To run the example locally, launch a NeMo container ... then mount it onto your docker container.
 
 - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
 - `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`
+
+Inside the container, install the mounted repos in editable mode (or set PYTHONPATH):
+
+```bash
+python -m pip install -U pip
+python -m pip install -e /opt/NeMo
+python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt
+# Alternatively:
+# export PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer/modelopt:$PYTHONPATH
+```
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 43718bc and 16f61f1.

📒 Files selected for processing (2)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-33)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
examples/nemo_run/qat/nemo_qat_flow.py (1)

272-276: Verify finetuned_ckpt_dir value for MMLU SFT eval.

Passing exp_dir may include multiple runs; if in_memory_mmlu.py expects a single experiment directory, prefer train.log.log_dir (or most recent subdir) for determinism.

-        args=["--finetuned_ckpt_dir", exp_dir, "--tensor_parallelism", f"{args.ptq_gpus}"],
+        args=["--finetuned_ckpt_dir", train.log.log_dir, "--tensor_parallelism", f"{args.ptq_gpus}"],

If multiple runs are present under log_dir, consider pointing to the most recent subdir explicitly.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
examples/nemo_run/qat/README.md (6)

59-63: Mounted sources won’t be importable without installs; add editable installs

Add steps so Python resolves mounted NeMo/TMO sources inside the container.

 To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.
 - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
 - `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`
+
+Inside the container, install the mounted repos in editable mode (or set PYTHONPATH):
+
+```bash
+python -m pip install -U pip
+python -m pip install -e /opt/NeMo
+python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt
+# Alternatively:
+# export PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer/modelopt:$PYTHONPATH
+```

66-68: Avoid mounting into site-packages; run as non-root and use editable installs

Mounting over site-packages is brittle; prefer workspace mounts + pip -e.

-```bash
-docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
-```
+```bash
+docker run --rm -it --gpus=all --shm-size=20g \
+  -u $(id -u):$(id -g) \
+  -v /home/user/NeMo:/opt/NeMo \
+  -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \
+  -v /home/user/logs:/logs \
+  nvcr.io/nvidia/nemo:25.07 bash
+
+# Inside the container:
+python -m pip install -U pip
+python -m pip install -e /opt/NeMo
+python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt
+```

70-71: Use canonical HF token env var; don’t recommend chmod 777

Prefer HUGGINGFACE_HUB_TOKEN (HF_TOKEN as alias). Avoid world-writable perms; run as host user and/or mount a writable logs dir.

-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
+Set your Hugging Face token (HF_TOKEN is also recognized):
+`export HUGGINGFACE_HUB_TOKEN=<your-token>`  # or `export HF_TOKEN=<your-token>`
+
+For write access, avoid `chmod 777`. Instead run the container as your user (`-u $(id -u):$(id -g)`) and/or mount a writable logs directory (e.g., `-v /home/user/logs:/logs`) and pass `--log-dir /logs`.

86-86: CLI flag mismatch: use --enable_kv_cache (underscore)

Docs should match the script; also use pipe in the choice list and mention disable flag.

-> **_NOTE:_** To enable KV cache quantization, add `--enable-kv-cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`.
+> **_NOTE:_** To enable KV cache quantization, add `--enable_kv_cache` and specify qformat using `--kv-cache-qformat <fp8|nvfp4>`. To explicitly disable it, use `--disable_kv_cache`.

53-53: Move “Usage” nearer the top (after Overview/Flow)

Improves discoverability; this aligns with earlier feedback.


98-101: Keep Slurm details in ADVANCED.md; link instead of inline flags

Reduce duplication by referencing the Slurm doc here.

-Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. On Slurm you can configure the number of nodes/gpus for training and PTQ with the following flags: `--train-nodes`, `--train-gpus`, `--ptq-gpus`.
+Locally this script supports 1 node with 8 × 80GB GPUs. For Slurm configuration (nodes/GPUs for training and PTQ), see [ADVANCED.md](ADVANCED.md).
🧹 Nitpick comments (3)
examples/nemo_run/qat/README.md (3)

5-7: Dedup header links; both point to ADVANCED.md

Avoid redundant links; keep one combined link.

-[Slurm Examples](ADVANCED.md) |
-[Advanced Topics](ADVANCED.md) |
+[Slurm & Advanced Topics](ADVANCED.md) |
 [NeMo Integration](https://github.com/NVIDIA-NeMo/NeMo/tree/main/nemo/collections/llm/modelopt)

51-51: Grammar/consistency nit

Use “is also” and consistent spacing.

-The resulting exported checkpoint also is much smaller in memory at 6.4GB compared to the original BF16 checkpoint which is 16.4 GB.
+The resulting exported checkpoint is also much smaller in memory at 6.4 GB, compared to the original BF16 checkpoint at 16.4 GB.

57-57: Extra space

Minor spacing fix.

-You can run the example either locally  or on a [Slurm cluster](ADVANCED.md).
+You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 16f61f1 and 246a765.

📒 Files selected for processing (8)
  • examples/llm_qat/README.md (1 hunks)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/common/process_openscience.py (1 hunks)
  • examples/nemo_run/common/utils.py (1 hunks)
  • examples/nemo_run/qat/ADVANCED.md (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
  • examples/nemo_run/common/process_openscience.py
  • examples/nemo_run/common/utils.py
  • examples/nemo_run/common/in_memory_mmlu.py
  • modelopt/torch/export/plugins/nemo_run.py
  • examples/nemo_run/qat/nemo_qat_flow.py
  • examples/llm_qat/README.md
  • examples/nemo_run/qat/ADVANCED.md
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/qat/README.md (2)

59-63: Pin/validate NeMo commit — confirmed
Commit 676ed1a exists in NVIDIA‑NeMo/NeMo and is compatible with the nvcr.io/nvidia/nemo:25.07 container; no update required.


109-110: Verify CLI flags exist in nemo_qat_flow.py

Automated search failed (missing "nemo" path); confirm --tensor_parallelism and --pipeline_parallelism are defined as argparse CLI args in nemo_qat_flow.py, or update the README to point to the correct flags.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
examples/nemo_run/common/utils.py (2)

120-124: Guard missing recipe module; avoid AttributeError.

Use a safe getattr and raise a clear ValueError.

-def get_finetune_recipe(recipe_name: str):
-    if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
-        raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
-    return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
+def get_finetune_recipe(recipe_name: str):
+    recipe_mod = getattr(llm, recipe_name, None)
+    if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"):
+        raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe")
+    return recipe_mod.finetune_recipe(peft_scheme=None)

91-105: Propagate env vars to Slurm jobs (GPU + CPU executors).

SlurmConfig.env_vars is never applied; pass through to SlurmExecutor.

         return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_gpu,
             ntasks_per_node=ntasks_per_node,
             gpus_per_node=num_gpus,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
             gres=f"gpu:{num_gpus}",
+            env_vars=slurm_cfg.env_vars,
         )
@@
         return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_cpu,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
+            env_vars=slurm_cfg.env_vars,
         )

Run to confirm the constructor param name (env_vars/env) for your installed nemo-run:

#!/usr/bin/env bash
python - << 'PY'
import inspect
import nemo_run as run
sig = inspect.signature(run.SlurmExecutor)
print(sig)
print([p for p in sig.parameters])
PY

Also applies to: 107-117

examples/nemo_run/qat/nemo_qat_flow.py (4)

185-203: Only pass KV‑cache flag when explicitly set.

Build the flag conditionally; avoid hard “disable”.

-    ptq = run.Script(
+    kv_cache_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
         "/opt/NeMo/scripts/llm/ptq.py",
         args=[
             "-nc",
             bf16_ckpt_path,
             "-out",
             ptq_model_out,
             "--export_format",
             "nemo",
             "--algorithm",
             args.algorithm,
             "--kv_cache_qformat",
             args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            *kv_cache_flag,
             "-ctp",
             f"{args.ptq_gpus}",
         ],
         entrypoint="python",
     )

138-142: Make KV‑cache flag tri‑state; don’t force disable by default.

Expose both enable/disable flags and default to “unspecified”.

-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    kv = parser.add_mutually_exclusive_group()
+    kv.add_argument("--enable_kv_cache", dest="enable_kv_cache", action="store_true", help="Enable KV-cache quantization")
+    kv.add_argument("--disable_kv_cache", dest="enable_kv_cache", action="store_false", help="Disable KV-cache quantization")
+    parser.set_defaults(enable_kv_cache=None)

148-156: Guard globals and support programmatic invocation.

Avoid NameError when called outside main; validate Slurm config presence.

 def main(args):
+    # Fallbacks if constants/SLURM_CONFIG aren’t defined at import time
+    global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL
+    SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+    MBS = globals().get("MBS", 1)
+    GBS = globals().get("GBS", 512)
+    TRAIN_STEPS = globals().get("TRAIN_STEPS", 200)
+    VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)
+    if args.use_slurm and "SLURM_CONFIG" not in globals():
+        raise ValueError("SLURM_CONFIG must be defined when --use-slurm is set (provide via __main__ or module scope).")

338-346: Use dedicated single‑GPU executor for export; don’t mutate the train executor.

Prevents stale Slurm settings and surprises in multi‑stage runs.

-        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
-        train_gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
-            export,
-            tail_logs=True,
-            name="07_export_hf",
-            executor=train_gpu_executor,
-            dependencies=[s5],
-        )
+        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
+        exp.add(
+            export,
+            tail_logs=True,
+            name="07_export_hf",
+            executor=single_gpu_executor,
+            dependencies=[s5],
+        )
🧹 Nitpick comments (4)
examples/nemo_run/common/utils.py (2)

40-51: (Optional) Validate Slurm time format early.

Prevent misconfigurations by checking HH:MM:SS.

+import re
@@
         if not self.time:
             raise ValueError("SlurmConfig.time must be set to job time limit (e.g., '02:00:00')")
+        elif not re.fullmatch(r"\d{1,2}:\d{2}:\d{2}", self.time):
+            raise ValueError("SlurmConfig.time must be in HH:MM:SS format (e.g., '04:00:00')")

126-128: Open chat template with UTF‑8 to avoid locale issues.

-def read_chat_template(template_path: str):
-    with open(template_path) as f:
+def read_chat_template(template_path: str):
+    with open(template_path, encoding="utf-8") as f:
         return f.read().strip()
examples/nemo_run/qat/nemo_qat_flow.py (2)

151-153: Handle missing/invalid recipe module gracefully.

Avoid AttributeError; emit a clear message.

-    model_name = args.finetune_recipe
-    model_module = getattr(llm, model_name)
+    model_name = args.finetune_recipe
+    model_module = getattr(llm, model_name, None)
+    if model_module is None:
+        raise ValueError(f"Unknown finetune recipe module: {model_name}")

264-276: Tie MMLU TP/PP to user flags (optional).

Use args.tensor_parallelism/pipeline_parallelism for eval to match train; reduces confusion when PTQ GPU count differs.

-    eval_ptq = run.Script(
+    eval_ptq = run.Script(
         mmlu_script_path,
-        args=["--nemo_ckpt", ptq_model_out, "--tensor_parallelism", f"{args.ptq_gpus}"],
+        args=[
+            "--nemo_ckpt", ptq_model_out,
+            "--tensor_parallelism", f"{args.tensor_parallelism}",
+            "--pipeline_parallelism", f"{args.pipeline_parallelism}",
+        ],
         entrypoint="python",
     )
-    eval_bf16 = run.Script(
+    eval_bf16 = run.Script(
         mmlu_script_path,
-        args=["--nemo_ckpt", bf16_ckpt_path, "--tensor_parallelism", f"{args.ptq_gpus}"],
+        args=[
+            "--nemo_ckpt", bf16_ckpt_path,
+            "--tensor_parallelism", f"{args.tensor_parallelism}",
+            "--pipeline_parallelism", f"{args.pipeline_parallelism}",
+        ],
         entrypoint="python",
     )
-    eval_sft = run.Script(
+    eval_sft = run.Script(
         mmlu_script_path,
-        args=["--finetuned_ckpt_dir", exp_dir, "--tensor_parallelism", f"{args.ptq_gpus}"],
+        args=[
+            "--finetuned_ckpt_dir", exp_dir,
+            "--tensor_parallelism", f"{args.tensor_parallelism}",
+            "--pipeline_parallelism", f"{args.pipeline_parallelism}",
+        ],
         entrypoint="python",
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 246a765 and 1249de1.

📒 Files selected for processing (8)
  • examples/llm_qat/README.md (1 hunks)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/common/process_openscience.py (1 hunks)
  • examples/nemo_run/common/utils.py (1 hunks)
  • examples/nemo_run/qat/ADVANCED.md (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (6)
  • examples/nemo_run/common/in_memory_mmlu.py
  • examples/nemo_run/common/process_openscience.py
  • examples/nemo_run/qat/ADVANCED.md
  • examples/llm_qat/README.md
  • modelopt/torch/export/plugins/nemo_run.py
  • examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
  • download_hf_dataset (29-37)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-33)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Comment on lines +67 to +72
self.env_vars |= {
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Don’t override caller-provided env vars; reverse the merge.

self.env_vars |= {...} overwrites user settings. Defaults should be baseline, user wins on conflicts.

-        self.env_vars |= {
-            "CUDA_DEVICE_MAX_CONNECTIONS": "1",  # Disable GPU communication/computation overlap for performance
-            "TRANSFORMERS_OFFLINE": "1",  # Disable online downloads from HuggingFace
-            "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
-            "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
-        }
+        defaults = {
+            "CUDA_DEVICE_MAX_CONNECTIONS": "1",  # Disable GPU communication/computation overlap for performance
+            "TRANSFORMERS_OFFLINE": "1",        # Disable online downloads from HuggingFace
+            "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
+            "NCCL_NVLS_ENABLE": "0",            # Disable NVLink SHARP to save memory
+        }
+        # User-specified values take precedence
+        self.env_vars = defaults | self.env_vars
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.env_vars |= {
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
}
defaults = {
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
}
# User-specified values take precedence
self.env_vars = defaults | self.env_vars
🤖 Prompt for AI Agents
In examples/nemo_run/common/utils.py around lines 67 to 72, the current use of
self.env_vars |= {...} forces these defaults to override any caller-provided
environment variables; change the merge order so caller values win. Create a
defaults dict with those keys, then merge so existing self.env_vars take
precedence (e.g., apply defaults first and then overlay self.env_vars, or set
each default only if the key is not already present). Update the code to perform
a non-destructive merge rather than overwriting user-supplied entries.

Signed-off-by: Jennifer Chen <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/nemo_run/qat/nemo_qat_flow.py (1)

186-201: Fix PTQ CLI flags (NeMo 25.07)

  • File: examples/nemo_run/qat/nemo_qat_flow.py (lines 186–201) — the flags used are incorrect for NeMo 25.07 and will break runs. Replace the KV-cache flags and the short -ctp alias with the canonical names: e.g. enable FP8 KV-cache with --use_fp8_kv_cache (or disable paged KV cache with --no_paged_kv_cache as appropriate), and replace -ctp with the appropriate full TP flag (--calibration_tp, --calibration_pp or --inference_tp) depending on intent. Update accordingly before running.
♻️ Duplicate comments (11)
examples/nemo_run/qat/README.md (4)

53-57: Move “Usage” directly after “Overview”

Improves discoverability; align with prior review.


66-69: Do not mount into site-packages or run as root; use editable installs and user UID/GID

Replace the docker run example with a safer pattern; also add editable installs for NeMo and modelopt inside the container.

-```bash
-docker run -v  /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
-```
+```bash
+docker run --rm -it --gpus=all --shm-size=20g \
+  -u $(id -u):$(id -g) \
+  -v /home/user/NeMo:/opt/NeMo \
+  -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \
+  -v /home/user/logs:/logs \
+  nvcr.io/nvidia/nemo:25.07 bash
+
+# Inside the container, install mounted repos in editable mode:
+python -m pip install -U pip
+python -m pip install -e /opt/NeMo
+python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt
+```

70-70: Use HUGGING_FACE_HUB_TOKEN and avoid chmod 777; suggest safer write-access guidance

Update token env var and replace the 777 advice with running as host UID and/or mounting a writable logs dir.

-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
+Set your Hugging Face token (HF_TOKEN is also recognized):
+`export HUGGING_FACE_HUB_TOKEN=<your-token>`  # or `export HF_TOKEN=<your-token>`
+
+For write access, avoid `chmod 777`. Run the container as your user (`-u $(id -u):$(id -g)`) and/or mount a writable logs directory (e.g., `-v /home/user/logs:/logs`) and pass `--log-dir /logs`.

86-86: Flag name mismatch: use --enable_kv_cache (underscore); also document --disable_kv_cache

The script defines underscore style; current docs use hyphens.

-> **_NOTE:_** To enable KV cache quantization, add `--enable-kv-cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`.
+> **_NOTE:_** To enable KV cache quantization, add `--enable_kv_cache` and specify qformat using `--kv-cache-qformat <fp8|nvfp4>`. To explicitly disable it, use `--disable_kv_cache`.
examples/nemo_run/common/utils.py (3)

120-124: Guard missing recipe module to avoid AttributeError

Use a safe getattr and clearer error.

-def get_finetune_recipe(recipe_name: str):
-    if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
-        raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
-    return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
+def get_finetune_recipe(recipe_name: str):
+    recipe_mod = getattr(llm, recipe_name, None)
+    if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"):
+        raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe")
+    return recipe_mod.finetune_recipe(peft_scheme=None)

67-72: Don’t overwrite user env vars; merge defaults so user wins

Current self.env_vars |= {...} overrides caller-provided values. Reverse the merge.

-        self.env_vars |= {
-            "CUDA_DEVICE_MAX_CONNECTIONS": "1",  # Disable GPU communication/computation overlap for performance
-            "TRANSFORMERS_OFFLINE": "1",  # Disable online downloads from HuggingFace
-            "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
-            "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
-        }
+        defaults = {
+            "CUDA_DEVICE_MAX_CONNECTIONS": "1",  # Disable GPU communication/computation overlap for performance
+            "TRANSFORMERS_OFFLINE": "1",        # Disable online downloads from HuggingFace
+            "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
+            "NCCL_NVLS_ENABLE": "0",            # Disable NVLink SHARP to save memory
+        }
+        # User-specified values take precedence
+        self.env_vars = defaults | self.env_vars

91-106: Pass env vars into SlurmExecutor

SlurmConfig.env_vars is never applied; forward it to the executor.

         return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_gpu,
             ntasks_per_node=ntasks_per_node,
             gpus_per_node=num_gpus,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
             gres=f"gpu:{num_gpus}",
+            env_vars=slurm_cfg.env_vars,
         )
@@
         return run.SlurmExecutor(
             account=slurm_cfg.account,
             partition=slurm_cfg.partition_cpu,
             nodes=nodes,
             tunnel=tunnel,
             container_image=slurm_cfg.container_image,
             container_mounts=slurm_cfg.container_mounts,
             time=slurm_cfg.time,
             packager=run.GitArchivePackager(),
             mem="0",
+            env_vars=slurm_cfg.env_vars,
         )

To confirm parameter name on your installed nemo-run:

#!/usr/bin/env bash
python - <<'PY'
import inspect, nemo_run as run
sig = inspect.signature(run.SlurmExecutor)
print("SlurmExecutor params:", list(sig.parameters))
PY

Also applies to: 107-117

examples/nemo_run/qat/nemo_qat_flow.py (4)

138-145: KV‑cache tri‑state: only add enable/disable flag when explicitly set

Current logic always appends “--disable_kv_cache” when not enabling, which removes neutrality. Use a mutually exclusive group with default None and build the flag conditionally.

-    parser.add_argument(
-        "--enable_kv_cache",
-        help="Enables KV-cache quantization",
-        action="store_true",
-        default=False,
-    )
+    kv_group = parser.add_mutually_exclusive_group()
+    kv_group.add_argument("--enable_kv_cache", dest="enable_kv_cache", action="store_true",
+                          help="Enable KV-cache quantization")
+    kv_group.add_argument("--disable_kv_cache", dest="enable_kv_cache", action="store_false",
+                          help="Disable KV-cache quantization")
+    parser.set_defaults(enable_kv_cache=None)
@@
-    ptq = run.Script(
+    # Build KV-cache flag only when explicitly set
+    kv_cache_flag = (
+        ["--enable_kv_cache"] if args.enable_kv_cache is True
+        else (["--disable_kv_cache"] if args.enable_kv_cache is False else [])
+    )
+    ptq = run.Script(
         "/opt/NeMo/scripts/llm/ptq.py",
         args=[
@@
-            "--kv_cache_qformat",
-            args.kv_cache_qformat,
-            "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache",
+            "--kv_cache_qformat", args.kv_cache_qformat,
+            *kv_cache_flag,

Also applies to: 185-201


222-233: Make constants robust when imported programmatically

Guard against NameError if main() is called without main initializers.

 def main(args):
+    # Fallbacks if constants are unset (e.g., when imported)
+    global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL
+    SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+    MBS = globals().get("MBS", 1)
+    GBS = globals().get("GBS", 512)
+    TRAIN_STEPS = globals().get("TRAIN_STEPS", 200)
+    VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)

283-286: Honor --train-nodes on Slurm

Training executor currently hardcodes nodes=1. Pass nodes=args.train_nodes.

-        train_gpu_executor = create_slurm_executor(
-            SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
-        )
+        train_gpu_executor = create_slurm_executor(
+            SLURM_CONFIG,
+            nodes=args.train_nodes,
+            num_gpus=args.train_gpus,
+            ntasks_per_node=args.train_gpus,
+        )

338-346: Use the single‑GPU executor for export; don’t mutate the train executor

Mutating ntasks_per_node is brittle and may be ignored by the backend.

-        # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
-        train_gpu_executor.ntasks_per_node = 1  # will throw error if more than 1 task during export
-        exp.add(
-            export,
-            tail_logs=True,
-            name="07_export_hf",
-            executor=train_gpu_executor,
-            dependencies=[s5],
-        )
+        # Export with a dedicated single‑GPU executor
+        exp.add(
+            export,
+            tail_logs=True,
+            name="07_export_hf",
+            executor=single_gpu_executor,
+            dependencies=[s5],
+        )
🧹 Nitpick comments (8)
examples/nemo_run/qat/README.md (3)

59-63: Pin NeMo commit with context

Consider adding a brief note why commit 676ed1a is required and the date/tag, or switch to a tag for stability.

What NeMo features fixed by 676ed1a are required here (PTQ CLI flags, training hooks, etc.)?


41-51: Results reproducibility note

Add hardware/software versions (driver/CUDA/NeMo/TMO) and random seed for reproducibility.


51-51: Grammar nit

“also is much smaller” → “is also much smaller.”

-The resulting exported checkpoint also is much smaller in memory at 6.4GB compared to the original BF16 checkpoint which is 16.4 GB.
+The resulting exported checkpoint is also much smaller in memory: 6.4 GB vs 16.4 GB for the original BF16 checkpoint.
examples/nemo_run/common/utils.py (2)

52-66: Validate job_dir when using LocalTunnel

LocalTunnel requires job_dir as well; add a check for use_local_tunnel=True.

         if not self.use_local_tunnel:
             # Only validate SSH tunnel settings if not using local tunnel
@@
                 )
+        else:
+            if not self.job_dir:
+                raise ValueError("SlurmConfig.job_dir must be set when use_local_tunnel is True")

131-139: Optional: reuse existing helper to avoid duplication

download_hf_dataset duplicates logic in examples/nemo_run/common/process_lima.py. Consider importing/reusing to DRY.

examples/nemo_run/qat/nemo_qat_flow.py (3)

148-156: Fix model/recipe naming confusion; use HF basename for output dirs

model_name = args.finetune_recipe is misleading and the fallback never triggers. Use distinct names and drive paths from HF model basename.

-    model_name = args.finetune_recipe
-    model_module = getattr(llm, model_name)
-    if not model_name:
-        model_name = os.path.basename(args.model_name)
-    exp_dir = f"{args.log_dir.rstrip('/')}/{args.experiment}"
+    recipe_name = args.finetune_recipe
+    try:
+        model_module = getattr(llm, recipe_name)
+    except AttributeError as e:
+        raise ValueError(f"Unknown recipe: {recipe_name}") from e
+    hf_basename = os.path.basename(args.model_name)
+    model_name = hf_basename  # used for directory naming
+    exp_dir = f"{args.log_dir.rstrip('/')}/{args.experiment}"

157-171: Remove stale TODO and path ambiguity; rely on resolved paths

Comments suggest uncertainty; the code already resolves absolute vs Slurm paths. Clean up.

-    # 1. Process data
-    # TODO figure out path
-    # LOCALLY common/process.py works
-    # On slurm examples/nemo_run/common/process.py works
+    # 1. Process data

362-364: Prefer HUGGING_FACE_HUB_TOKEN; keep HF_TOKEN as alias

Align with README and common tooling.

-            env_vars={
-                "HF_TOKEN": "",
-            },
+            env_vars={
+                "HUGGING_FACE_HUB_TOKEN": "",  # HF_TOKEN also recognized by some tools
+                "HF_TOKEN": "",
+            },
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1249de1 and 410c7dd.

📒 Files selected for processing (8)
  • examples/llm_qat/README.md (1 hunks)
  • examples/nemo_run/common/in_memory_mmlu.py (1 hunks)
  • examples/nemo_run/common/process_openscience.py (1 hunks)
  • examples/nemo_run/common/utils.py (1 hunks)
  • examples/nemo_run/qat/ADVANCED.md (1 hunks)
  • examples/nemo_run/qat/README.md (2 hunks)
  • examples/nemo_run/qat/nemo_qat_flow.py (6 hunks)
  • modelopt/torch/export/plugins/nemo_run.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/llm_qat/README.md
🚧 Files skipped from review as they are similar to previous changes (4)
  • examples/nemo_run/common/process_openscience.py
  • examples/nemo_run/common/in_memory_mmlu.py
  • examples/nemo_run/qat/ADVANCED.md
  • modelopt/torch/export/plugins/nemo_run.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
  • download_hf_dataset (29-37)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
  • export_most_recent_ckpt (24-33)
examples/nemo_run/common/utils.py (4)
  • SlurmConfig (24-72)
  • create_slurm_executor (75-117)
  • get_finetune_recipe (120-123)
  • read_chat_template (126-128)
examples/nemo_run/common/in_memory_mmlu.py (1)
  • parse_args (24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
examples/nemo_run/qat/README.md (3)

19-39: Flow stages list + diagram: aligned and clear

Eight stages match task names; diagram edges are coherent.


5-7: Two links pointing to the same ADVANCED.md

If intentional, rename one (e.g., “Slurm Guide | Advanced Topics”) or deduplicate.
[raise_nitpick_refactor]


66-71: Doc inconsistent with AI summary and prior commits

The file still shows the old docker guidance and KV-cache flag; AI summary/past comments indicate these were addressed. Update the README accordingly.

Also applies to: 86-86

@jenchen13 jenchen13 merged commit 1c7c16e into NVIDIA:main Sep 18, 2025
22 checks passed
@jenchen13 jenchen13 deleted the jennifchen/qat_slurm branch September 18, 2025 18:24
yeyu-nvidia pushed a commit that referenced this pull request Sep 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants