Skip to content

Commit 1c7c16e

Browse files
authored
Slurm support for QAT Simplified Flow + Qwen3-8B recipe (#285)
Signed-off-by: Jennifer Chen <[email protected]>
1 parent b7ed8cd commit 1c7c16e

File tree

10 files changed

+674
-127
lines changed

10 files changed

+674
-127
lines changed

examples/llm_qat/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Quantization Aware Training (QAT) helps to improve the model accuracy beyond pos
1111
| Support Matrix | View the support matrix to see quantization compatibility and feature availability across different models | \[[Link](#support-matrix)\] | |
1212
| End to End QAT | Example scripts demonstrating quantization techniques for optimizing Hugging Face models | \[[Link](#end-to-end-qat-example)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/1_quantization.html)\] |
1313
| End to End QAD | Example scripts demonstrating quantization aware distillation techniques for optimizing Hugging Face models | \[[Link](#end-to-end-qad-example)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/1_quantization.html)\] |
14+
| NeMo QAT/QAD Simplified Flow | Example script demonstrating end-to-end QAT/QAD in NeMo | \[[Link](../nemo_run/qat/README.md)\] | |
1415
| Evaluate Accuracy | Evaluating model accuracy after QAT/QAD (with fake quantization) | \[[Link](#testing-qat-model-with-llm-benchmarks-for-accuracy-evaluation)\] | |
1516
| Deployment | Deploying the model after QAT/QAD | \[[Link](#deployment)\] | |
1617
| QLoRA | Model training with reduced GPU memory | \[[Link](#end-to-end-qlora-with-real-quantization)\] | |
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
18+
from nemo.collections.llm.modelopt import setup_trainer_and_restore_model_with_modelopt_spec
19+
20+
from modelopt.torch.export.plugins.nemo_run import _get_most_recent_ckpt
21+
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
22+
23+
24+
def parse_args():
25+
parser = argparse.ArgumentParser(
26+
description=(
27+
"Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt"
28+
"or --finetuned_ckpt_dir"
29+
)
30+
)
31+
group = parser.add_mutually_exclusive_group(required=True)
32+
group.add_argument("--nemo_ckpt", type=str, required=False, help="Path to NeMo checkpoint.")
33+
group.add_argument(
34+
"--finetuned_ckpt_dir",
35+
required=False,
36+
type=str,
37+
help="Checkpoint directory of 1 or more finetuned models",
38+
)
39+
parser.add_argument(
40+
"--tensor_parallelism", type=int, default=1, help="Tensor parallelism size."
41+
)
42+
parser.add_argument(
43+
"--pipeline_parallelism", type=int, default=1, help="Pipeline parallelism size."
44+
)
45+
return parser.parse_args()
46+
47+
48+
if __name__ == "__main__":
49+
args = parse_args()
50+
ckpt_path = args.nemo_ckpt
51+
if args.finetuned_ckpt_dir:
52+
ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir)
53+
model, trainer = setup_trainer_and_restore_model_with_modelopt_spec(
54+
ckpt_path,
55+
tensor_model_parallel_size=args.tensor_parallelism,
56+
pipeline_model_parallel_size=args.pipeline_parallelism,
57+
devices=args.tensor_parallelism * args.pipeline_parallelism,
58+
)
59+
tokenizer = model.tokenizer.tokenizer
60+
megatron_mmlu(model.module, tokenizer)
File renamed without changes.
File renamed without changes.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import os
18+
from pathlib import Path
19+
20+
from datasets import load_dataset
21+
22+
23+
def get_parser():
24+
parser = argparse.ArgumentParser(description="Process nvidia/OpenScience dataset")
25+
parser.add_argument("--output-dir", type=str, default=".")
26+
return parser
27+
28+
29+
def convert_row_oai(row: dict):
30+
return {
31+
"messages": [
32+
{"role": "user", "content": row["input"]},
33+
{"role": "assistant", "content": row["output"]},
34+
]
35+
}
36+
37+
38+
def process_subset(raw_dir, proc_dir):
39+
ds = load_dataset(raw_dir)
40+
ds = ds.map(convert_row_oai, remove_columns=["input", "output"])
41+
42+
split_ds = ds["train"].train_test_split(test_size=0.1)
43+
split_ds["train"].to_json(os.path.join(proc_dir, "training.jsonl"))
44+
split_ds["test"].to_json(os.path.join(proc_dir, "validation.jsonl"))
45+
46+
47+
if __name__ == "__main__":
48+
args = get_parser().parse_args()
49+
raw_dir = f"{args.output_dir}/openscience_raw"
50+
proc_dir = f"{args.output_dir}/openscience_proc"
51+
52+
if not os.path.exists(raw_dir):
53+
q235_subset = load_dataset("nvidia/OpenScience", data_files="OS-Q3-235B-4.jsonl")
54+
q235_subset.save_to_disk(raw_dir)
55+
56+
if not os.path.exists(proc_dir):
57+
Path(proc_dir).mkdir(exist_ok=True)
58+
print("Processing OpenScience dataset")
59+
process_subset(raw_dir, proc_dir)
60+
else:
61+
print(f"Processed OpenScience dataset exists in: {proc_dir}, skipped processing")

examples/nemo_run/common/utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import subprocess
17+
from dataclasses import dataclass, field
18+
19+
import nemo_run as run
20+
from nemo.collections import llm
21+
22+
23+
@dataclass
24+
class SlurmConfig:
25+
"""Configuration for SlurmExecutor."""
26+
27+
account: str = "" # Your Slurm account
28+
partition_cpu: str = "" # Slurm CPU partition to use
29+
partition_gpu: str = "" # Slurm GPU partition to use
30+
time: str = "" # Job time limit (HH:MM:SS)
31+
container_image: str = "" # Container image for jobs
32+
env_vars: dict[str, str] = field(default_factory=dict) # Environment variables to set
33+
container_mounts: list[str] = field(default_factory=list) # Container mounts
34+
use_local_tunnel: bool = False # Set to True if running from within the cluster
35+
host: str = "" # Required for SSH tunnel: Slurm cluster hostname
36+
user: str = "" # Required for SSH tunnel: Your username
37+
job_dir: str = "" # Required for SSH tunnel: Directory to store runs on cluster
38+
identity: str | None = None # Optional for SSH tunnel: Path to SSH key for authentication
39+
40+
def __post_init__(self):
41+
"""Validate the configuration and raise descriptive errors."""
42+
if not self.account:
43+
raise ValueError("SlurmConfig.account must be set to your actual Slurm account")
44+
if not self.partition_cpu:
45+
raise ValueError("SlurmConfig.partition_cpu must be set")
46+
if not self.partition_gpu:
47+
raise ValueError("SlurmConfig.partition_gpu must be set")
48+
if not self.time:
49+
raise ValueError("SlurmConfig.time must be set to job time limit (e.g., '02:00:00')")
50+
if not self.container_image:
51+
raise ValueError("SlurmConfig.container_image must be set to container image for jobs")
52+
if not self.use_local_tunnel:
53+
# Only validate SSH tunnel settings if not using local tunnel
54+
if not self.host:
55+
raise ValueError(
56+
"SlurmConfig.host must be set to your actual cluster hostname when using SSH tunnel"
57+
)
58+
if not self.user:
59+
raise ValueError(
60+
"SlurmConfig.user must be set to your actual username when using SSH tunnel"
61+
)
62+
if not self.job_dir:
63+
raise ValueError(
64+
"SlurmConfig.job_dir must be set to directory for storing runs on cluster"
65+
)
66+
67+
self.env_vars |= {
68+
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance
69+
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace
70+
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
71+
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
72+
}
73+
74+
75+
def create_slurm_executor(
76+
slurm_cfg: SlurmConfig, nodes: int = 1, ntasks_per_node: int = 1, num_gpus: int = 0
77+
):
78+
# Configure tunnel
79+
if slurm_cfg.use_local_tunnel:
80+
# Use LocalTunnel when already on the cluster
81+
tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir)
82+
else:
83+
# Use SSH tunnel when launching from local machine
84+
tunnel = run.SSHTunnel(
85+
host=slurm_cfg.host,
86+
user=slurm_cfg.user,
87+
job_dir=slurm_cfg.job_dir,
88+
identity=slurm_cfg.identity, # can be None
89+
)
90+
91+
if num_gpus > 0:
92+
return run.SlurmExecutor(
93+
account=slurm_cfg.account,
94+
partition=slurm_cfg.partition_gpu,
95+
ntasks_per_node=ntasks_per_node,
96+
gpus_per_node=num_gpus,
97+
nodes=nodes,
98+
tunnel=tunnel,
99+
container_image=slurm_cfg.container_image,
100+
container_mounts=slurm_cfg.container_mounts,
101+
time=slurm_cfg.time,
102+
packager=run.GitArchivePackager(),
103+
mem="0",
104+
gres=f"gpu:{num_gpus}",
105+
)
106+
else:
107+
return run.SlurmExecutor(
108+
account=slurm_cfg.account,
109+
partition=slurm_cfg.partition_cpu,
110+
nodes=nodes,
111+
tunnel=tunnel,
112+
container_image=slurm_cfg.container_image,
113+
container_mounts=slurm_cfg.container_mounts,
114+
time=slurm_cfg.time,
115+
packager=run.GitArchivePackager(),
116+
mem="0",
117+
)
118+
119+
120+
def get_finetune_recipe(recipe_name: str):
121+
if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
122+
raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
123+
return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
124+
125+
126+
def read_chat_template(template_path: str):
127+
with open(template_path) as f:
128+
return f.read().strip()
129+
130+
131+
def download_hf_dataset(dataset_name: str, output_dir: str | None = None):
132+
"""Download a dataset from HuggingFace Hub using huggingface-cli."""
133+
cmd = ["huggingface-cli", "download", dataset_name, "--repo-type", "dataset"]
134+
135+
if output_dir:
136+
cmd.extend(["--local-dir", output_dir])
137+
138+
subprocess.run(cmd, check=True)
139+
print(f"Successfully downloaded dataset: {dataset_name}")

examples/nemo_run/qat/ADVANCED.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# NeMo QAT/QAD Flow: Advanced Topics
2+
3+
If you need to run QAT/QAD on a Slurm cluster (for example to use more than 1 node), this guide covers how to configure and launch on Slurm.
4+
5+
To run the example on slurm, edit the `SLURM_CONFIG` at the bottom of `nemo_qat_flow.py` with the appropriate credentials, container, cluster name (host), and container mounts. Make sure you are mounting the NeMo and Megatron-LM repositories above in the Slurm cluster and that you've checked out the correct commits.
6+
7+
## Running the Flow on Slurm
8+
9+
To launch the Flow on a Slurm cluster, modify your Slurm credentials at the bottom of `nemo_qat_flow.py` and add the `--use-slurm` flag to the command. On a different server (e.g. your local server), launch the NeMo container as described in the [README](README.md) then run `python qat/nemo_qat_flow.py --use-slurm --log-dir /slurm/log/dir`, which will `ssh` into the Slurm cluster, `rsync` your files over, and launch the tasks. The log directory on the Slurm cluster should look like this after an experiment is run (assuming your experiment name is `qat_flow_ckpts`)
10+
11+
```bash
12+
qat_flow_ckpts qat_flow_ckpts_1755708286
13+
```
14+
15+
If you `cd` into the experiment itself, e.g. `cd qat_flow_ckpts_1755708286`, you'll find a directory structure like the following. Each folder is for a stage of the Simplified Flow, and in each stage you can see the logs for that stage as well as the sbatch command that was run. You can `cd` into each stage and `tail -f` the log file to see the logs while the stage is running.
16+
17+
```bash
18+
├── 00_openscience_data
19+
│   ├── code
20+
│   ├── configs
21+
│   ├── log-coreai_dlalgo_modelopt-modelopt.00_openscience_data_5345664_0.out
22+
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.00_openscience_data_5345664.out
23+
├── 01_import_model
24+
│   ├── code
25+
│   ├── configs
26+
│   ├── log-coreai_dlalgo_modelopt-modelopt.01_import_model_5345665_0.out
27+
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.01_import_model_5345665.out
28+
├── 02_mmlu_bf16
29+
│   ├── code
30+
│   ├── configs
31+
│   ├── log-coreai_dlalgo_modelopt-modelopt.02_mmlu_bf16_5345666_0.out
32+
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.02_mmlu_bf16_5345666.out
33+
├── 03_ptq
34+
│   ├── code
35+
│   ├── configs
36+
│   ├── log-coreai_dlalgo_modelopt-modelopt.03_ptq_5345667_0.out
37+
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.03_ptq_5345667.out
38+
├── 04_mmlu_ptq
39+
│   ├── code
40+
│   ├── configs
41+
│   ├── log-coreai_dlalgo_modelopt-modelopt.04_mmlu_ptq_5345668_0.out
42+
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.04_mmlu_ptq_5345668.out
43+
├── 05_train
44+
│   ├── code
45+
│   ├── configs
46+
│   ├── log-coreai_dlalgo_modelopt-modelopt.05_train_5345669_0.out
47+
│   └── sbatch_coreai_dlalgo_modelopt-modelopt.05_train_5345669.out
48+
├── 06_mmlu_sft
49+
│   ├── code
50+
│   └── configs
51+
├── 07_export_hf
52+
│   ├── code
53+
│   └── configs
54+
```
55+
56+
**NOTE:** `rsync` may not currently be available in the NeMo container and will be added as a dependency.

0 commit comments

Comments
 (0)