Skip to content

Commit 5d46a7b

Browse files
committed
respond coderabbit
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 46e6c13 commit 5d46a7b

File tree

5 files changed

+21
-23
lines changed

5 files changed

+21
-23
lines changed

examples/nemo_run/common/in_memory_mmlu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def parse_args():
2525
parser = argparse.ArgumentParser(
2626
description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --ckpt_dir"
2727
)
28-
parser.add_argument("--nemo_ckpt", type=str, required=False, help="Path to NeMo checkpoint.")
29-
parser.add_argument(
30-
"--ckpt_dir",
28+
group = parser.add_mutually_exclusive_group(required=True)
29+
group.add_argument("--nemo_ckpt", type=str, required=False, help="Path to NeMo checkpoint.")
30+
group.add_argument(
31+
"--finetuned_ckpt_dir",
3132
required=False,
3233
type=str,
3334
help="Checkpoint directory of 1 or more finetuned models",
@@ -43,7 +44,6 @@ def parse_args():
4344

4445
if __name__ == "__main__":
4546
args = parse_args()
46-
assert args.nemo_ckpt or args.ckpt_dir, "Provide one of either --nemo_ckpt or --ckpt_dir."
4747
ckpt_path = args.nemo_ckpt
4848
if args.ckpt_dir:
4949
ckpt_path = _get_most_recent_ckpt(args.ckpt_dir)

examples/nemo_run/qat/ADVANCED.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# NeMo QAT/QAD Flow: Advanced Topics
22

3-
If you need to run QAT/QAD on a Slurm cluster (for example to use more than 1 node)
3+
If you need to run QAT/QAD on a Slurm cluster (for example to use more than 1 node), this guide covers how to configure and launch on Slurm.
44

55
To run the example on slurm, edit the `SLURM_CONFIG` at the bottom of `nemo_qat_flow.py` with the appropriate credentials, container, cluster name (host), and container mounts. Make sure you are mounting the NeMo and Megatron-LM repositories above in the Slurm cluster and that you've checked out the correct commits.
66

examples/nemo_run/qat/README.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212

1313
This directory contains an end-to-end QAT Simplified Flow example using NeMo for model training. It supports both QAT with cross-entropy loss and QAD (quantization-aware distillation) with knowledge-distillation loss between the BF16 teacher and quantized student models.
1414

15-
After PTQ (post-training quantization), the quantized model may
15+
After PTQ (post-training quantization), the quantized model may show some accuracy degradation on tasks like MMLU; the QAT/QAD stages aim to recover that loss.
1616

1717
## Flow Stages
1818

19-
Currently the Simplified Flow runs the following steps in order:
19+
The Simplified Flow runs the following steps in order:
2020

21-
1. Process Nvidia/OpenScience data (if `--data-path` is not specified)
22-
1. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on BF16 checkpoint
23-
1. PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint
24-
1. SFT (finetune) the model
25-
1. Evaluate 5% of MMLU on the SFT checkpoint
26-
1. Export model to Unified checkpoint (HuggingFace) format in lower precision
21+
1. 00_openscience_data — Process NVIDIA/OpenScience data (skipped if `--data-path` is given)
22+
1. 01_import_model — Import NeMo BF16 model checkpoint
23+
1. 02_mmlu_bf16 — Evaluate 5% MMLU on BF16 checkpoint
24+
1. 03_ptq — Apply PTQ
25+
1. 04_mmlu_ptq — Evaluate 5% MMLU on PTQ checkpoint
26+
1. 05_train — SFT/QAT (and optional QAD)
27+
1. 06_mmlu_sft — Evaluate 5% MMLU on SFT/QAT checkpoint
28+
1. 07_export_hf — Export to Hugging Face (Unified) format
2729

2830
```mermaid
2931
graph TD;

examples/nemo_run/qat/nemo_qat_flow.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,8 @@ def get_args():
138138
"--enable_kv_cache",
139139
help="Enables KV-cache quantization",
140140
action="store_true",
141+
default=False
141142
)
142-
parser.add_argument(
143-
"--disable_kv_cache",
144-
dest="enable_kv_cache",
145-
action="store_false",
146-
)
147-
148-
parser.set_defaults(enable_kv_cache=None)
149143
return parser.parse_args()
150144

151145

@@ -265,7 +259,7 @@ def main(args):
265259
)
266260
eval_sft = run.Script(
267261
mmlu_script_path,
268-
args=["--ckpt_dir", exp_dir],
262+
args=["--finetuned_ckpt_dir", exp_dir],
269263
entrypoint="python",
270264
)
271265

modelopt/torch/export/plugins/nemo_run.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ def _get_most_recent_ckpt(directory: str):
5757
str: Path to the most recent subdirectory.
5858
"""
5959
exp_dir = Path(directory) / "default"
60-
assert exp_dir.exists(), f"Experiment directory {exp_dir} does not exist"
60+
if not exp_dir.exists():
61+
raise FileNotFoundError(f"Experiment directory {exp_dir} does not exist")
6162

6263
checkpoint_dir = exp_dir / "checkpoints"
6364
if checkpoint_dir.exists():
6465
most_recent = _get_most_recent_subdir(checkpoint_dir)
6566
else:
6667
most_recent = _get_most_recent_subdir(exp_dir)
6768
checkpoint_dir = most_recent / "checkpoints"
68-
assert checkpoint_dir.exists(), f"Checkpoint directory {checkpoint_dir} does not exist"
69+
if not checkpoint_dir.exists():
70+
raise FileNotFoundError(f"Checkpoint directory {checkpoint_dir} does not exist")
6971
most_recent = _get_most_recent_subdir(checkpoint_dir)
7072

7173
return str(most_recent)

0 commit comments

Comments
 (0)