-
Notifications
You must be signed in to change notification settings - Fork 170
Support checkpointing Minitron pruning scores / prune without sorting #361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds checkpoint save/restore for Minitron pruning activations and layer scores to enable re-pruning without re-running the forward pass; adapts Megatron activation collection/aggregation; extends searcher checkpoint I/O and distributed utilities with optional group support; updates tests, docs, examples, and CI env setup. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Runner as Prune Runner
participant Searcher as MCoreMinitronSearcher
participant Model as Dynamic Model
participant Dist as Distributed Utils
rect rgb(250,250,255)
note right of User: Start prune with config (may include scores_path/checkpoint, skip_sorting)
User->>Runner: prune(config={forward_loop?, scores_path?, skip_sorting?})
Runner->>Searcher: before_search(export_config)
end
alt checkpoint provided & exists
Searcher->>Searcher: load_search_checkpoint() -> activations_per_rank & layer_scores
Searcher->>Model: set_activations_and_layer_scores(restored)
else no checkpoint / missing
Searcher->>Model: run forward loop (collect per-rank activations & outputs)
Model-->>Searcher: activations_per_rank
Searcher->>Dist: allgather(activations_per_rank, group?)
Dist-->>Searcher: aggregated activations
Searcher->>Searcher: _get_layer_scores() -> layer_scores
Searcher->>Searcher: save_search_checkpoint(verbose?)
end
rect rgb(245,255,245)
alt skip_sorting == false
Searcher->>Searcher: sort_parameters / compute final drop list
else skip_sorting == true
Searcher->>Searcher: use existing parameter order
end
Searcher->>Runner: export pruned artifacts
Runner-->>User: done
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (15)
docs/source/guides/3_pruning.rst (1)
7-7
: Polish phrasing: "Check out"Use “Check out” instead of “Checkout” for correct phrasing.
- Checkout `Qwen 3 NeMo Minitron Pruning & Distillation <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation>`_ and + Check out `Qwen 3 NeMo Minitron Pruning & Distillation <https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation>`_ andtox.ini (1)
68-69
: Pin the mamba git dependency to a commit for reproducible, faster CIBuilding from git HEAD with --no-build-isolation is slow and non‑reproducible. Pin to a known commit and consider keeping build isolation on unless there’s a specific reason.
Proposed change:
- # Install Mamba model dependencies (takes 8-10mins!) - MAMBA_FORCE_BUILD="TRUE" pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + # Install Mamba model dependencies (pin to a commit for reproducibility; still takes ~8-10 mins) + # Replace <commit-sha> with a vetted commit from state-spaces/mamba + MAMBA_FORCE_BUILD="TRUE" pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git@<commit-sha>If build isolation can be enabled, prefer:
- MAMBA_FORCE_BUILD="TRUE" pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git@<commit-sha> + MAMBA_FORCE_BUILD="TRUE" pip install git+https://github.com/state-spaces/mamba.git@<commit-sha>Please confirm if GPU tests rely on local repo deps that require
--no-build-isolation
; otherwise enable isolation.examples/llm_distill/README.md (1)
147-147
: LGTM; minor wording tweak optionalConsider “Check out” instead of “Checkout” for consistency with other docs.
-You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. +You can also check out the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation), which showcase Minitron pruning followed by distillation for Qwen 3 8B step-by-step in the NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.examples/pruning/README.md (3)
69-72
: Clarify dataset-dependent checkpoint behavior“Skip checkpoint path” is ambiguous. Recommend explicitly saying to omit the “checkpoint” key when changing datasets, since scores are dataset-dependent.
-# Save minitron scores so we can re-run pruning with different export configs without running the forward loop again -# NOTE: Skip checkpoint path on re-running if you want to change the dataset +# Save Minitron scores so we can re-run pruning with different export configs without re-running the forward loop. +# NOTE: If you change the dataset, omit the "checkpoint" key so fresh activations/scores are collected.
77-77
: Optional: suggest experiment-scoped checkpoint filenameTo avoid accidental reuse across experiments, consider an experiment-scoped path.
- config={"forward_loop": forward_loop, "checkpoint": "modelopt_minitron_scores.pth"}, + config={"forward_loop": forward_loop, "checkpoint": "runs/exp1/modelopt_minitron_scores.pth"},
96-101
: Docs alignment LGTM; minor grammar optionalThe Qwen tutorial link update looks good. Consider the small phrasing tweaks below for consistency.
-### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) +### Minitron Pruning for Megatron‑LM / NeMo framework LLMs (e.g., Qwen 3, Nemotron Nano) -You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. +You can also check out the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation), which showcase Minitron pruning followed by distillation for Qwen 3 8B step-by-step in the NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.modelopt/torch/nas/plugins/megatron.py (1)
1246-1252
: Scope the reduction to PP group (and fix misleading comment)Comment says “Reduce over all PP ranks” but all_reduce uses the default group (all ranks). Scope to PP group; if you intend averaging, divide by PP world size.
- # Reduce over all PP ranks - activations = activations.clone() - torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) # average + # Reduce over all PP ranks (sum). If averaging is needed, divide by PP world size. + activations = activations.clone() + torch.distributed.all_reduce( + activations, op=torch.distributed.ReduceOp.SUM, group=get_pipeline_model_parallel_group() + ) + # Optional: uncomment to average instead of sum + # activations /= get_pipeline_model_parallel_world_size()modelopt/torch/opt/searcher.py (2)
239-248
: Make checkpoint type PathLike‑friendly and improve warning attribution.
- Tests pass a
Path
object; update the local type to acceptos.PathLike[str]
to align with usage.- Add
stacklevel=2
so the warn points at the caller site.Apply this diff:
- checkpoint: str | None = self.config["checkpoint"] + checkpoint: os.PathLike[str] | str | None = self.config["checkpoint"] @@ - if dist.is_master(): - warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") + if dist.is_master(): + warn( + f"Checkpoint {checkpoint} does not exist! Initializing from scratch.", + stacklevel=2, + )
258-268
: Optional: defaultverbose
to config when not provided.So callers don’t have to thread a flag, you can default to
self.config["verbose"]
whenverbose
is None.Proposed change:
- def save_search_checkpoint(self, verbose=False) -> None: + def save_search_checkpoint(self, verbose: bool | None = None) -> None: @@ - if verbose: + if (self.config.get("verbose", False) if verbose is None else verbose): print(f"Saving searcher state to {checkpoint}...")tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
98-99
: Pass a string path to avoid static type friction.
config["checkpoint"]
is annotated asstr | None
in code; casting improves static checks.Apply this diff:
- config={"forward_loop": forward_loop, "checkpoint": ckpt_path}, + config={"forward_loop": forward_loop, "checkpoint": str(ckpt_path)},
124-133
: Strengthen re‑pruning validation by asserting post‑restore behavior.Add a quick inference and shape checks after re‑pruning to exercise the checkpoint path.
Apply this diff:
# Assert re-pruning from checkpoint works without running the forward loop again model = _get_model(initialize_megatron=False) mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used - config={"checkpoint": ckpt_path}, + config={"checkpoint": str(ckpt_path)}, ) + + # Validate pruned shapes and forward on restored path + mixer = None + for layer in model.decoder.layers: + if isinstance(layer, MambaLayer): + mixer = layer.mixer + break + assert mixer is not None + bc = 2 * mixer.ngroups * mixer.d_state + assert mixer.nheads == pruned_mamba_num_heads + assert mixer.headdim == pruned_mamba_head_dim + assert mixer.in_proj.input_size == pruned_hidden_size + assert mixer.d_inner == pruned_mamba_num_heads * pruned_mamba_head_dim + assert mixer.in_proj.output_size == 2 * mixer.d_inner + bc + pruned_mamba_num_heads + assert mixer.out_proj.input_size == mixer.d_inner + assert mixer.out_proj.output_size == pruned_hidden_size + assert mixer.conv1d.in_channels == mixer.conv1d.out_channels == mixer.d_inner + bc + run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size)modelopt/torch/prune/plugins/mcore_minitron.py (2)
126-135
: Defensive restore from checkpoint to avoid KeyError.If a module name is missing in the checkpoint maps, current code will KeyError. Guard lookups.
Apply this diff:
- if self.scores and self.activations: # Available from checkpoint - print_rank_0("Loading activations and scores per rank from checkpoint...") + if self.scores and self.activations: # Available from checkpoint + if self.config.get("verbose", False): + print_rank_0("Loading activations and scores per rank from checkpoint...") assert self.ckpt_world_size == dist.size(), "World size mismatch!" rank = dist.rank() - for n, m in self.model.named_modules(): - if hasattr(m, "_scores"): - m._scores = self.scores[rank][n] - if hasattr(m, "_activations"): - m._activations = self.activations[rank][n] + scores_for_rank = self.scores.get(rank, {}) + activations_for_rank = self.activations.get(rank, {}) + for n, m in self.model.named_modules(): + if hasattr(m, "_scores") and n in scores_for_rank: + m._scores = scores_for_rank[n] + if hasattr(m, "_activations") and n in activations_for_rank: + m._activations = activations_for_rank[n]
136-166
: Gate logs byconfig["verbose"]
and honor it downstream.Reduce noisy prints and align with the searcher’s verbosity semantics.
Apply this diff:
- print_rank_0("Running forward loop...") + if self.config.get("verbose", False): + print_rank_0("Running forward loop...") @@ - self.save_search_checkpoint(verbose=True) + self.save_search_checkpoint(verbose=self.config.get("verbose", False)) - dist.barrier() + dist.barrier() @@ - sort_parameters(self.model, self.hps_to_sort, verbose=True) + sort_parameters(self.model, self.hps_to_sort, verbose=self.config.get("verbose", False))tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
118-119
: Pass a string path for checkpoint to match the API annotation.Prevents static typing complaints and keeps consistency across tests.
Apply this diff:
- config={"forward_loop": forward_loop, "checkpoint": ckpt_path}, + config={"forward_loop": forward_loop, "checkpoint": str(ckpt_path) if ckpt_path else None},
147-156
: Assert restored path behavior to fully validate checkpoint re‑use.Add shape checks and a quick forward pass after re‑pruning.
Apply this diff:
if ckpt_path: model = _get_model(initialize_megatron=False) mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used - config={"checkpoint": ckpt_path}, + config={"checkpoint": str(ckpt_path)}, ) + + # Validate shapes match expectations after restore + for layer in model.decoder.layers: + assert layer.mlp.linear_fc1.weight.shape == ( + pruned_ffn * (2 if activation_func == "swiglu" else 1), + pruned_hidden_size, + ) + assert layer.mlp.linear_fc2.weight.shape == (pruned_hidden_size, pruned_ffn) + assert layer.self_attention.linear_qkv.weight.shape == ( + (pruned_num_heads_per_group + 2) * pruned_num_query_groups * model.config.kv_channels, + pruned_hidden_size, + ) + assert layer.self_attention.linear_proj.weight.shape == ( + pruned_hidden_size, + pruned_num_heads_per_group * pruned_num_query_groups * model.config.kv_channels, + ) + run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(3 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)tests/conftest.py (1)
verbose
(22-23)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
_get_model
(45-58)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
modelopt/torch/nas/plugins/megatron.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/prune/plugins/mcore_minitron.py (4)
modelopt/torch/opt/searcher.py (2)
default_search_config
(73-85)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (2)
allgather
(132-159)barrier
(171-176)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(70-87)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/nas/plugins/megatron.py (1)
1296-1306
: Depth pruning: PP aggregation and logging LGTMGathering per-rank layer scores into a global 1-indexed map and logging on rank 0 is correct and clearer.
modelopt/torch/opt/searcher.py (1)
30-30
: Importingwarn
is fine.No issues with bringing in
warnings.warn
for checkpoint notices.CHANGELOG.rst (1)
20-20
: Changelog entry reads well.Accurately reflects the new checkpointing capability for Minitron pruning.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #361 +/- ##
==========================================
- Coverage 73.46% 73.44% -0.03%
==========================================
Files 172 172
Lines 17640 17645 +5
==========================================
Hits 12960 12960
- Misses 4680 4685 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
1c7d206
to
3d7f8af
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/nas/plugins/megatron.py (1)
1249-1252
: Reduce over PP group, not the default global group.The comment says “Reduce over all PP ranks” but all_reduce uses the default group. On multi-DP runs this aggregates across DP as well, skewing importance.
Apply this diff:
- activations = activations.clone() - torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) # average + activations = activations.clone() + torch.distributed.all_reduce( + activations, + op=torch.distributed.ReduceOp.SUM, + group=get_pipeline_model_parallel_group(), + ) # sum across PP ranks
🧹 Nitpick comments (4)
modelopt/torch/opt/searcher.py (2)
251-256
: Gate the “Loading searcher state…” print to master and verbose.Prevents duplicate logs from all ranks and respects verbosity.
Apply this diff:
- print(f"Loading searcher state from {checkpoint}...") + if dist.is_master() and self.config.get("verbose", False): + print(f"Loading searcher state from {checkpoint}...", flush=True)
252-253
: Optional: use weights_only=True for safer deserialization.Mitigates pickle code-execution risk on untrusted files (Torch ≥2.5).
Apply this diff:
- state_dict = torch.load(checkpoint, weights_only=False) + state_dict = torch.load(checkpoint, weights_only=True)modelopt/torch/prune/plugins/mcore_minitron.py (2)
127-136
: Checkpoint restore path: small robustness tweak.Direct indexing assumes every module with _scores/_activations had entries saved. Safer to use .get with default None to avoid KeyError in edge cases.
Apply this diff:
- for n, m in self.model.named_modules(): - if hasattr(m, "_scores"): - m._scores = self.scores[rank][n] - if hasattr(m, "_activations"): - m._activations = self.activations[rank][n] + for n, m in self.model.named_modules(): + if hasattr(m, "_scores"): + m._scores = self.scores.get(rank, {}).get(n, getattr(m, "_scores", None)) + if hasattr(m, "_activations"): + m._activations = self.activations.get(rank, {}).get( + n, getattr(m, "_activations", None) + )
147-154
: Filter out None to shrink checkpoint and avoid Nones on restore.Skip modules with no collected data (None) to reduce size and prevent unexpected None assignments later.
Apply this diff:
- for n, m in self.model.named_modules(): - if hasattr(m, "_scores"): - rank_scores[n] = m._scores - if hasattr(m, "_activations"): - rank_activations[n] = m._activations + for n, m in self.model.named_modules(): + if hasattr(m, "_scores") and getattr(m, "_scores") is not None: + rank_scores[n] = m._scores + if hasattr(m, "_activations") and getattr(m, "_activations") is not None: + rank_activations[n] = m._activations
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(3 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- tox.ini
- docs/source/guides/3_pruning.rst
- examples/pruning/README.md
- examples/llm_distill/README.md
🧰 Additional context used
🧬 Code graph analysis (5)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(70-87)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/opt/searcher.py (1)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)
modelopt/torch/prune/plugins/mcore_minitron.py (5)
modelopt/torch/opt/searcher.py (3)
default_search_config
(73-85)default_state_dict
(89-90)save_search_checkpoint
(258-271)modelopt/torch/opt/config.py (1)
keys
(132-134)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (2)
allgather
(132-159)barrier
(171-176)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
_get_model
(45-58)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
modelopt/torch/nas/plugins/megatron.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (17)
CHANGELOG.rst (1)
20-20
: Changelog entry reads well.Accurately reflects the new checkpointing capability for Minitron pruning.
modelopt/torch/opt/searcher.py (2)
243-248
: Good: graceful handling of missing checkpoint.Returning False on None/missing path and master-only warn is appropriate.
258-268
: Verbose save is master-only. LGTM.The optional verbose print is correctly gated by master-only save.
modelopt/torch/nas/plugins/megatron.py (2)
1237-1241
: Fix: consistent id(module) keying for activations.Corrects the earlier mismatch between membership check and dict keys.
1296-1306
: Depth-pruning layer-score gather/readout is clear.Gather by PP group and rank-0 print are correct and helpful.
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (4)
34-59
: Helper factory for model construction is clean.Allows re-instantiation with/without Megatron init for re-pruning-from-ckpt.
98-99
: Exercise checkpointed pruning path.Passing checkpoint with forward_loop the first time is correct.
124-133
: Re-pruning without forward loop validated.Second prune run from checkpoint only looks good and aligns with design.
135-140
: Per-test checkpoint path wiring is sound.Using tmp_path and partial for mp spawn is tidy.
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (4)
70-89
: Model factory + ckpt plumbed through.Streamlines re-instantiation and optional checkpoint usage.
113-119
: Initial prune with optional checkpoint is correct.Including checkpoint (possibly None) is fine; loader handles None.
147-156
: Re-pruning-from-checkpoint path guarded by ckpt presence.Good conditional; avoids unnecessary second run when not testing ckpt.
159-179
: Parametrization adds ckpt case without inflating matrix.Balanced coverage; last case exercises checkpointed flow.
modelopt/torch/prune/plugins/mcore_minitron.py (4)
66-79
: Expose activations/scores/ckpt_world_size in state.Type hints and default_state_dict look correct and future-proof.
89-95
: Deep-copy export_config before mutation.Prevents user input side-effects. Good.
156-166
: All-gather + master-only save + barrier is solid.The flow ensures every rank sees gathered state and synchronizes after save.
167-188
: Post-restore/search flow looks correct.Sorting only selected hparams and updating model config is consistent.
3d7f8af
to
90d413a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (6)
examples/pruning/README.md (2)
70-71
: Capitalize “Minitron” and clarify the NOTE.Tighten language and mention world-size requirement when restoring.
-# Save minitron scores so we can re-run pruning with different export configs without running the forward loop again -# NOTE: Skip checkpoint path on re-running if you want to change the dataset +# Save Minitron scores so you can re-run pruning with different export configs without running the forward loop again +# NOTE: When re-running with a different dataset, do not pass "checkpoint" in the config. +# Ensure the distributed world size matches the checkpoint used for collection.
77-78
: Consider using a path variable to make the example copy‑pasteable.Small ergonomics tweak; avoids repeating the literal and works cross‑platform.
-mtp.prune( +ckpt_path = "modelopt_minitron_scores.pth" +mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used - config={"forward_loop": forward_loop, "checkpoint": "modelopt_minitron_scores.pth"}, + config={"forward_loop": forward_loop, "checkpoint": ckpt_path}, )tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
98-99
: PathLike in config is usually fine; cast to str for portability.
Some APIs expect str; casting avoids edge cases.- config={"forward_loop": forward_loop, "checkpoint": ckpt_path}, + config={"forward_loop": forward_loop, "checkpoint": str(ckpt_path)},
124-133
: Add post-restore assertions to validate re-pruning from checkpoint.Currently, the second prune is not verified. Mirror key assertions to ensure identical pruning without a forward loop.
# Assert re-pruning from checkpoint works without running the forward loop again model = _get_model(initialize_megatron=False) mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used - config={"checkpoint": ckpt_path}, + config={"checkpoint": str(ckpt_path)}, ) + + # Re-assert shapes on the restored-pruned model + mamba_layer = None + for layer in model.decoder.layers: + if isinstance(layer, MambaLayer): + mamba_layer = layer + break + assert mamba_layer is not None, f"No MambaLayer found in the model PP rank {rank} after restore!" + mixer = mamba_layer.mixer + bc = 2 * mixer.ngroups * mixer.d_state + assert mixer.nheads == pruned_mamba_num_heads + assert mixer.headdim == pruned_mamba_head_dim + assert mixer.in_proj.input_size == pruned_hidden_size + assert mixer.d_inner == pruned_mamba_num_heads * pruned_mamba_head_dim + assert mixer.in_proj.output_size == 2 * mixer.d_inner + bc + pruned_mamba_num_heads + assert mixer.out_proj.input_size == mixer.d_inner + assert mixer.out_proj.output_size == pruned_hidden_size + assert mixer.conv1d.in_channels == mixer.conv1d.out_channels == mixer.d_inner + bcmodelopt/torch/prune/plugins/mcore_minitron.py (2)
127-136
: Harden checkpoint loading against missing keys and improve error.Avoid KeyError if a module’s entry is absent; make the mismatch assert more informative.
- print_rank_0("Loading activations and scores per rank from checkpoint...") - assert self.ckpt_world_size == dist.size(), "World size mismatch!" - rank = dist.rank() - for n, m in self.model.named_modules(): - if hasattr(m, "_scores"): - m._scores = self.scores[rank][n] - if hasattr(m, "_activations"): - m._activations = self.activations[rank][n] + print_rank_0("Loading activations and scores per rank from checkpoint...") + assert ( + self.ckpt_world_size == dist.size() + ), f"World size mismatch! checkpoint={self.ckpt_world_size}, current={dist.size()}" + rank = dist.rank() + scores_rank = self.scores.get(rank, {}) + acts_rank = self.activations.get(rank, {}) + for n, m in self.model.named_modules(): + if hasattr(m, "_scores") and n in scores_rank: + m._scores = scores_rank[n] + if hasattr(m, "_activations") and n in acts_rank: + m._activations = acts_rank[n]
137-166
: Reduce memory: keep aggregated tensors only on rank 0 and move to CPU before saving.Prevents O(world_size) duplication on every rank and frees VRAM.
- # Store all ranks' data in the searcher's state - for r in range(dist.size()): - self.scores[r] = all_scores[r] - self.activations[r] = all_activations[r] - - self.save_search_checkpoint(verbose=True) - dist.barrier() + # Store all ranks' data in the searcher's state (rank 0 only) and save + if dist.is_master(): + for r in range(dist.size()): + self.scores[r] = {k: v.detach().cpu() for k, v in all_scores[r].items()} + self.activations[r] = {k: v.detach().cpu() for k, v in all_activations[r].items()} + self.ckpt_world_size = dist.size() + self.save_search_checkpoint(verbose=True) + dist.barrier()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(3 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (6)
- examples/llm_distill/README.md
- modelopt/torch/opt/searcher.py
- tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
- docs/source/guides/3_pruning.rst
- CHANGELOG.rst
- tox.ini
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/nas/plugins/megatron.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/prune/plugins/mcore_minitron.py (3)
modelopt/torch/opt/searcher.py (2)
default_search_config
(73-85)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/utils/distributed.py (2)
allgather
(132-159)barrier
(171-176)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(70-87)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (12)
modelopt/torch/nas/plugins/megatron.py (2)
1237-1240
: Bug: Fixed the dict key consistency issue.The code correctly uses
id(module)
as the key to check membership inself._activations
and stores activations using the same key, which fixes the bug identified in the past review comment.
1296-1305
: LGTM: Improved logging accuracy for depth pruning.The updated comment and logging correctly reference "PP ranks" instead of "TP regions" and uses 1-indexed layer numbering, making the debug information more accurate and useful for understanding layer-wise importance scores during pruning.
examples/pruning/README.md (2)
96-96
: Header update looks good.
100-100
: Verified — GitHub tutorial link resolves and is correct. Confirmed https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation exists in the NVIDIA-NeMo main branch and contains the pruning + distillation (Minitron) tutorial referenced.tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (4)
17-18
: Importing partial is appropriate for passing the checkpoint path.
34-44
: Updated worker signature is consistent with spawn usage.
45-59
: Local model factory is a nice touch for re-instantiation.
135-140
: Good use of tmp_path and partial to thread the checkpoint into workers.modelopt/torch/prune/plugins/mcore_minitron.py (4)
66-69
: Corrected types for per-rank nested maps.Matches the stored structure. Thanks for fixing.
77-79
: Useful default state including world size.
89-95
: Deep-copying export_config is the right defensive move.
167-176
: Sorting step placement looks correct.Relies on module-local _scores/_activations for the immediate run; checkpoint path sets them explicitly beforehand.
Please confirm sort_parameters does not require aggregated values to be re-assigned to modules within the same run (it currently uses local per-rank statistics).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!
90d413a
to
fdbc2d0
Compare
fdbc2d0
to
1797f07
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
modelopt/torch/nas/plugins/megatron.py (1)
1249-1252
: Bug: all_reduce aggregates over the entire world; restrict to PP groupThis sums across all ranks, not just PP, contradicting the comment and skewing importance scores when DP>1. Pass the PP group and fix the misleading “average” comment.
Apply this diff:
- # Reduce over all PP ranks - activations = activations.clone() - torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) # average + # Reduce over all PP ranks (sum) + group = get_pipeline_model_parallel_group() + torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM, group=group) + # If average is intended instead of sum, uncomment: + # activations /= get_pipeline_model_parallel_world_size()modelopt/torch/utils/distributed.py (3)
140-151
: Bug: Type mismatch causes padding branch to error (tensor vs int).tensor_size is a Tensor; comparing it to max_size (int) and using it in arithmetic will error or behave incorrectly. Compute and use an int local_size instead.
Apply this diff:
- tensor_size = torch.LongTensor([tensor.numel()]).cuda() + local_size = tensor.numel() + tensor_size = torch.LongTensor([local_size]).cuda() @@ - if tensor_size != max_size: - padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda() + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).cuda() tensor = torch.cat((tensor, padding), dim=0)
101-121
: Broadcast should use the provided group for size/rank checks.With subgroup usage, early-return and src-rank checks must be group-scoped; otherwise behavior is incorrect when world size > 1 and group size == 1 or when the group’s rank 0 is not global rank 0.
Apply this diff:
- if size() == 1: + if size(group) == 1: return obj @@ - if rank() == src: + if rank(group) == src: tensor = _serialize(obj).cuda() @@ - tensor_size = ( - torch.LongTensor([tensor.numel()]).cuda() if rank() == src else torch.LongTensor([0]).cuda() - ) + tensor_size = ( + torch.LongTensor([tensor.numel()]).cuda() + if rank(group) == src + else torch.LongTensor([0]).cuda() + ) @@ - if rank() != src: + if rank(group) != src: tensor = torch.ByteTensor(size=(tensor_size.item(),)).cuda()
171-175
: Barrier early-exit should respect group size.Use size(group) to avoid unnecessary global syncs or mismatches with subgroup barriers.
Apply this diff:
- if size() == 1: + if size(group) == 1: return
🧹 Nitpick comments (7)
modelopt/torch/nas/plugins/megatron.py (3)
1287-1307
: Return type mismatch and stricter-than-needed assertion_layer_scores values are floats (see MambaTransformerLayerMixin), not tensors. Also, asserting > 0 can be brittle; non-negative or presence checks are safer.
Apply this diff to align the type and soften the check:
- def _get_layer_scores(self) -> dict[int, torch.Tensor]: + def _get_layer_scores(self) -> dict[int, float]: @@ - for layer in self.decoder.layers: - assert layer._scores > 0, "No scores collected for importance estimation." + for layer in self.decoder.layers: + assert layer._scores is not None, "No scores collected for importance estimation."
1347-1363
: Trim None activations to reduce payload and noiseAvoid gathering modules that never collected activations.
Apply this diff:
- local_activations = {} + local_activations = {} for n, m in self.named_modules(): - if hasattr(m, "_activations"): - local_activations[n] = m._activations + if hasattr(m, "_activations") and m._activations is not None: + local_activations[n] = m._activations
1364-1388
: Docstring and robustness: correct key description and guard missing entriesThe activations map is keyed by module name, not layer number. Also guard against missing keys during restore and validate layer_scores coverage.
Apply this diff:
def set_activations_and_layer_scores( self, activations_per_rank: list[dict[str, torch.Tensor]], layer_scores: dict[int, torch.Tensor], ) -> None: """Set the pre-computed layer_scores and per-rank activations instead of running forward. Args: layer_scores: Dict from layer_number (1-indexed) to score. - activations_per_rank: List of dicts from layer_number (1-indexed) to activations. - Should match PP size. + activations_per_rank: List (len=PP size) of dicts mapping module name (nn.Module.named_modules()) to activations. + Should match PP size. """ rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( len(activations_per_rank), activations_per_rank, pp_size, ) + expected_layers = {layer.layer_number for layer in self.decoder.layers} + missing = expected_layers - set(layer_scores) + assert not missing, f"Missing layer scores for layers: {sorted(missing)}" for layer in self.decoder.layers: layer._scores = layer_scores[layer.layer_number] - for n, m in self.named_modules(): - if hasattr(m, "_activations"): - m._activations = activations_per_rank[rank][n] + rank_activations = activations_per_rank[rank] + for n, m in self.named_modules(): + if hasattr(m, "_activations") and n in rank_activations: + m._activations = rank_activations[n]CHANGELOG.rst (1)
20-20
: Changelog entry reads well and matches the feature.Consider adding the config keys mentioned in docs/tests (“checkpoint”, “skip_sorting”) for quick discoverability.
modelopt/torch/utils/distributed.py (1)
253-257
: Nit: typo in comment.“backebnd” → “backend”.
- # NCCL has an issue with calling barrier. So we just use the gloo backebnd for group barriers. + # NCCL has an issue with calling barrier. So we just use the gloo backend for group barriers.tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
124-133
: Strengthen re-prune validation.Verify the checkpoint exists and the re-pruned model runs a forward pass with pruned dims.
Apply this diff:
# Assert re-pruning from checkpoint works without running the forward loop again - model = _get_model(initialize_megatron=False) + assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}" + model = _get_model(initialize_megatron=False) mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used config={"checkpoint": ckpt_path}, ) + # Forward should work with pruned hidden size as before + run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
159-169
: Add a forward check after re-pruning to fully validate restoration path.Ensures the checkpointed activations/scores drive consistent pruning without a forward loop.
if ckpt_path: model = _get_model(initialize_megatron=False) - mtp.prune( + model, _ = mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used config={"checkpoint": ckpt_path}, ) + run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(4 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)modelopt/torch/utils/distributed.py
(3 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- tox.ini
- examples/llm_distill/README.md
- modelopt/torch/opt/searcher.py
- examples/pruning/README.md
🧰 Additional context used
🧬 Code graph analysis (4)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(71-88)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/prune/plugins/mcore_minitron.py (4)
modelopt/torch/opt/dynamic.py (2)
DynamicModule
(338-914)config
(1265-1278)modelopt/torch/opt/searcher.py (2)
default_search_config
(73-85)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/nas/plugins/megatron.py (2)
set_activations_and_layer_scores
(1364-1387)get_activations_and_layer_scores
(1347-1362)
modelopt/torch/nas/plugins/megatron.py (4)
modelopt/torch/opt/dynamic.py (2)
get_hparam
(819-821)get_hparam
(1235-1240)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (3)
allgather
(132-159)rank
(68-72)rank
(200-202)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
_get_model
(45-58)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (14)
modelopt/torch/nas/plugins/megatron.py (2)
1237-1241
: Fix confirmed: consistent id-based keyingThe membership check now correctly uses id(module) consistently, resolving the accumulation bug flagged earlier.
1308-1316
: Early return when no depth pruning is cleanShort-circuiting when num_layers is max prevents unnecessary work. Looks good.
modelopt/torch/utils/distributed.py (2)
127-130
: all_gather now group-aware — good change.
162-165
: allreduce group threading LGTM.tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
45-59
: Model factory refactor LGTM.Local _get_model simplifies re-prune flows and mirrors GPT test structure.
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (3)
114-121
: Config assembly LGTM; guard on skip_sorting is correct.
129-132
: Good: validate pruning stats only when sorting ran.
71-89
: Model factory refactor LGTM.Matches mamba test pattern; makes re-prune path clear.
modelopt/torch/prune/plugins/mcore_minitron.py (5)
66-68
: Public state typing LGTM.Types align with get/set APIs returning list[dict[str, Tensor]] and dict[int, Tensor].
72-77
: Default search config extension LGTM.skip_sorting default is sensible; max_iter_data_loader retained.
80-82
: State dict defaults LGTM.Matches the checkpointed fields used in run_search.
92-97
: Deep-copy and validation improvements LGTM.Clearer error with received keys; safe mutation isolation.
130-161
: Checkpoint-aware run flow LGTM.
- Properly unwraps DDP/DP to reach DynamicModule.
- Loads activations/scores when present; otherwise runs forward_loop and saves checkpoint.
- Conditional sorting controlled by skip_sorting.
No issues spotted.
Please confirm BaseSearcher loads the saved state_dict when config["checkpoint"] is provided so non-master ranks also see activations_per_rank/layer_scores on resume.
docs/source/guides/3_pruning.rst (1)
7-7
: Link update LGTM; URL resolves to the correct Qwen 3 pruning & distillation tutorial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/nas/plugins/megatron.py (1)
1249-1252
: Bug: all_reduce should use the PP group (and “average” comment is misleading).You’re reducing across the world group; this will overcount when DP > 1. Use the pipeline group. If you intend an average, divide by PP world size.
Apply this diff:
- # Reduce over all PP ranks - activations = activations.clone() - torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) # average + # Reduce over all PP ranks + activations = activations.clone() + torch.distributed.all_reduce( + activations, op=torch.distributed.ReduceOp.SUM, group=get_pipeline_model_parallel_group() + ) + # Optionally average to match the comment: + # activations /= get_pipeline_model_parallel_world_size()
🧹 Nitpick comments (2)
modelopt/torch/utils/distributed.py (1)
132-159
: Avoid ambiguous Tensor-vs-int comparisons in allgather padding logic.Use .item() when comparing/arithmetizing tensor_size to prevent ambiguous truth-value errors and ensure correct padding size.
Apply this diff:
- if tensor_size != max_size: - padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda() + if tensor_size.item() != max_size: + padding = torch.ByteTensor(size=(max_size - tensor_size.item(),)).cuda() tensor = torch.cat((tensor, padding), dim=0)modelopt/torch/prune/plugins/mcore_minitron.py (1)
151-156
: Make checkpoint save verbosity respect config.Always-verbose saving can be noisy. Use the configured verbosity.
Apply this diff:
- self.save_search_checkpoint(verbose=True) + self.save_search_checkpoint(verbose=bool(self.config.get("verbose", False)))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(4 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)modelopt/torch/utils/distributed.py
(3 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- tox.ini
🚧 Files skipped from review as they are similar to previous changes (4)
- docs/source/guides/3_pruning.rst
- examples/llm_distill/README.md
- examples/pruning/README.md
- CHANGELOG.rst
🧰 Additional context used
🧬 Code graph analysis (5)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(71-88)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)tests/conftest.py (1)
verbose
(22-23)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
_get_model
(45-58)forward_loop
(70-72)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
modelopt/torch/prune/plugins/mcore_minitron.py (4)
modelopt/torch/opt/dynamic.py (2)
DynamicModule
(338-914)config
(1265-1278)modelopt/torch/opt/searcher.py (3)
default_search_config
(73-85)default_state_dict
(89-90)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/nas/plugins/megatron.py (2)
set_activations_and_layer_scores
(1364-1387)get_activations_and_layer_scores
(1347-1362)
modelopt/torch/nas/plugins/megatron.py (5)
modelopt/torch/opt/dynamic.py (2)
get_hparam
(819-821)get_hparam
(1235-1240)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/hparam.py (3)
max
(153-155)active
(101-103)active
(106-113)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (3)
allgather
(132-159)rank
(68-72)rank
(200-202)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
🔇 Additional comments (14)
modelopt/torch/nas/plugins/megatron.py (1)
1237-1241
: Good fix: consistent id(module) usage for activation keys.Using id(module) for both membership check and updates prevents accidental overwrites.
modelopt/torch/utils/distributed.py (2)
127-130
: Group-aware _allgather looks good.Forwarding the group argument matches the new API shape and usage.
162-169
: allreduce: group propagation is correct.Leverages the group-aware allgather; interface is consistent.
modelopt/torch/prune/plugins/mcore_minitron.py (2)
66-68
: Accurate state typing for checkpoint contents.Types now reflect the stored structures returned by the dynamic module helpers.
92-97
: Good: deep-copy export_config and clearer error report.Prevents side effects and improves diagnostics.
modelopt/torch/opt/searcher.py (2)
243-249
: Graceful handling of missing checkpoints.Returning False and warning only on master is appropriate.
266-271
: Optional verbose save message is a nice touch.Non-breaking and useful for debugging.
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (3)
34-41
: Test harness accepts a checkpoint path for reuse.Matches the new workflow and exercises restore-without-forward.
124-133
: Re-pruning from checkpoint path validated.Covers the fast path without forward_loop.
136-140
: Multiprocess job wiring with ckpt path is correct.partial usage and backend selection look good.
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (4)
114-122
: Config composition (skip_sorting + checkpoint + optional forward_loop) is sound.Keeps constraints clear and avoids running forward when skip_sorting is on.
129-132
: Asserting pruning_stats presence only when sorting ran is correct.Prevents false negatives in skip_sorting mode.
159-169
: Re-pruning-from-checkpoint path is covered.Ensures load-only flow works for export without forward.
214-232
: Parametrization covers both skip_sorting and checkpointed flows.End-to-end coverage looks sufficient.
1797f07
to
136e94d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/utils/distributed.py (1)
141-153
: Fix size type mismatch in allgather padding (runtime error).tensor_size is a Tensor but later used as an int, leading to ambiguous Tensor truth-value and invalid subtraction.
Apply this diff:
- tensor_size = torch.LongTensor([tensor.numel()]).cuda() + local_size = int(tensor.numel()) + tensor_size = torch.LongTensor([local_size]).cuda() @@ - if tensor_size != max_size: - padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda() + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
🧹 Nitpick comments (14)
examples/llm_distill/README.md (1)
147-147
: Good link update; minor follow-up suggestionConsider adding a brief note that the NeMo tutorial covers pruning first, then distillation (as hinted), and optionally pin the link to a stable commit to avoid future breakage if the tutorials move.
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
124-133
: Add assertions to validate re-pruning from checkpointRe-pruning path is exercised but not validated. Re-run the structural asserts on the re-pruned model to ensure restoration worked without the forward loop.
Apply this diff:
# Assert re-pruning from scores_path works without running the forward loop again model = _get_model(initialize_megatron=False) mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, # Not used config={"scores_path": ckpt_path}, ) + + # Re-acquire a MambaLayer and re-assert shapes on the re-pruned model + mamba_layer = None + for layer in model.decoder.layers: + if isinstance(layer, MambaLayer): + mamba_layer = layer + break + assert mamba_layer is not None, f"No MambaLayer found in the model PP rank {rank} after re-pruning!" + mixer = mamba_layer.mixer + bc = 2 * mixer.ngroups * mixer.d_state + assert mixer.nheads == pruned_mamba_num_heads + assert mixer.headdim == pruned_mamba_head_dim + assert mixer.in_proj.input_size == pruned_hidden_size + assert mixer.d_inner == pruned_mamba_num_heads * pruned_mamba_head_dim + assert mixer.in_proj.output_size == 2 * mixer.d_inner + bc + pruned_mamba_num_heads + assert mixer.out_proj.input_size == mixer.d_inner + assert mixer.out_proj.output_size == pruned_hidden_size + assert mixer.conv1d.in_channels == mixer.conv1d.out_channels == mixer.d_inner + bc + + # Sanity check a forward pass again + run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size)modelopt/torch/opt/searcher.py (2)
243-249
: Improve warning context and portability when checkpoint missingAdd stacklevel to surface the caller location in logs.
Apply this diff:
- if dist.is_master(): - warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") + if dist.is_master(): + warn( + f"Checkpoint {checkpoint} does not exist! Initializing from scratch.", + stacklevel=2, + )
251-256
: Guard load print by verbosity and load on CPUAvoid unconditional prints in distributed runs; also map to CPU to prevent GPU dependency on load.
Apply this diff:
- print(f"Loading searcher state from {checkpoint}...") - state_dict = torch.load(checkpoint, weights_only=False) + if getattr(self, "config", {}).get("verbose", False) and dist.is_master(): + print(f"Loading searcher state from {checkpoint}...") + state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False)examples/pruning/README.md (2)
81-82
: Clarify skip_sorting usage.Mention that skip_sorting assumes parameters were already sorted from a prior run or checkpoint; otherwise accuracy may degrade.
98-103
: Nit: minor wording polish.“Checkout” → “Check out”.
modelopt/torch/utils/distributed.py (2)
101-106
: Broadcast short‑circuit should respect group size.Use size(group) to avoid unnecessary work when the provided group has a single rank.
- if size() == 1: + if size(group) == 1: return obj
171-176
: Barrier short‑circuit should respect group size.Avoid invoking a barrier for a size‑1 group.
- if size() == 1: + if size(group) == 1: returnmodelopt/torch/prune/plugins/mcore_minitron.py (4)
66-68
: Align type hints with actual values.layer_scores holds floats (from .item()), not tensors.
- layer_scores: dict[int, torch.Tensor] + layer_scores: dict[int, float]
85-88
: Don’t clobber an explicitly provided checkpoint path.If users pass checkpoint directly, prefer it when scores_path is None.
- config["checkpoint"] = config["scores_path"] + config["checkpoint"] = config.get("scores_path") or config.get("checkpoint")
144-149
: Guard against missing state attrs when no checkpoint is loaded.Avoid AttributeError if attributes aren’t initialized by BaseSearcher.
- if self.layer_scores and self.activations_per_rank: # Available from checkpoint + if getattr(self, "layer_scores", None) and getattr(self, "activations_per_rank", None):
158-163
: Use configured verbosity for checkpoint saves.Honor config["verbose"] rather than hardcoding True.
- self.save_search_checkpoint(verbose=True) + self.save_search_checkpoint(verbose=bool(self.config.get("verbose")))modelopt/torch/nas/plugins/megatron.py (2)
1249-1252
: Scope all_reduce to a group (optional).Consider reducing over the pipeline group explicitly for clarity and to avoid unintended world-size coupling.
- torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) # average + torch.distributed.all_reduce( + activations, op=torch.distributed.ReduceOp.SUM, group=get_pipeline_model_parallel_group() + )
1349-1362
: Docstring/types: layer_scores are floats; activations are by module name.Adjust annotations/comments for accuracy.
- ) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]: + ) -> tuple[list[dict[str, torch.Tensor]], dict[int, float]]: @@ - """Get the per-rank activations and layer scores from the module.""" + """Get the per-rank activations (by module name) and per-layer scores (1-indexed, float)."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(4 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)modelopt/torch/utils/distributed.py
(3 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- CHANGELOG.rst
- tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/nas/plugins/megatron.py (4)
modelopt/torch/opt/dynamic.py (2)
get_hparam
(819-821)get_hparam
(1235-1240)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (3)
allgather
(132-159)rank
(68-72)rank
(200-202)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)tests/conftest.py (1)
verbose
(22-23)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(71-88)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/prune/plugins/mcore_minitron.py (3)
modelopt/torch/opt/searcher.py (2)
default_search_config
(73-85)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/nas/plugins/megatron.py (2)
set_activations_and_layer_scores
(1364-1386)get_activations_and_layer_scores
(1347-1362)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (12)
docs/source/guides/3_pruning.rst (1)
7-7
: Updated tutorial link LGTMThe Qwen 3 NeMo pruning & distillation link is correct and aligns with the repo’s refocus.
tox.ini (2)
63-64
: Commented setenv looks fineNo behavior change; leaving this disabled is reasonable given the current build constraints.
70-72
: Clarified timing comment is helpfulKeeping the Mamba install disabled with a clear note avoids flaky GPU CI.
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (3)
34-59
: Helper signature + local model factory LGTMPassing ckpt_path via partial and toggling initialize_megatron through a model factory is clean and aligns with the re-prune-from-checkpoint objective.
98-99
: scores_path propagation: confirm that it triggers checkpoint writeAssuming mcore_minitron consumes scores_path to save aggregated scores on the master, this is correct. If non-master ranks attempt to write concurrently, ensure internal synchronization/barrier exists.
136-140
: Multiprocess invocation LGTMUsing partial to pass the shared checkpoint path is tidy and works with tmp_path.
modelopt/torch/opt/searcher.py (2)
30-31
: Importing warn is fineLocal import of warn keeps call sites concise.
258-271
: Verbose save hook LGTMOptional verbosity is useful; master-only save behavior is preserved.
examples/pruning/README.md (2)
70-72
: Good rename: “scores_path” is clearer than “checkpoint”.This avoids confusion with full model checkpoints.
69-79
: Unpacking example is correct;mtp.prune
returns a tuple
Type annotation (-> tuple[nn.Module, dict[str, Any]]
) and implementation confirm it returns(model, pruning_scores)
, so the README example need not change.modelopt/torch/nas/plugins/megatron.py (2)
1237-1241
: Good fix: consistent id(module) usage for activations map.Prevents missed accumulation due to mismatched keys.
1364-1387
: Restore tensors to the correct device when loading activations (prevents NCCL/device mismatches).Loaded tensors may be on CPU; hidden-size importance uses all_reduce and will fail on CPU tensors.
for layer in self.decoder.layers: layer._scores = layer_scores[layer.layer_number] for n, m in self.named_modules(): if hasattr(m, "_activations"): - m._activations = activations_per_rank[rank][n] + v = activations_per_rank[rank][n] + dev = get_module_device(m) + if isinstance(v, dict): + m._activations = {k: (t.to(dev) if isinstance(t, torch.Tensor) else t) for k, t in v.items()} + elif isinstance(v, torch.Tensor): + m._activations = v.to(dev) + else: + m._activations = v
136e94d
to
898a16c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/opt/searcher.py (1)
243-249
: Graceful missing-checkpoint handling is good; consider CPU load for safety.Add map_location="cpu" when loading to minimize GPU pressure and work in CPU-only contexts.
Apply this diff:
- state_dict = torch.load(checkpoint, weights_only=False) + state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False)modelopt/torch/prune/plugins/mcore_minitron.py (1)
87-89
: Be cautious: forcing verbose=True will spam multi-rank logs.Consider keeping rank-0-only verbosity for forward loops and rely on print_rank_0 for status messages.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(5 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(5 hunks)modelopt/torch/utils/distributed.py
(3 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/utils/distributed.py
- docs/source/guides/3_pruning.rst
- CHANGELOG.rst
🧰 Additional context used
🧬 Code graph analysis (5)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
_get_model
(45-58)forward_loop
(70-72)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(71-88)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/prune/plugins/mcore_minitron.py (4)
modelopt/torch/opt/dynamic.py (2)
DynamicModule
(338-914)config
(1265-1278)modelopt/torch/opt/searcher.py (3)
default_search_config
(73-85)sanitize_search_config
(92-103)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/nas/plugins/megatron.py (2)
set_activations_and_layer_scores
(1368-1390)get_activations_and_layer_scores
(1351-1366)
modelopt/torch/nas/plugins/megatron.py (4)
modelopt/torch/opt/dynamic.py (2)
get_hparam
(819-821)get_hparam
(1235-1240)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (3)
allgather
(132-159)rank
(68-72)rank
(200-202)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)tests/conftest.py (1)
verbose
(22-23)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (36)
modelopt/torch/nas/plugins/megatron.py (6)
652-655
: Good precision guard on similarity calc.Casting to float32 before cosine sim avoids overflow/NaNs during scoring.
1241-1245
: Bug fix looks correct: consistent id(module) keying.Switching both membership check and assignment to id(module) resolves the accumulation bug flagged earlier.
1291-1311
: Layer score aggregation across PP ranks looks sound.Per-rank dicts, all_gather_object, and final key sanity check are appropriate here.
1312-1317
: Refactor improves cohesion (single source of scores).Using _get_layer_scores in export keeps sorting logic centralized.
1351-1367
: Confirm root module’s _activations are captured.get_activations_and_layer_scores relies on named_modules(). Ensure the root module (which holds hidden_size activations in self._activations) is included so checkpointing captures them; otherwise hidden_size sorting from checkpoint will fail.
1369-1391
: Critical: Move restored tensors to each module’s device.activations_per_rank comes back with CPU tensors; later ops (e.g., NCCL all_reduce in _estimate_hidden_size_importance) require CUDA tensors on the module device. Move them on restore.
Apply this diff:
def set_activations_and_layer_scores( self, activations_per_rank: list[dict[str, torch.Tensor]], layer_scores: dict[int, torch.Tensor], ) -> None: """Set the pre-computed layer_scores and per-rank activations instead of running forward. Args: layer_scores: Dict from layer_number (1-indexed) to score. activations_per_rank: List of dicts from module name to activations. Should match PP size. """ rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( len(activations_per_rank), activations_per_rank, pp_size, ) for layer in self.decoder.layers: layer._scores = layer_scores[layer.layer_number] for n, m in self.named_modules(): if hasattr(m, "_activations"): - m._activations = activations_per_rank[rank][n] + v = activations_per_rank[rank][n] + dev = get_module_device(m) + if isinstance(v, dict): + m._activations = {k: (t.to(dev) if isinstance(t, torch.Tensor) else t) for k, t in v.items()} + elif isinstance(v, torch.Tensor): + m._activations = v.to(dev) + else: + m._activations = vexamples/llm_distill/README.md (1)
147-147
: Link update LGTM.Pointing to the NeMo Qwen pruning-distillation tutorial is appropriate.
tox.ini (2)
63-65
: Env var for Mamba build is reasonable.MAMBA_FORCE_BUILD=TRUE is helpful to avoid prebuilt wheel mismatches.
70-74
: Verify Triton pin compatibility with your Torch/CUDA.triton<3.4 may not match all torch versions; confirm it resolves and doesn’t conflict with your CUDA/Torch stack in CI.
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (5)
17-18
: Import addition is fine.partial is used below to pass ckpt_path.
34-59
: Local model factory improves re-pruning ergonomics.Enables clean reconstruction before checkpoint-based pruning.
60-60
: LGTM.Initialize model once for the first pruning pass.
124-133
: Re-pruning from checkpoint path is correct.Reruns prune with only scores_path and no forward_loop.
135-140
: Test harness update LGTM.Using tmp_path for scores_path and NCCL backend is appropriate.
modelopt/torch/opt/searcher.py (2)
30-31
: Minor import addition LGTM.warnings.warn is used below.
258-271
: Optional verbose save is a nice touch.Master-only save, mkdir, and state persistence look correct.
examples/pruning/README.md (3)
62-67
: Export config example LGTM.Straightforward width-prune illustration.
81-82
: Skip-sorting note LGTM.Clear guidance on bypassing forward_loop when already sorted.
98-103
: NeMo tutorial link update LGTM.Consistent with other docs.
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (7)
44-46
: New parameters are well-threaded.skip_sorting and ckpt_path are consistently propagated through the flow.
71-89
: Model factory abstraction LGTM.Simplifies model re-instantiation for checkpoint-driven runs.
114-118
: Config composition is clean; good guard.Asserting ckpt_path is None when skipping sorting avoids ambiguous semantics.
122-128
: Config handoff to prune is correct.scores_path and forward_loop wiring aligns with plugin behavior.
129-132
: Basic stats validation LGTM.Ensures checkpoint content is populated when sorting runs.
159-169
: Re-pruning from checkpoint path is correct.Validates the no-forward-loop path with restored scores.
199-233
: Parametrization update LGTM.tmp_path-based checkpoint toggle and spawn job wiring look fine.
modelopt/torch/prune/plugins/mcore_minitron.py (10)
27-28
: copy import is appropriate.Used to deep-copy export_config.
42-42
: Correct import of DynamicModule.Used to unwrap/check converted model.
66-68
: Public state typing is clear.Matches the stored structures used for checkpointing.
72-77
: New config keys are useful.skip_sorting and scores_path align with the checkpointed pruning flow.
81-83
: Default state keys LGTM.Explicit empty state for activations/scores.
100-105
: Helpful error messaging.Deep-copy + explicit key set makes misuse obvious.
138-144
: DDP unwrap + capability check LGTM.Ensures we operate on the converted dynamic module.
145-164
: Checkpointed activations flow is correct; pairs with tests.Load-or-collect, then save for reuse. Make sure set_activations moves tensors to the proper device (see comment in megatron.py).
165-169
: Sorting gate is clear.skip_sorting path is explicit; otherwise defer to sort_parameters().
178-189
: Model config sync LGTM.Preserving kv_channels and updating config fields ensures save/restore correctness.
898a16c
to
651f9d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/nas/plugins/megatron.py (1)
1254-1256
: all_reduce should use the PP group; comment says “average” but code sumsMissing group can over-reduce across non-PP ranks. Also the inline comment says average but you’re summing.
Apply this diff:
- # Reduce over all PP ranks - activations = activations.clone() - torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) # average + # Reduce (sum) over PP ranks only + activations = activations.clone() + torch.distributed.all_reduce( + activations, + op=torch.distributed.ReduceOp.SUM, + group=get_pipeline_model_parallel_group(), + ) + # If average is intended, divide by PP world size: + # pp_size = get_pipeline_model_parallel_world_size() + # activations /= pp_size
🧹 Nitpick comments (3)
modelopt/torch/opt/searcher.py (1)
251-256
: Gate “Loading searcher state…” print behind verbosityLoad prints on all ranks can be noisy. Mirror save behavior using verbose or master-only.
Apply this diff:
- # iterate through state dict and load keys - print(f"Loading searcher state from {checkpoint}...") + # iterate through state dict and load keys + if self.config.get("verbose", False): + print(f"Loading searcher state from {checkpoint}...")tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
159-168
: Re-pruning flow from checkpointFlow mirrors Mamba test. Optional: add a quick inference call after to sanity check execution.
modelopt/torch/prune/plugins/mcore_minitron.py (1)
86-88
: Verbose on all ranks may be noisyConsider keeping verbosity to rank 0 (print_rank_0 already used), and leave forward-loop silence as master-only to reduce log spam.
- config["checkpoint"] = config["scores_path"] - config["verbose"] = True # Print for all ranks + config["checkpoint"] = config["scores_path"] + # Keep verbose gated to rank 0 (print_rank_0 is already used elsewhere) + # config["verbose"] remains as set by BaseSearcher.sanitize_search_config
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(5 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(4 hunks)modelopt/torch/utils/distributed.py
(3 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- modelopt/torch/utils/distributed.py
- tox.ini
- examples/pruning/README.md
- CHANGELOG.rst
- examples/llm_distill/README.md
🧰 Additional context used
🧬 Code graph analysis (5)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(71-88)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)tests/conftest.py (1)
verbose
(22-23)
modelopt/torch/nas/plugins/megatron.py (5)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)modelopt/torch/opt/dynamic.py (2)
get_hparam
(819-821)get_hparam
(1235-1240)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (3)
allgather
(132-159)rank
(68-72)rank
(200-202)
modelopt/torch/prune/plugins/mcore_minitron.py (3)
modelopt/torch/nas/plugins/megatron.py (3)
_DynamicMCoreLanguageModel
(1151-1390)set_activations_and_layer_scores
(1368-1390)get_activations_and_layer_scores
(1351-1366)modelopt/torch/opt/searcher.py (2)
default_search_config
(73-85)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
_get_model
(45-58)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (24)
docs/source/guides/3_pruning.rst (1)
7-9
: Updated tutorial reference looks correct.The refreshed pointer to the Qwen 3 NeMo Minitron pruning tutorial aligns with the new workflow coverage. Thanks for keeping the docs current.
modelopt/torch/nas/plugins/megatron.py (6)
652-655
: Casting to float32 before cosine similarity is correctPrevents overflow/precision loss in importance scoring. Good change.
1241-1245
: Fix uses id(module) consistently for activation accumulationThis resolves the previous membership-check bug and ensures accumulation per module hook is correct.
1291-1311
: Layer score collection across PP ranks looks goodAsserts presence, gathers via all_gather_object, merges, and validates coverage. LGTM.
1312-1320
: Deferring drop to a dedicated helper after centralized scoring is cleanThe flow via _get_layer_scores before pruning is clear and maintainable.
1351-1367
: Sharing activations and layer scores API is useful and consistentCollects only modules with _activations and gathers per PP rank. LGTM.
1368-1391
: Move restored activations to the correct device; handle dict vs tensorLoaded tensors can land on CPU and break later GPU ops; also some modules store dicts. Move to each module’s device and handle both cases.
Apply this diff:
def set_activations_and_layer_scores( self, activations_per_rank: list[dict[str, torch.Tensor]], layer_scores: dict[int, torch.Tensor], ) -> None: """Set the pre-computed layer_scores and per-rank activations instead of running forward. Args: layer_scores: Dict from layer_number (1-indexed) to score. activations_per_rank: List of dicts from module name to activations. Should match PP size. """ rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( len(activations_per_rank), activations_per_rank, pp_size, ) for layer in self.decoder.layers: layer._scores = layer_scores[layer.layer_number] for n, m in self.named_modules(): if hasattr(m, "_activations"): - m._activations = activations_per_rank[rank][n] + v = activations_per_rank[rank][n] + dev = get_module_device(m) + if isinstance(v, dict): + m._activations = { + k: (t.to(dev) if hasattr(t, "to") else t) for k, t in v.items() + } + else: + m._activations = v.to(dev) if hasattr(v, "to") else vtests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (4)
34-59
: Refactor to pass ckpt_path and wrap model constructionTest now supports checkpoint-driven re-pruning and cleanly toggles Megatron init via _get_model(). Good.
98-99
: Thread scores_path through prune configEnsures activations/scores are saved for reuse. LGTM.
124-133
: Re-pruning from checkpoint without forward loopFlow is correct. Ensure tmp_path is shared/accessible by all PP ranks on the same node/FS.
135-140
: Multiprocess invocation with partial is correctPasses identical checkpoint path to all ranks. LGTM.
modelopt/torch/opt/searcher.py (2)
243-248
: Graceful handling when checkpoint missing or NoneEarly exit + master-only warning is correct. LGTM.
258-267
: No subclasses override save_search_checkpoint; signature change is safe.tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (4)
71-89
: Local _get_model wrapper is a good cleanupEnables easy re-instantiation for re-pruning. LGTM.
114-128
: Config wiring for skip_sorting and scores_path is soundAsserts ensure correct usage; forward loop only when needed. LGTM.
129-132
: Pruning stats checksValidates presence of scores only when sorting runs. LGTM.
171-201
: Expanded parametrization is comprehensiveCovers MHA/GQA, depth, skip_sorting, and checkpoint path toggles. LGTM.
modelopt/torch/prune/plugins/mcore_minitron.py (7)
65-67
: State attributes for checkpoint contentsCorrect and explicit types. LGTM.
71-76
: Default config extended with skip_sorting and scores_pathMatches new functionality. LGTM.
81-82
: Default state dict captures activations/scoresAppropriate for checkpointing. LGTM.
99-104
: Deep-copy export_config and clear error messagePrevents caller-side mutation; message enumerates allowed keys. LGTM.
137-143
: Unwrapping to _DynamicMCoreLanguageModel before reuse is correctEnsures we operate on the dynamic model instance. LGTM.
144-163
: Checkpoint-aware run_search flow is solid
- Loads activations/scores when present
- Forward loop + save when absent (and not skipping)
- Stores per-rank activations and aggregated scores
LGTM.
165-168
: Conditional sortingRespects skip_sorting for pre-sorted models. LGTM.
Signed-off-by: Keval Morabia <[email protected]>
651f9d2
to
6254401
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/utils/distributed.py (1)
132-160
: Fix tensor/int mixing in allgather (runtime error).Comparing and subtracting a torch.Tensor from an int here is invalid; this branch will raise/behave incorrectly.
Apply this diff:
@@ - # gather the tensor + # gather the tensor tensors = [torch.ByteTensor(size=(max_size,)).cuda() for _ in tensor_sizes] - if tensor_size != max_size: - padding = torch.ByteTensor(size=(max_size - tensor_size,)).cuda() - tensor = torch.cat((tensor, padding), dim=0) + tensor_size_int = int(tensor_size.item()) + if tensor_size_int != max_size: + padding = torch.ByteTensor(size=(max_size - tensor_size_int,)).cuda() + tensor = torch.cat((tensor, padding), dim=0) _allgather(tensors, tensor, group)Additionally, for consistency, update broadcast’s early return to be group-aware (outside the selected lines):
def broadcast(obj: Any, src: int = 0, group=None) -> Any: if size(group) == 1: return obj # ... if rank(group) == src: tensor = _serialize(obj).cuda() # ...
🧹 Nitpick comments (7)
tox.ini (1)
70-74
: Pin the Mamba git dependency to avoid CI nondeterminism.Installing from the HEAD of the mamba repo can break builds unexpectedly. Pin to a tag or commit.
Apply this diff:
- pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + # Pin Mamba to a stable ref to ensure reproducible CI + pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git@<commit-or-tag>examples/llm_distill/README.md (1)
147-147
: New NeMo tutorial reference: LGTM; optional doc link.Consider adding a link to NeMo docs on converting HF -> NeMo to guide users.
CHANGELOG.rst (1)
20-20
: Call out skip_sorting and scores_path for clarity.Expand this bullet to mention the new config keys (scores_path, skip_sorting) so users discover them easily.
-- Support storing and restoring Minitron pruning activations and scores for re-pruning without running the forward loop again. +- Support storing and restoring Minitron pruning activations and scores for re-pruning without running the forward loop again (config: scores_path). Add option to skip parameter sorting when models are pre-sorted (config: skip_sorting).modelopt/torch/opt/searcher.py (3)
243-248
: Good: graceful handling of missing checkpoints.Minor: consider gating the "Loading searcher state..." print under verbose to reduce multi-rank noise.
Outside the selected range:
if self.config["verbose"]: print(f"Loading searcher state from {checkpoint}...")
253-256
: Use set-equality for checkpoint key validation.dict_keys equality can be fragile; set comparison is clearer.
- assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" + assert set(state_dict.keys()) == set(self.state_dict().keys()), "Keys in checkpoint don't match!"
258-268
: Default verbose to config to unify logging control.Let callers omit verbose and inherit from config.
- def save_search_checkpoint(self, verbose=False) -> None: + def save_search_checkpoint(self, verbose: bool | None = None) -> None: @@ - if verbose: + if (self.config["verbose"] if verbose is None else verbose): print(f"Saving searcher state to {checkpoint}...")examples/pruning/README.md (1)
81-82
: Add usage example for skip_sorting.While you mention that
skip_sorting
can be used, consider adding a complete code example to clarify the usage pattern.Add a code example after line 82:
# Example: Re-run with skip_sorting for pre-sorted models (e.g., FlexTRON) model = mtp.prune( model, mode="mcore_minitron", constraints={"export_config": export_config}, dummy_input=None, config={"skip_sorting": True}, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
.gitlab/tests.yml
(1 hunks)CHANGELOG.rst
(1 hunks)docs/source/guides/3_pruning.rst
(1 hunks)examples/llm_distill/README.md
(1 hunks)examples/pruning/README.md
(2 hunks)modelopt/torch/nas/plugins/megatron.py
(5 hunks)modelopt/torch/opt/searcher.py
(3 hunks)modelopt/torch/prune/plugins/mcore_minitron.py
(4 hunks)modelopt/torch/utils/distributed.py
(3 hunks)tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
(7 hunks)tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
(5 hunks)tox.ini
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/distributed.py (1)
is_master
(75-77)tests/conftest.py (1)
verbose
(22-23)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
_get_model
(71-88)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_mamba_model
(255-321)
modelopt/torch/prune/plugins/mcore_minitron.py (3)
modelopt/torch/nas/plugins/megatron.py (3)
_DynamicMCoreLanguageModel
(1151-1390)set_activations_and_layer_scores
(1368-1390)get_activations_and_layer_scores
(1351-1366)modelopt/torch/opt/searcher.py (3)
default_search_config
(73-85)sanitize_search_config
(92-103)save_search_checkpoint
(258-271)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/nas/plugins/megatron.py (5)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)modelopt/torch/opt/dynamic.py (2)
get_hparam
(819-821)get_hparam
(1235-1240)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/utils/distributed.py (3)
allgather
(132-159)rank
(68-72)rank
(200-202)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
_get_model
(45-58)forward_loop
(70-72)tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
get_mcore_gpt_model
(133-208)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (18)
.gitlab/tests.yml (1)
30-30
: GPU tests timeout bump to 90m looks fine.No concerns. This should accommodate the new pruning tests.
tox.ini (1)
63-64
: Environment flag addition LGTM.Setting MAMBA_FORCE_BUILD=TRUE for GPU envs is reasonable.
modelopt/torch/opt/searcher.py (1)
30-30
: Switch to warnings.warn import: LGTM.modelopt/torch/utils/distributed.py (2)
127-130
: Group-aware _allgather: LGTM.
162-169
: Group-aware allreduce: LGTM.Works as expected for subgroup reductions (sum).
docs/source/guides/3_pruning.rst (1)
7-7
: Tutorial link verified: the updated Qwen 3 pruning & distillation tutorial URL is correct.examples/pruning/README.md (1)
62-79
: LGTM! Clear usage example with scores_path parameter.The updated example effectively demonstrates the new checkpoint feature and includes an important note about re-calibration when changing datasets.
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (2)
45-59
: LGTM! Clean helper function pattern.The
_get_model
helper function properly handles conditional megatron initialization, enabling both initial pruning and checkpoint-based re-pruning scenarios.
124-133
: LGTM! Validates re-pruning from checkpoint.The test correctly validates that re-pruning from a saved checkpoint works without rerunning the forward loop, ensuring the checkpoint feature functions as intended.
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
114-132
: Comprehensive test coverage for pruning configurations.The test properly handles both skip_sorting and scores_path configurations with appropriate validation logic. The assertion that
ckpt_path
should be None whenskip_sorting
is True ensures mutually exclusive usage.
159-169
: LGTM! Clean re-pruning validation.The conditional re-pruning test appropriately validates checkpoint-based pruning without forward pass execution.
modelopt/torch/nas/plugins/megatron.py (3)
652-655
: Necessary float32 conversion for numerical stability.Converting to float32 before computing cosine similarity prevents potential overflow/underflow issues with mixed precision tensors.
1241-1244
: Fixed: Proper use of id() for dictionary keys.The code now correctly uses
id(module)
as the dictionary key for both checking and storing, fixing the previous key mismatch issue.Based on past review comments.
1351-1391
: Missing device handling for loaded tensors.Loaded tensors may be on CPU or incorrect device, which could cause issues during distributed operations. Additionally, the docstring should clarify that activations_per_rank uses module names as keys.
Based on past review comments.
Apply this diff to fix device placement and documentation:
def set_activations_and_layer_scores( self, activations_per_rank: list[dict[str, torch.Tensor]], layer_scores: dict[int, torch.Tensor], ) -> None: """Set the pre-computed layer_scores and per-rank activations instead of running forward. Args: layer_scores: Dict from layer_number (1-indexed) to score. - activations_per_rank: List of dicts from module name to activations. Should match PP size. + activations_per_rank: List of dicts from module name (str) to activations. Should match PP size. """ rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( len(activations_per_rank), activations_per_rank, pp_size, ) for layer in self.decoder.layers: layer._scores = layer_scores[layer.layer_number] for n, m in self.named_modules(): if hasattr(m, "_activations"): - m._activations = activations_per_rank[rank][n] + activation = activations_per_rank[rank].get(n) + if activation is not None and isinstance(activation, torch.Tensor): + # Move to module's device to avoid cross-device errors + device = next(m.parameters()).device if list(m.parameters()) else None + if device is not None: + activation = activation.to(device, non_blocking=True) + m._activations = activationmodelopt/torch/prune/plugins/mcore_minitron.py (4)
68-76
: LGTM! Well-structured default config with checkpoint support.The addition of
skip_sorting
andscores_path
parameters provides good flexibility for checkpoint-based workflows.
99-99
: Good defensive programming with deep copy.Deep copying the export_config prevents unexpected mutations during processing.
144-163
: Effective checkpoint management implementation.The conditional logic properly handles three scenarios: loading from checkpoint, running forward pass with checkpoint save, and skipping sorting for pre-sorted models.
65-67
: Type hints don't match actual structure.The type annotations incorrectly suggest a single-level dict when these are rank-indexed nested dicts.
Apply this diff to fix the type hints:
- activations_per_rank: list[dict[str, torch.Tensor]] - layer_scores: dict[int, torch.Tensor] + activations_per_rank: list[dict[str, torch.Tensor]] # List indexed by rank, each containing module name -> activation + layer_scores: dict[int, torch.Tensor] # Layer number (1-indexed) -> score
…#361) Signed-off-by: Keval Morabia <[email protected]>
…#361) Signed-off-by: Keval Morabia <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: minitron minor improvements
Usage
Testing
Verified ckpt can be restored and pruning can be re-run without forward loop. For toy test models, verified aggregated scores from ckpt are same as before
Before your PR is "Ready for review"
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests