Skip to content

Commit 9269c34

Browse files
committed
More suggestions
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 3090a98 commit 9269c34

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

examples/nemo_run/prune_distill/README.md

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,37 @@ After structured pruning, the compressed model may show some accuracy degradatio
1212

1313
## Flow Stages
1414

15-
The Simplified Flow runs the following steps in order:
15+
The Simplified Flow runs the following steps:
1616

1717
1. 01_import — Import HuggingFace model to NeMo format
18-
1. 02_prune — Apply structured pruning to create a compressed student model
18+
1. 02a_eval_teacher — Evaluate teacher model on 5% of MMLU benchmark
19+
1. 02b_prune — Apply structured pruning to create a compressed student model
1920
1. 03_distill — Knowledge distillation from teacher to pruned student model
20-
1. 04_export — Export final compressed model to HuggingFace format
21-
1. eval_teacher — Evaluate teacher model on 5% of MMLU benchmark
22-
1. eval_student — Evaluate student model on 5% of MMLU benchmark
21+
1. 04a_eval_student — Evaluate student model on 5% of MMLU benchmark
22+
1. 04b_export — Export final compressed model to HuggingFace format
2323

2424
```mermaid
2525
graph TD;
26-
01_import-->02_prune;
27-
01_import-->eval_teacher;
28-
02_prune-->03_distill;
29-
03_distill-->eval_student;
30-
03_distill-->04_export;
26+
01_import-->02a_eval_teacher;
27+
01_import-->02b_prune;
28+
02b_prune-->03_distill;
29+
03_distill-->04a_eval_student;
30+
03_distill-->04b_export;
3131
```
3232

3333
## Results
3434

