Skip to content

Conversation

AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Sep 15, 2025

What does this PR do?

Feature: Change the way "kd_loss" mode saves state

Overview: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Unit

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Y (technically No in theory, but in practice yes)
  • Did you write any new necessary tests?: Y
  • Did you add or update any necessary documentation?: N
  • Did you update Changelog?: N

Additional Information

Summary by CodeRabbit

  • New Features

    • Distillation now accepts a pre-instantiated teacher model (nn.Module) in addition to class, callable, or tuple forms.
  • Refactor

    • Checkpoint restore no longer reinstates the distillation wrapper; exported students restore with KD config reset and a new pre-save reset hook.
    • Saved state now preserves all KD-related entries (no pruning).
  • Documentation

    • Guides and examples updated for teacher_model forms, export/restore semantics, loss kwargs, CLI formatting, arg rename (max_seq_length → max_length), and example requirements.
  • Tests

    • Tests updated to assert unrestored distillation state and adjust examples/tests to direct teacher instantiation.

@AAnoosheh AAnoosheh self-assigned this Sep 15, 2025
Copy link

copy-pr-bot bot commented Sep 15, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 15, 2025

Walkthrough

Updated docs, examples, configs, and tests to broaden accepted teacher_model types, change save/restore semantics so distillation wrappers are not re-instantiated on restore, add a pre-save reset hook, stop filtering KD entries from saved state, and adapt examples/CLI and tests accordingly.

Changes

Cohort / File(s) Summary of changes
Docs & guide
docs/source/guides/4_distillation.rst
Clarified restore/export wording, broadened teacher_model description to module/class/callable/tuple, updated loss guidance and examples.
Examples & README
examples/llm_distill/README.md, examples/llm_distill/main.py, examples/llm_distill/requirements.txt, examples/llm_distill/accelerate_config/fsdp2.yaml
Replaced teacher factory with direct from_pretrained instantiation passed as teacher_model; renamed max_seq_lengthmax_length; updated accelerate CLI to config-file usage and added FSDP config; bumped trl to >=0.23.0.
Example tests / CLI tests
tests/examples/llm_distill/test_llm_distill.py, tests/unit/torch/opt/plugins/test_hf_patching.py
Adapted tests to provide concrete teacher instances and to stop asserting reconstitution of DistillationModel/teacher internals.
Distill config typing
modelopt/torch/distill/config.py
Removed local TeacherModel alias; switched KDLossConfig.teacher_model to shared ModelLike and updated descriptive text to accept module, class, callable, or tuple.
Distillation mode logic
modelopt/torch/distill/mode.py
Added update_for_save, consolidated pre-new-mode/pre-save reset via _reset_kd_state_config, and simplified restore helpers to return the input model without reconstructing DistillationModel.
Trainer plugins: save behavior
modelopt/torch/distill/plugins/huggingface.py, modelopt/torch/quantization/plugins/transformers_trainer.py
Removed filtering that pruned modelopt_state["modelopt_state_dict"] entries (e.g., "kd_loss", "export_student"); full state dict now saved intact.
Unit tests: distillation
tests/unit/torch/distill/test_distill.py
Updated assertions to expect KD config reset on restore (e.g., teacher_model marker, criterion contains Loss, loss_balancer None) and to not expect DistillationModel reconstitution.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant T as KDTrainer
  participant KD as KnowledgeDistillationModeDescriptor
  participant M as Model (student/exported)
  participant S as ModeloptStateManager

  Note over T,KD: Pre-save reset then full state save
  T->>KD: update_for_save()
  activate KD
  KD->>M: _reset_kd_state_config (teacher_model -> marker, criterion -> Loss(), loss_balancer -> None)
  deactivate KD
  T->>S: collect modelopt_state (unfiltered)
  T->>T: save model + modelopt_state (kd entries preserved)
Loading
sequenceDiagram
  autonumber
  participant U as User
  participant S as ModeloptStateManager
  participant M as Model (on disk)
  participant KD as KnowledgeDistillationModeDescriptor

  Note over U,S: Restore returns raw model (no DistillationModel)
  U->>S: restore(checkpoint)
  S-->>U: return model instance (no DistillationModel)
  U->>KD: (optional) update_for_new_mode()
  KD->>M: _reset_kd_state_config (ensure KD config consistent)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I hop through checkpoints, snug and light,
No wrapped-up teacher pops to sight,
Saved whole, unpruned, the state stays true,
Reset and ready — then off I chew. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.36% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Disable KD mode from saving problematic state" is a concise, single-sentence summary that accurately captures the primary intent of the changeset—preventing knowledge-distillation-related state from being persisted during saves (which aligns with changes to the distill mode, save hooks, and tests). It is specific and clear enough for a reviewer scanning PR history to understand the main change without listing file-level details.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch aanoosheh/remove-kd-state

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@AAnoosheh AAnoosheh marked this pull request as ready for review September 16, 2025 09:19
@AAnoosheh AAnoosheh requested review from a team as code owners September 16, 2025 09:19
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
docs/source/guides/4_distillation.rst (2)

42-55: Example uses undefined variable teacher_model.

The snippet sets teacher_model in the config but never defines it.

Apply this diff to define a simple teacher for the example:

 from torchvision.models import resnet50
@@
 # User-defined model (student)
 model = resnet50()
 
 # Configure and convert for distillation
 distillation_config = {
-    # `teacher_model` is a model, model class, callable, or a tuple.
+    # A simple example teacher; in practice use a stronger model.
+    "teacher_model": resnet50(),
+    # `teacher_model` can be a model, model class, callable, or a tuple.
-    # If a tuple, it must be of the form (model_cls_or_callable,) or
+    # If a tuple, it must be of the form (model_cls_or_callable,) or
     # (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
-    "teacher_model": teacher_model,
     "criterion": mtd.LogitsDistillationLoss(),
     "loss_balancer": mtd.StaticLossBalancer(),
 }

156-162: Typo in API alias: atd → mtd.

The code won’t run as written.

Apply this diff:

-    distillation_model = atd.convert(student_model, mode=[("kd_loss", distillation_config)])
+    distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
🧹 Nitpick comments (12)
examples/llm_distill/README.md (4)

42-45: Example may OOM on 70B teacher; suggest device_map to shard or offload.

Loading a 70B model on a single device will likely OOM. Recommend showing a safer pattern in the snippet.

Apply this diff to the example:

-# Define student & teacher
-student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
-teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
+# Define student & teacher
+student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
+# Consider sharded/offloaded loading for large teachers:
+teacher_model = AutoModelForCausalLM.from_pretrained(
+    "meta-llama/Llama-3.1-70B-Instruct",
+    device_map="auto"  # or use accelerate/FSDP per your setup
+)

65-66: Clarify serialization behavior of an nn.Module teacher.

Readers may wonder how checkpoints behave when passing a module instance. Add a short note pointing to the guide’s restoration semantics.

Proposed addition:

-The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`.
+The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`.
+Note: when saving, KD-specific state (including the teacher instance) is not re-instantiated on restore; see the Distillation guide for details.

76-82: Fix variable name mismatch (train_loader vs train_dataloader).

The code defines train_loader but iterates train_dataloader.

Apply this diff:

-for input, labels in train_dataloader:
+for input, labels in train_loader:

14-15: Typo: “intellegant” → “intelligent”.

Minor doc polish.

Apply this diff:

