-
Notifications
You must be signed in to change notification settings - Fork 169
Updated Prune-KD NeMo flow #382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a new ClimbMix preprocessing script; refactors the prune+distill orchestration into a staged, SLURM-aware experiment graph (import → prune → distill → eval → export); and updates README entries to reference the simplified flow and the new data-prep step. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User
participant F as KD Flow Script (`nemo_prune_kd_flow.py`)
participant D as DataModule (HF / Mock)
participant ExecL as Local Executor
participant ExecS as Slurm Executor
participant CK as Checkpoints
participant EX as Exporter
U->>F: Run with args (model-id, base-recipe, data-dir, use-slurm,…)
F->>D: Initialize data (tokenizer, subset paths)
Note right of F: s1: 01_import — prepare initial checkpoint
F->>ExecL: s2: 02b_prune (single-device config)
ExecL-->>CK: Write pruned checkpoint
F->>ExecS: s3: 03_distill (multi-node/GPU via Slurm if enabled)
ExecS-->>CK: Save distilled checkpoints
F->>ExecL: s4: 02a_eval_teacher (MMLU)
F->>ExecL: s5: 04a_eval_student (MMLU)
F->>ExecL: s6: 04b_export (use recent ckpt, single-task workaround)
ExecL-->>EX: Produce exported model artifacts
EX-->>U: Artifacts in exp_dir
Note over ExecL,ExecS: Executor chosen by `--use-slurm` flag
sequenceDiagram
autonumber
participant U as User
participant P as `process_climbmix.py`
participant HF as HuggingFace Hub
participant PP as `megatron_preprocess_data`
participant FS as Filesystem
U->>P: Run with --output-dir --tokenizer [--subset-index]
P->>HF: snapshot_download (allow_patterns for subset files)
HF-->>FS: Raw subset files placed in `out_dir/raw`
P->>PP: Tokenize & pack (seq_len, workers, append_eod)
PP-->>FS: Processed data in `out_dir/processed`
P-->>U: Print paths to processed data
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Asha Anoosheh <[email protected]>
eb4e041
to
d632578
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/llm_distill/README.md
(1 hunks)examples/nemo_run/common/process_climbmix.py
(1 hunks)examples/nemo_run/prune_distill/README.md
(1 hunks)examples/nemo_run/prune_distill/nemo_prune_kd_flow.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-33)examples/nemo_run/common/utils.py (2)
SlurmConfig
(24-72)create_slurm_executor
(75-117)
examples/nemo_run/common/process_climbmix.py (1)
modelopt/torch/utils/plugins/megatron_preprocess_data.py (1)
megatron_preprocess_data
(160-203)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
examples/llm_distill/README.md (1)
19-19
: Link addition looks goodThe new row cleanly surfaces the prune+distill flow and the relative link resolves to the new README as expected.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #382 +/- ##
==========================================
+ Coverage 73.46% 73.79% +0.32%
==========================================
Files 172 171 -1
Lines 17640 17583 -57
==========================================
+ Hits 12959 12975 +16
+ Misses 4681 4608 -73 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
*[0, 1, 6, 10, 11], | ||
*[12, 13, 14, 21, 24], | ||
*[33, 35, 38, 40, 48], | ||
*[49, 52, 66, 70, 76], | ||
*[83, 88, 91, 94, 99], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why we selected these numbers in the first place? Was it randomly generated or had some significance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started to download/process them all before realizing it didn't need them all, so it already downloaded 0,1,10,11,12,13,14 but the rest are random.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to randomly sample 25 numbers out of 0-99 instead of hardcoding this? If necessary to hardcode, can this just be a normal list instead of having strange syntax (multiple sub-lists with asterisks to unpack them)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have better reproducibility and control if they are hardcoded. For example, I can now import this list of indices and create the data_paths
argument to the PretrainDataModule
in the Nemo-Run script.
The sub-list formatting is apparently Ruff's ideal way to make it fit within the line limit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add # fmt: off
and # fmt: on
line before and after the code block to avoid auto-formatting
Signed-off-by: Asha Anoosheh <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (4)
30-32
: Fragile path manipulation for imports.The
sys.path.append
with relative path construction assumes the script is run from a specific location. If the script is invoked from a different directory or the repository structure changes, the import will fail. Consider making the common utilities a proper package or use explicit relative imports.
135-142
: Add validation for model recipe name.The code uses
getattr(llm, model_name)
without checking if the recipe exists. An invalid--base-recipe
value will cause anAttributeError
at runtime. Consider adding a try-except or validating against known recipes.Apply this diff to add validation:
initial_model_out = f"{exp_dir}/{model_name}_initial" - model_module = getattr(llm, model_name) + try: + model_module = getattr(llm, model_name) + except AttributeError: + raise ValueError(f"Recipe '{model_name}' not found in NeMo llm module. Check --base-recipe value.") import_model = run.Partial(
267-275
: Executor modification may cause issues if reused.Line 268 modifies
multi_gpu_executor.ntasks_per_node
in place as a workaround for a NeMo bug. If the executor is reused after the export stage or if stages run out of order, this could cause unexpected behavior. Since export is the final stage and usesdetach=True
, this is likely safe, but consider documenting this assumption.Consider creating a dedicated export executor instead:
- # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo - multi_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export + # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo + export_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=1) if not args.use_slurm else create_slurm_executor(SLURM_CONFIG, ntasks_per_node=1, num_gpus=args.train_gpus, nodes=args.nodes) _ = exp.add( export_model, - executor=multi_gpu_executor, + executor=export_executor,
284-299
: SLURM_CONFIG requires user configuration.The placeholder SLURM_CONFIG will fail validation if
--use-slurm
is used without modification. This is intentional to force users to configure their environment, but consider adding a more prominent comment or early validation to provide clearer feedback.Add an early check:
if args.use_slurm: + # Validate SLURM_CONFIG is configured before proceeding + if not all([SLURM_CONFIG.account, SLURM_CONFIG.host, SLURM_CONFIG.user, SLURM_CONFIG.job_dir]): + raise ValueError("SLURM_CONFIG must be configured with your cluster details before using --use-slurm") SLURM_CONFIG = SlurmConfig(
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/nemo_run/common/process_climbmix.py
(1 hunks)examples/nemo_run/prune_distill/README.md
(1 hunks)examples/nemo_run/prune_distill/nemo_prune_kd_flow.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/nemo_run/common/process_climbmix.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-33)examples/nemo_run/common/utils.py (2)
SlurmConfig
(24-72)create_slurm_executor
(75-117)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (18)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (11)
35-100
: Well-structured argument parser.The argument definitions are clear, properly typed, and include helpful descriptions. The new arguments align well with the simplified flow architecture.
103-107
: Clean entry point structure.The main function properly encapsulates the workflow logic with clear variable naming.
116-117
: Guard against missing--data-dir
correctly implemented.The ValueError properly prevents silent failures when running in non-mock mode without specifying the data directory. This resolves the previous review concern.
144-160
: Pruning configuration looks correct.The single-device pruning setup with explicit device and pipeline parallelism settings is appropriate. The use of cloned data config with adjusted batch sizes is a clean pattern.
161-181
: Distillation configuration is comprehensive.The distillation recipe setup includes all necessary hyperparameters with sensible defaults. The
ckpt_load_strictness="log_all"
setting appropriately handles potential checkpoint compatibility issues during distillation.
182-203
: MMLU evaluation setup is correct.The conditional path resolution for Slurm vs. local execution and the proper routing of teacher/student checkpoints to their respective evaluation scripts is well-structured.
205-211
: Export configuration is correct.The export stage properly references the distillation output directory to export the most recent checkpoint to HuggingFace format.
213-230
: Executor setup is well-structured.The conditional executor configuration cleanly separates Slurm and local execution paths. The distinction between single-GPU and multi-GPU executors provides flexibility for different stages.
232-277
: Execution graph structure is correct.The staged execution with explicit dependencies forms a clean DAG. The parallel evaluation of teacher (post-import) and sequential distillation → student evaluation → export flow is logical.
301-322
: Well-organized hyperparameter configuration.The clear separation between mock and production parameters, along with the token-based calculation for distillation steps, makes the configuration easy to understand and modify.
125-125
: SUBSET_IDX is properly defined and exported.Verification confirms that
SUBSET_IDX
is defined inexamples/nemo_run/common/process_climbmix.py
(lines 23-28) as a module-level constant containing 25 subset indices representing 25% of the ClimbMix dataset. The import statement at line 31 correctly imports this constant, and the usage at line 125 appropriately constructs data paths corresponding to the processed dataset files.examples/nemo_run/prune_distill/README.md (7)
1-12
: Clear documentation overview.The title and overview accurately describe the simplified flow and its purpose. The explanation of accuracy degradation and recovery through distillation provides good context.
13-32
: Flow diagram accurately reflects implementation.The mermaid diagram and stage descriptions correctly match the execution graph defined in the code, showing both the linear dependency chain (import → prune → distill → export) and parallel evaluation paths.
33-43
: Results clearly demonstrate the value proposition.The results table effectively shows the compression benefits with the 6B pruned model outperforming the original 4B model in both throughput and accuracy, while maintaining competitive performance with the 8B teacher.
45-62
: Prerequisites are clear and complete.The setup instructions cover all necessary components. The
chmod 777
recommendation is pragmatic for container environments, though users should be aware this grants broad permissions.
63-72
: Dataset preparation instructions are clear.The rationale for manual preprocessing and the command to execute it are well-documented and correctly reference the preprocessing script.
74-85
: Slurm execution instructions are correct.The dataset path in the command uses the correct directory name "climbmix_proc" and the instructions properly guide users through the SLURM configuration and execution process.
87-99
: Model support and limitations are clearly documented.The hardware requirements, default configuration, and important disclaimer about dataset/hyperparameter sensitivity set appropriate expectations for users attempting different model-dataset combinations.
*[0, 1, 6, 10, 11], | ||
*[12, 13, 14, 21, 24], | ||
*[33, 35, 38, 40, 48], | ||
*[49, 52, 66, 70, 76], | ||
*[83, 88, 91, 94, 99], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to randomly sample 25 numbers out of 0-99 instead of hardcoding this? If necessary to hardcode, can this just be a normal list instead of having strange syntax (multiple sub-lists with asterisks to unpack them)?
) | ||
parser.add_argument( | ||
"--chat_template", | ||
"--data-dir", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for pretraining would it make more sense to provide a flexible length --data-dir
argument and set the default to the list of SUBSET_IDX
here instead of hardcoding it below in line 129? This way the user can change the data as they wish
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Done
else: | ||
PRUNE_SAMPLES = 1024 | ||
DISTILL_GBS = 768 | ||
_NUM_TOKENS = 89694564352 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the user provided a different dataset, the DISTILL_STEPS
here would be hardcoded to the _NUM_TOKENS
based on climb mix ... is it possible to just not have _NUM_TOKENS
/ steps derived from num tokens and instead fix DISTILL_STEPS
to a reasonable number? I also find it odd that there's a giant number stored in this file ...
|
||
# NeMo Pruning + Knowledge Distillation Simplified Flow Example | ||
|
||
[Slurm Examples](ADVANCED.md) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't have an ADVANCED.md file .. can you move the Slurm info into it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to just leave it as Slurm default tutorial in main README cause KD on such dataset will be useless on a local node anyway.
1. eval_teacher — Evaluate teacher model on 5% of MMLU benchmark | ||
1. eval_student — Evaluate student model on 5% of MMLU benchmark |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why these are not numbered? Like 05a, 05b?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're not actually sequential, but I fixed it by making them 2b and 4b
- **DISTILL_STEPS**: Number of distillation steps (default: 7500 for real runs, 20 for mock runs) | ||
- **VAL_INTERVAL**: Validation interval (default: 500 for real runs, 10 for mock runs) | ||
- **PRUNE_SAMPLES**: Number of samples for pruning calibration (default: 1024 for real runs, 3 for mock runs) | ||
### Running the Flow via Slurm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to the QAT flow README can the Slurm info be moved into ADVANCED.md?
Signed-off-by: Asha Anoosheh <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (2)
30-32
: Consider using relative imports or package installation.The
sys.path.append
pattern works but is fragile and makes the module structure unclear. For a more maintainable approach, consider either:
- Using relative imports if this is part of a package
- Installing the common utilities as a local package
79-84
: Document the SUBSET_IDX dependency.The default value for
--data-prefixes
depends onSUBSET_IDX
imported fromprocess_climbmix.py
. If users modify the subset indices when runningprocess_climbmix.py
(or if the script changes), this default becomes stale. Consider adding a help message note about this dependency, or document it in the README.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/nemo_run/prune_distill/README.md
(1 hunks)examples/nemo_run/prune_distill/nemo_prune_kd_flow.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-33)examples/nemo_run/common/utils.py (2)
SlurmConfig
(24-72)create_slurm_executor
(75-117)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (10)
examples/nemo_run/prune_distill/nemo_prune_kd_flow.py (5)
123-124
: Good fix for missing data-dir validation.This guard correctly addresses the previous review concern about failing fast when
--data-dir
is not provided for real runs.
186-186
: Verify tensor parallelism equals GPU count assumption.Setting
tensor_model_parallel_size
toargs.train_gpus
assumes you want full tensor parallelism across all GPUs. For some model sizes or GPU counts, you might want pipeline parallelism or a different TP/PP split. Consider whether this should be configurable for flexibility.Based on learnings from past comments, this was mentioned as "default to 1, 8" but the current implementation hardcodes it to train_gpus.
190-195
: SLURM path assumption may be fragile.When
--use-slurm
is true, the script path is set to"examples/nemo_run/common/in_memory_mmlu.py"
assuming this relative path works in the SLURM container environment. Verify that:
- The SLURM packager (GitArchivePackager) preserves this directory structure
- The working directory in the SLURM job is the repository root
If the working directory differs, this path will fail.
291-306
: SLURM_CONFIG validation will fail with placeholder values.The
SlurmConfig
class (fromutils.py
) has validation in__post_init__
that raisesValueError
for empty required fields. With the current placeholder values:
account=""
will failtime="HH:MM:SS"
will fail (not a valid time format)container_image
points to a real image, but other fields are invalidjob_dir="/path/to/logs"
is a placeholder pathThis is good defensive behavior - it will fail fast with clear error messages when users try to run with
--use-slurm
without configuring these values. Consider adding a comment in the code to make this more explicit for users.The validation design ensures users cannot accidentally run with invalid configuration.
326-327
: NUM_TOKENS calculation addresses past review concern.The use of
int(90e9)
forNUM_TOKENS
is cleaner than the previous hardcoded value and addresses the past review comment about rounding. The subsequentDISTILL_STEPS
calculation correctly derives steps from tokens, batch size, and sequence length.examples/nemo_run/prune_distill/README.md (5)
24-31
: Flow diagram accurately reflects implementation.The mermaid diagram correctly represents the execution graph implemented in
nemo_prune_kd_flow.py
. The stage names and dependencies match the code exactly.
35-35
: Verify step count consistency.The README states "~28,000 steps" which aligns with the code calculation:
- NUM_TOKENS = 90e9 (line 326)
- DISTILL_GBS = 768 (line 325)
- SEQUENCE_LENGTH = 4096 (line 310)
- DISTILL_STEPS = 90e9 / 768 / 4096 ≈ 28,610
The approximation is reasonable and consistent.
67-74
: Clear dataset preparation instructions.The dataset preparation section correctly documents the manual process and provides the right command path relative to the prune_distill directory.
86-86
: Dataset path typo has been corrected.The command now correctly uses
climbmix_proc
(notclimbix_proc
), which matches the output directory created byprocess_climbmix.py
. This addresses the previous review concern.
91-99
: Supported models section matches code defaults.The default model (
Qwen/Qwen3-8B
) and recipe (qwen3_8b
) documented here correctly match the argument defaults innemo_prune_kd_flow.py
.
# WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo | ||
multi_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export | ||
_ = exp.add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Executor mutation may cause issues with executor reuse.
Mutating multi_gpu_executor.ntasks_per_node
after it has already been used by the distill task could cause unexpected behavior if executors cache or validate their configuration. Consider creating a separate executor for export instead:
- # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
- multi_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export
+ # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
+ if args.use_slurm:
+ export_executor = create_slurm_executor(
+ SLURM_CONFIG,
+ ntasks_per_node=1,
+ num_gpus=args.train_gpus,
+ nodes=args.nodes,
+ )
+ else:
+ export_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=1)
_ = exp.add(
export_model,
- executor=multi_gpu_executor,
+ executor=export_executor,
tail_logs=True,
name="04b_export",
dependencies=[s3],
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo | |
multi_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export | |
_ = exp.add( | |
# WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo | |
if args.use_slurm: | |
export_executor = create_slurm_executor( | |
SLURM_CONFIG, | |
ntasks_per_node=1, | |
num_gpus=args.train_gpus, | |
nodes=args.nodes, | |
) | |
else: | |
export_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=1) | |
_ = exp.add( | |
export_model, | |
executor=export_executor, | |
tail_logs=True, | |
name="04b_export", | |
dependencies=[s3], | |
) |
🤖 Prompt for AI Agents
In examples/nemo_run/prune_distill/nemo_prune_kd_flow.py around lines 274 to
276, the code mutates multi_gpu_executor.ntasks_per_node after it has been used
for the distill task which can break executor reuse or internal validation;
instead create a new executor instance for the export step (e.g., clone or
instantiate a separate executor), set ntasks_per_node = 1 on that new executor,
and pass that new executor to exp.add for the export job so the original
multi_gpu_executor remains unchanged.
What does this PR do?
Type of change: ? Updated example
Overview: ?
Usage
Prune + KD NeMo-Run example now supports slurm and uses the ClimbMix dataset
# Add a code snippet demonstrating how to use this
Testing
Ran e2e on cluster
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Refactor