3535
Pruning + Knowledge Distillation of Qwen3-8B achieves significant model compression while recovering most of the accuracy through distillation. We depth-prune the model from 32 to 24 layers (reducing from 8B to 6B parameters) and distill for ~28,000 steps (determined by sequence length, default 4096) with a learning rate of 1e-4 and global batch size of 768 using a 25% subset of the [ClimbMix dataset](https://huggingface.co/datasets/OptimalScale/ClimbMix). (This is about 90 billion tokens and takes a total of ~6k H100 GPU hours)
3636

37-
| | Tokens per Second | MMLU |
38-
|-----------------------------------|-------------------|------|
39-
| Qwen3-8B Original | 4420 | 74.9 |
40-
| Qwen3-6B Pruned+Distilled from 8B | 6950 | 72.5 |
41-
| Qwen3-4B Original (comparison) | 5210 | 70.0 |
37+
| | Tokens per Second * | MMLU |
38+
|-----------------------------------|---------------------|------|
39+
| Qwen3-8B Original | 4420 | 74.9 |
40+
| Qwen3-6B Pruned+Distilled from 8B | 6950 | 72.5 |
41+
| Qwen3-4B Original (comparison) | 5210 | 70.0 |
4242

43-
The resulting compressed student maintains competitive performance while being significantly faster with a smaller memory footprint than the teacher. It also happens to have both better performance and throughput than the existing Qwen3-4B model!
43+
The resulting compressed student maintains competitive performance while being significantly faster with fewer parameters than the teacher. It also happens to have both better performance and throughput than the existing Qwen3-4B model!
44+
45+
\* _Measured on H100 using TRT-LLM, FP8 precision_
4446

4547
## Usage
4648

@@ -76,14 +78,16 @@ This will download and process the ClimbMix dataset, creating the necessary data
7678
After launching the NeMo container with the specified mounts, change the contents of the `SLURM_CONFIG` in `nemo_prune_kd_flow.py`
7779
to reflect your environment, and then perform the following:
7880

79-
From the `nemo_run` folder, launch the example with the `nemo_prune_kd_flow.py` script. To use a different model than the default model (Qwen3-8B), you can add the `--model-name <hf-model-name> --base-recipe <recipe-name>` flags and use the model's HuggingFace name and NeMo recipe names listed [here](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes). Provide the processed dataset path using the `--data-dir` flag.
81+
Launch the example with the `nemo_prune_kd_flow.py` script. To use a different model than the default model (Qwen3-8B), you can add the `--model-name <hf-model-name> --base-recipe <recipe-name>` flags and use the model's HuggingFace name and NeMo recipe names listed [here](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes). Provide the processed dataset path using the `--data-dir` flag.
8082

8183
To perform Pruning + Knowledge Distillation, run:
8284

8385
```bash
8486
python prune_distill/nemo_prune_kd_flow.py --log-dir /my/log/dir --data-dir /path/to/climbmix_proc --use-slurm
8587
```
8688

89+
> **_NOTE:_** You can omit the `--use-slurm` flag to run locally for testing, and optionally with `--mock-run` to use a mock dataset.
90+
8791
## Supported models
8892

8993
Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. On Slurm you can configure the number of nodes/gpus for training and pruning with the following flags: `--nodes`, `--train-gpus`.

examples/nemo_run/prune_distill/nemo_prune_kd_flow.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@ def get_args():
7373
parser.add_argument(
7474
"--data-dir",
7575
type=str,
76-
help="Path the preprocessed dataset",
76+
help="Path to the preprocessed dataset",
77+
)
78+
parser.add_argument(
79+
"--data-prefixes",
80+
type=str,
81+
nargs="*",
82+
help="Prefixes of the .bin and .idx files in the data directory",
83+
default=[f"part_{i}_text_document" for i in SUBSET_IDX],
7784
)
7885
parser.add_argument(
7986
"--train-gpus",
@@ -122,7 +129,7 @@ def main(args):
122129
)
123130
data = run.Config(
124131
PreTrainingDataModule,
125-
paths=[f"{args.data_dir}/part_{i}_text_document" for i in SUBSET_IDX],
132+
paths=[f"{args.data_dir}/{prefix}" for prefix in args.data_prefixes],
126133
seq_length=SEQUENCE_LENGTH,
127134
tokenizer=tokenizer,
128135
global_batch_size=DISTILL_GBS,
@@ -236,11 +243,18 @@ def main(args):
236243
tail_logs=True,
237244
name="01_import",
238245
)
246+
_ = exp.add(
247+
eval_teacher,
248+
executor=multi_gpu_executor,
249+
tail_logs=True,
250+
name="02a_eval_teacher",
251+
dependencies=[s1],
252+
)
239253
s2 = exp.add(
240254
prune,
241255
executor=gpu_executor,
242256
tail_logs=True,
243-
name="02_prune",
257+
name="02b_prune",
244258
dependencies=[s1],
245259
)
246260
s3 = exp.add(
@@ -250,18 +264,11 @@ def main(args):
250264
name="03_distill",
251265
dependencies=[s2],
252266
)
253-
_ = exp.add(
254-
eval_teacher,
255-
executor=multi_gpu_executor,
256-
tail_logs=True,
257-
name="eval_teacher",
258-
dependencies=[s1],
259-
)
260267
_ = exp.add(
261268
eval_student,
262269
executor=multi_gpu_executor,
263270
tail_logs=True,
264-
name="eval_student",
271+
name="04a_eval_student",
265272
dependencies=[s3],
266273
)
267274
# WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
@@ -270,7 +277,7 @@ def main(args):
270277
export_model,
271278
executor=multi_gpu_executor,
272279
tail_logs=True,
273-
name="04_export",
280+
name="04b_export",
274281
dependencies=[s3],
275282
)
276283
exp.run(detach=True)
@@ -316,8 +323,8 @@ def main(args):
316323
else:
317324
PRUNE_SAMPLES = 512
318325
DISTILL_GBS = 768
319-
_NUM_TOKENS = int(90e9)
320-
DISTILL_STEPS = int(_NUM_TOKENS / DISTILL_GBS / SEQUENCE_LENGTH)
326+
NUM_TOKENS = int(90e9)
327+
DISTILL_STEPS = int(NUM_TOKENS / DISTILL_GBS / SEQUENCE_LENGTH)
321328
VAL_INTERVAL = 1000
322329
# # # # # # # # # # # # # # # # # # # # # #
323330

0 commit comments

Comments
 (0)