-| Getting Started | Learn how to optimize your models using distillation to produce more intellegant smaller models | [[Link](#getting-started)] | [[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html)] |
+| Getting Started | Learn how to optimize your models using distillation to produce more intelligent smaller models | [[Link](#getting-started)] | [[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html)] |
docs/source/guides/4_distillation.rst (3)

19-22: Make restore semantics actionable.

Add a pointer that resuming KD requires re-converting with kd_loss if desired.

Apply this diff:

-Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`)
-will not reinstantiate the distillation meta-model, in order to avoid unpickling issues.
+Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`)
+will not reinstantiate the distillation meta-model, in order to avoid unpickling issues. To
+resume KD training, call :meth:`mtd.convert <modelopt.torch.distill.distillation.convert>`
+again on the restored student with your desired ``kd_loss`` config.

56-57: Clarify export comment.

“Previously-present attributes” is vague.

Apply this diff:

-# Export model in original class, with only previously-present attributes
+# Export the original student class; distillation-specific attributes are removed

61-61: Grammar nit: “for to perform” → “to perform”.

Apply this diff:

-    When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard
+    When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` to perform Minifinetuning in lieu of the standard
examples/llm_distill/main.py (1)

126-129: Consider specifying dtype to reduce memory (esp. with bf16).

Passing torch_dtype helps avoid unnecessary fp32 allocations.

Apply this diff:

-        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.teacher_name_or_path,
-            device_map=PartialState().process_index,
-        )
+        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.teacher_name_or_path,
+            device_map=PartialState().process_index,
+            torch_dtype=torch.bfloat16 if training_args.bf16 else None,
+        )
modelopt/torch/distill/config.py (1)

81-89: Handle None criterion without injecting a sentinel loss.

Returning {(“”, “”): None} creates a dict with a non-Loss value. Prefer returning {} to keep types clean; strict validation already warns on empty criterion.

Apply this diff:

     @pydantic.field_validator("criterion")
     @classmethod
     def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]:
         """Ensure criterion is a mapping from layer names to loss (potentially entire module)."""
-        if not isinstance(criterion, dict):
-            # Output-only distillation.
-            criterion = {("", ""): criterion}
-        return criterion
+        if criterion is None:
+            return {}
+        if isinstance(criterion, dict):
+            return criterion
+        # Output-only distillation.
+        return {("", ""): criterion}
modelopt/torch/distill/mode.py (3)

16-20: Docstring mentions NAS; update to Distillation.

Minor copy/paste artifact.

Apply this diff:

-"""Module implementing and describing modes that can be used during the NAS convert process.
+"""Module implementing and describing modes that can be used during the Distillation convert process.
 
-Check out :meth:`mtn.convert <modelopt.torch.nas.conversion.convert>` to learn more about modes.
+Check out :meth:`mtd.convert <modelopt.torch.distill.distillation.convert>` to learn more about modes.
 """

177-181: Optional: emit a debug log on no-op restore.

A one-line debug helps users understand why KD wasn’t reconstructed.

Apply this diff:

 def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module:
     """Function for restoring a previously convert model to a distillation meta-model."""
-    # NOTE: DistillationModel will purposely remain unrestored
+    # NOTE: DistillationModel will purposely remain unrestored
+    warnings.warn("KD mode state was sanitized at save time; skipping DistillationModel reconstruction.", stacklevel=1)
     return model

183-188: Type hint and minimal doc nit for reset helper.

Make the helper’s intent explicit; keep behavior unchanged.

Apply this diff:

-def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
-    """Function for resetting the state's config."""
+def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> None:
+    """Reset KD-related config fields to pickle-friendly placeholders before new-mode addition or save."""
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 7df1237.

📒 Files selected for processing (8)
  • docs/source/guides/4_distillation.rst (4 hunks)
  • examples/llm_distill/README.md (2 hunks)
  • examples/llm_distill/main.py (2 hunks)
  • modelopt/torch/distill/config.py (2 hunks)
  • modelopt/torch/distill/mode.py (3 hunks)
  • modelopt/torch/distill/plugins/huggingface.py (0 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (0 hunks)
  • tests/unit/torch/distill/test_distill.py (1 hunks)
💤 Files with no reviewable changes (2)
  • modelopt/torch/quantization/plugins/transformers_trainer.py
  • modelopt/torch/distill/plugins/huggingface.py
🧰 Additional context used
🧬 Code graph analysis (4)
tests/unit/torch/distill/test_distill.py (2)
modelopt/torch/opt/conversion.py (1)
  • ModeloptStateManager (63-311)
modelopt/torch/opt/config.py (1)
  • values (136-138)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
  • teacher_model (124-126)
modelopt/torch/distill/config.py (2)
modelopt/torch/distill/distillation_model.py (1)
  • teacher_model (124-126)
modelopt/torch/opt/config.py (1)
  • ModeloptField (50-53)
modelopt/torch/distill/mode.py (1)
modelopt/torch/distill/config.py (1)
  • KDLossConfig (34-120)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: wait-checks / wait
🔇 Additional comments (8)
examples/llm_distill/README.md (1)

57-60: Direct nn.Module acceptance is clear.

Good change; the example demonstrates the broadened API for teacher_model correctly.

examples/llm_distill/main.py (2)

113-115: LGTM: direct single-model load aligns with README.

Using AutoModelForCausalLM directly here improves clarity and removes indirection.


131-135: LGTM: pass instantiated teacher into kd_config.

This matches the broadened ModelLike typing and the docs.

tests/unit/torch/distill/test_distill.py (2)

150-156: Good assertions for sanitized KD state after restore.

Asserting placeholders for teacher_model, criterion, and loss_balancer precisely tests the new save/restore semantics.


157-159: LGTM: restored object remains the raw student type.

This validates the intended “don’t reconstruct DistillationModel” behavior.

modelopt/torch/distill/config.py (1)

43-51: Broadened ModelLike support is appropriate.

Accepting nn.Module in addition to class/callable/tuple simplifies usage and is consistent with init_model_from_model_like.

modelopt/torch/distill/mode.py (1)

84-92: Pre-save hook is a solid addition.

Exposing update_for_save and pointing it to the reset function centralizes KD-state sanitization before checkpointing.

docs/source/guides/4_distillation.rst (1)

205-210: External reference verified — "Minifinetuning" is correct.
Paper: "Minifinetuning: Low-Data Generation Domain Adaptation through Corrective Self-Distillation" — Peter Belcak, Greg Heinrich, Jan Kautz, Pavlo Molchanov; arXiv:2506.15702.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_distill/main.py (1)

83-91: world_size lookup can crash when DDP not initialized; avoid float division.

torch.distributed.get_world_size() raises if the process group isn’t initialized, and float is_integer checks are brittle.

Apply:

-    num_accum_steps = total_batch_size / (
-        training_args.per_device_train_batch_size * torch.distributed.get_world_size()
-    )
-    if not num_accum_steps.is_integer():
-        raise ValueError(
-            f"`per_device_train_batch_size` * `world_size` must be a factor of {total_batch_size}"
-        )
-    training_args.gradient_accumulation_steps = int(num_accum_steps)
+    world_size = PartialState().num_processes
+    per_step = training_args.per_device_train_batch_size * world_size
+    if total_batch_size % per_step != 0:
+        raise ValueError(
+            f"`per_device_train_batch_size` * `world_size` ({per_step}) must divide {total_batch_size}"
+        )
+    training_args.gradient_accumulation_steps = total_batch_size // per_step
🧹 Nitpick comments (3)
examples/llm_distill/main.py (2)

51-51: max_length is not consumed by SFTTrainer; wire it or revert.

Pass it as SFTTrainer’s max_seq_length to avoid silently using the default.

Apply:

 class KDSFTTrainer(SFTTrainer, KDTrainer):
     pass
@@
     trainer = trainer_cls(
         model,
         training_args,
         train_dataset=dset_train,
         eval_dataset=dset_eval,
         formatting_func=llama_text_format_func,
         processing_class=tokenizer,
+        max_seq_length=training_args.max_length,
     )

62-66: Formatting func: guard missing keys and trim whitespace.

Use .get(...) to avoid KeyError and normalize system prompt.

Apply:

-    p, q, r = sample["system_prompt"], sample["question"], sample["response"]
-    if not p:
-        return f"<s>[INST] {q}[/INST]\n{r}</s>"
-    else:
-        return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
+    p = (sample.get("system_prompt") or "").strip()
+    q = sample["question"]
+    r = sample["response"]
+    if p:
+        return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
+    return f"<s>[INST] {q}[/INST]\n{r}</s>"
examples/llm_distill/requirements.txt (1)

2-2: Pin HF stack for TRL 0.23.0 (avoid resolver drift)

TRL's PyPI metadata shows accelerate>=1.4.0, transformers>=4.56.1, peft>=0.8.0 — add explicit pins or a constraints file so examples/llm_distill/requirements.txt (currently contains trl==0.23.0) also locks compatible versions (suggested: transformers==4.56.1, accelerate==1.4.0, peft==0.8.0).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7df1237 and eb51261.

📒 Files selected for processing (3)
  • examples/llm_distill/README.md (3 hunks)
  • examples/llm_distill/main.py (3 hunks)
  • examples/llm_distill/requirements.txt (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_distill/README.md
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
  • teacher_model (124-126)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
examples/llm_distill/main.py (1)

128-131: Passing the instantiated teacher into kd_config looks right.

This aligns with the new ModelLike teacher_model semantics and the PR goal to avoid saving problematic KD wrapper state.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
examples/llm_distill/README.md (1)

172-184: Same here: keep FSDP2 knobs inside the YAML.

Mirror the previous change for the distillation command.

Apply:

-accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
-    --fsdp_cpu_ram_efficient_loading False \
-    --fsdp_activation_checkpointing False \
-    main.py \
+accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
+    main.py \
     --teacher_name_or_path ./llama2-7b-sft \
     --student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
     --output_dir ./llama2-distill \
     --max_length 2048 \
     --per_device_train_batch_size 1 \
     --per_device_eval_batch_size 4 \
     --max_steps 200 \
     --logging_steps 5
examples/llm_distill/main.py (3)

109-112: Optional: enable memory‑friendly loading (and consider explicit device_map).

For large models, add low_cpu_mem_usage=True. If you hit meta‑device issues in distributed setups, use an explicit device map or rely on accelerate’s placement. This mirrors prior review feedback.

Apply:

-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_path,
+            torch_dtype=torch.bfloat16 if training_args.bf16 else None,
+            low_cpu_mem_usage=True,
+        )
@@
-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.student_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.student_name_or_path,
+            torch_dtype=torch.bfloat16 if training_args.bf16 else None,
+            low_cpu_mem_usage=True,
+        )
@@
-        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.teacher_name_or_path,
+            torch_dtype=torch.bfloat16 if training_args.bf16 else None,
+            low_cpu_mem_usage=True,
+        )

Also applies to: 115-117, 121-123


107-112: Bug: from_pretrained uses torch_dtype, not dtype.

This will raise TypeError: got an unexpected keyword argument 'dtype'. Use torch_dtype=....

Apply:

-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_path, dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+        )

121-123: Same torch_dtype fix for teacher load.

Apply:

-        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+        )
🧹 Nitpick comments (8)
tests/unit/torch/opt/plugins/test_hf_patching.py (2)

40-44: Direct teacher instantiation LGTM; keep dtype/device consistent.

Good move replacing the factory with a concrete teacher instance. Consider ensuring teacher and student use the same dtype/device in tests to avoid accidental dtype/device mismatches when kernels run on CI with different defaults. A simple teacher_model.to(model_ref.dtype).eval() before convert is sufficient.


56-58: Add an assertion that KD state is not re‑materialized after restore.

Given the PR’s goal (avoid saving problematic KD state), add a check that the reloaded model is a plain base model without KD wrappers and that no KD config is present in saved state.

Apply:

 tf_output_tester(model, model_test)
 # since distill model contains loss function, we compare state of model manually
 assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model)
+
+# Also verify KD metadata is not persisted/reconstructed
+state = mto.modelopt_state(model_test.model)
+assert not any(k.startswith("kd_") for k in state.keys())
examples/llm_distill/README.md (3)

42-45: Example matches new API; add note about memory‑friendly loading.

Since large models are used here, add a brief note (or code) to pass low_cpu_mem_usage=True (and optionally a device map) to reduce CPU spikes during from_pretrained.

Apply:

-student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
-teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct")
+student_model = AutoModelForCausalLM.from_pretrained(
+    "meta-llama/Llama-3.1-8B-Instruct", low_cpu_mem_usage=True
+)
+teacher_model = AutoModelForCausalLM.from_pretrained(
+    "meta-llama/Llama-3.1-70B-Instruct", low_cpu_mem_usage=True
+)

57-60: Docs: clarify loss balancer recommendation.

Mention that when only KD loss is used, omitting loss_balancer makes the KD loss the total loss by default; otherwise, show a weight (e.g., StaticLossBalancer(kd_weight=1.0, student_weight=1.0)).


157-167: Avoid non‑portable accelerate flags; keep them in YAML.

--fsdp_cpu_ram_efficient_loading and --fsdp_activation_checkpointing are config keys, not stable CLI flags across accelerate versions. Recommend removing them from the command and keeping overrides in accelerate_config/fsdp2.yaml to prevent CI breakage.

Apply:

-accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
-    main.py \
+accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
+    main.py \
     --single_model \
     --teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
     --output_dir ./llama2-7b-sft \
     --max_length 2048 \
     --per_device_train_batch_size 1 \
     --per_device_eval_batch_size 4 \
     --max_steps 400 \
     --logging_steps 5
tests/examples/llm_distill/test_llm_distill.py (1)

24-27: Make test independent of accelerate CLI internals.

Drop accelerate‑specific FSDP flags from the CLI; keep them solely in accelerate_config/fsdp2.yaml to reduce flakiness across versions.

Apply:

-            "accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml",
-            "--fsdp_cpu_ram_efficient_loading", "False",
-            "--fsdp_activation_checkpointing", "False",
-            "main.py",
+            "accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml",
+            "main.py",
examples/llm_distill/main.py (2)

80-90: Guard world size when not launched under accelerate.

torch.distributed.get_world_size() errors if the process group isn’t initialized. Add a safe fallback.

Apply:

-    num_accum_steps = total_batch_size / (
-        training_args.per_device_train_batch_size * torch.distributed.get_world_size()
-    )
+    world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
+    num_accum_steps = total_batch_size / (
+        training_args.per_device_train_batch_size * world_size
+    )

124-129: Freeze teacher to avoid accidental gradients.

Explicitly set teacher_model.requires_grad_(False) to prevent optimizer from touching it if a custom trainer is used.

Apply:

         kd_config = {
             "teacher_model": teacher_model,
             "criterion": LMLogitsLoss(),
         }
+        teacher_model.requires_grad_(False).eval()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eb51261 and 3dcd039.

📒 Files selected for processing (6)
  • examples/llm_distill/README.md (3 hunks)
  • examples/llm_distill/accelerate_config/fsdp2.yaml (1 hunks)
  • examples/llm_distill/main.py (2 hunks)
  • examples/llm_distill/requirements.txt (1 hunks)
  • tests/examples/llm_distill/test_llm_distill.py (1 hunks)
  • tests/unit/torch/opt/plugins/test_hf_patching.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/llm_distill/accelerate_config/fsdp2.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_distill/requirements.txt
🧰 Additional context used
🧬 Code graph analysis (3)
tests/examples/llm_distill/test_llm_distill.py (1)
tests/examples/conftest.py (1)
  • tiny_llama_path (33-41)
tests/unit/torch/opt/plugins/test_hf_patching.py (2)
modelopt/torch/distill/distillation_model.py (1)
  • teacher_model (124-126)
tests/_test_utils/torch_model/transformers_models.py (1)
  • get_tiny_qwen3 (44-60)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
  • teacher_model (124-126)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: wait-checks / wait
🔇 Additional comments (2)
tests/examples/llm_distill/test_llm_distill.py (1)

31-35: Flag rename looks correct.

--max_seq_length--max_length aligns with the updated TrainingArguments. Good.

examples/llm_distill/main.py (1)

132-134: Confirm None is valid for generation fields across transformers versions.

Setting temperature/top_p to None is okay on recent versions; older versions expect floats. If you need broader compatibility, delete the attributes instead of assigning None.

Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
@AAnoosheh AAnoosheh force-pushed the aanoosheh/remove-kd-state branch from 3dcd039 to 549d20f Compare September 17, 2025 15:12
Copy link

codecov bot commented Sep 17, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.82%. Comparing base (682bf6d) to head (549d20f).
⚠️ Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #320      +/-   ##
==========================================
- Coverage   73.82%   73.82%   -0.01%     
==========================================
  Files         172      172              
  Lines       17438    17437       -1     
==========================================
- Hits        12874    12872       -2     
- Misses       4564     4565       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/distill/mode.py (1)

183-188: Use a concrete Loss to avoid accidental runtime errors

Loss() is the abstract base; calling it would raise at runtime if ever used. Prefer a harmless concrete default.

 def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
     """Function for resetting the state's config."""
     config.teacher_model = nn.Module
-    config.criterion = Loss()
+    # Use a concrete, parameter-free loss to avoid runtime errors if ever invoked.
+    config.criterion = nn.MSELoss()
     config.loss_balancer = None
examples/llm_distill/main.py (1)

82-90: Single-process robustness: guard world_size when not initialized

torch.distributed.get_world_size() errors if the process group isn’t initialized (e.g., local runs).

-    num_accum_steps = total_batch_size / (
-        training_args.per_device_train_batch_size * torch.distributed.get_world_size()
-    )
+    world_size = (
+        torch.distributed.get_world_size()
+        if torch.distributed.is_available() and torch.distributed.is_initialized()
+        else 1
+    )
+    num_accum_steps = total_batch_size / (
+        training_args.per_device_train_batch_size * world_size
+    )
♻️ Duplicate comments (2)
examples/llm_distill/main.py (2)

116-117: Same kwarg bug for student load: use torch_dtype, not dtype

Matches the above issue.

-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.student_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+        )

109-111: Replace invalid from_pretrained kwarg dtype with torch_dtype

from_pretrained does not accept dtype; use torch_dtype to set model precision.

File: examples/llm_distill/main.py (lines 109-111)

-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_path, dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+        )
🧹 Nitpick comments (2)
examples/llm_distill/main.py (2)

61-66: Guard dataset fields to avoid KeyError/None formatting

Open-Orca variants can miss system_prompt. Use .get and empty-string fallbacks.

-def llama_text_format_func(sample):
-    p, q, r = sample["system_prompt"], sample["question"], sample["response"]
-    if not p:
-        return f"<s>[INST] {q}[/INST]\n{r}</s>"
-    else:
-        return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
+def llama_text_format_func(sample):
+    p = (sample.get("system_prompt") or "").strip()
+    q = (sample.get("question") or "").strip()
+    r = (sample.get("response") or "").strip()
+    if not p:
+        return f"<s>[INST] {q}[/INST]\n{r}</s>"
+    return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"

Please confirm the split’s actual field names (some dumps use “system_prompt”, some “system_prompt_template”, etc.).


137-144: Pass dataset_num_proc to SFTTrainer for faster preprocessing

You defined it in TrainingArguments but don’t pass it; wiring it through speeds tokenization.

     trainer = trainer_cls(
         model,
         training_args,
         train_dataset=dset_train,
         eval_dataset=dset_eval,
         formatting_func=llama_text_format_func,
         processing_class=tokenizer,
+        dataset_num_proc=training_args.dataset_num_proc,
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3dcd039 and 549d20f.

📒 Files selected for processing (12)
  • docs/source/guides/4_distillation.rst (4 hunks)
  • examples/llm_distill/README.md (3 hunks)
  • examples/llm_distill/accelerate_config/fsdp2.yaml (1 hunks)
  • examples/llm_distill/main.py (2 hunks)
  • examples/llm_distill/requirements.txt (1 hunks)
  • modelopt/torch/distill/config.py (2 hunks)
  • modelopt/torch/distill/mode.py (3 hunks)
  • modelopt/torch/distill/plugins/huggingface.py (0 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (0 hunks)
  • tests/examples/llm_distill/test_llm_distill.py (1 hunks)
  • tests/unit/torch/distill/test_distill.py (1 hunks)
  • tests/unit/torch/opt/plugins/test_hf_patching.py (2 hunks)
💤 Files with no reviewable changes (2)
  • modelopt/torch/quantization/plugins/transformers_trainer.py
  • modelopt/torch/distill/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • docs/source/guides/4_distillation.rst
  • examples/llm_distill/accelerate_config/fsdp2.yaml
  • modelopt/torch/distill/config.py
  • tests/unit/torch/opt/plugins/test_hf_patching.py
  • examples/llm_distill/requirements.txt
  • tests/examples/llm_distill/test_llm_distill.py
  • tests/unit/torch/distill/test_distill.py
  • examples/llm_distill/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/distill/mode.py (1)
modelopt/torch/distill/config.py (1)
  • KDLossConfig (34-120)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
  • teacher_model (124-126)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/distill/mode.py (4)

86-87: Unifying pre-new-mode reset via _reset_kd_state_config — OK

Centralizing the reset path here looks good and simplifies lifecycle handling.


179-181: Restore is intentionally a no-op — OK

Matches the documented decision to avoid re-instantiation of the distillation wrapper on restore.

Please ensure docs and error messages clearly state that KD is not reconstructed on restore and must be re-converted if needed.


220-221: Exported-student restore no-op — OK

Consistent with the KD restore behavior; no objections.


88-91: Pre-save hook mutates config in-place; avoid post-save side-effects

update_for_save in KD returns _reset_kd_state_config (modelopt/torch/distill/mode.py:88–91) but conversion calls last_mode.update_for_save(model, last_config, self._last_metadata) on the stored _last_config (modelopt/torch/opt/conversion.py:310). If _reset_kd_state_config mutates teacher/criterion/loss_balancer (or other config fields) in-place, a save performed mid-training will change runtime behavior after saving. Ensure update_for_save either works on a deepcopy or restores the original config (or change the caller to pass a deepcopy) so saves have no side effects.

examples/llm_distill/main.py (2)

50-50: Rename to max_length — OK

Tracks recent TRL/Transformers arg conventions.


100-105: Tokenizer pad token may be None

Some LLM tokenizers lack eos_token/pad_token. Consider adding a pad token if missing to avoid collation errors.

-    tokenizer.pad_token = tokenizer.eos_token
+    if tokenizer.pad_token is None:
+        tokenizer.pad_token = tokenizer.eos_token
+    # If still None, add a pad token and resize later if needed.
+    if tokenizer.pad_token is None:
+        tokenizer.add_special_tokens({"pad_token": "<pad>"})

Comment on lines +121 to +123
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Same kwarg bug for teacher load: use torch_dtype

Keep teacher/student loads consistent.

-        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
-        )
+        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
+            model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
)
🤖 Prompt for AI Agents
In examples/llm_distill/main.py around lines 121 to 123, the teacher model is
loaded using the incorrect keyword argument dtype= when calling
transformers.AutoModelForCausalLM.from_pretrained; change that to torch_dtype=
and pass torch.bfloat16 if training_args.bf16 else None so the teacher load
matches the student load and uses the correct HF transformers parameter.

@kevalmorabia97 kevalmorabia97 merged commit b895dc5 into main Sep 18, 2025
41 of 44 checks passed
@kevalmorabia97 kevalmorabia97 deleted the aanoosheh/remove-kd-state branch September 18, 2025 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants