Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions modelopt/torch/distill/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,30 @@ def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable
return student_layer, teacher_layer, loss_fn


def load_distillation_config(
config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig"
def setup_distillation_config(
config_or_path: str | DistillationConfig | None,
student_cfg: "TransformerConfig",
teacher_cfg: "TransformerConfig",
) -> DistillationConfig:
"""Read the distillation yaml config file specified by ``args.export_kd_cfg``.

Args:
config_path: Path to user-defined distillation settings yaml file.
config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself.
If `None`, uses default logits-only distillation mode for GPT models.
student_cfg: Model config for student model.
teacher_cfg: Model config for teacher model.

WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute.
"""
if config_path:
with open(config_path) as f:
cfg = yaml.safe_load(f)
cfg = DistillationConfig(**cfg)
else:
if config_or_path is None:
logger.warning("Distillation config not provided. Using default.")
cfg = DistillationConfig()
elif isinstance(config_or_path, DistillationConfig):
cfg = config_or_path
else:
with open(config_or_path) as f:
cfg = yaml.safe_load(f)
cfg = DistillationConfig(**cfg)
Comment on lines +110 to +118
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider the implications of mutating the input DistillationConfig.

When a DistillationConfig instance is passed in (line 114), the function directly assigns it to cfg and later mutates it by setting cfg.criterion (line 146) and cfg.loss_balancer (line 147). This means the caller's original config object will be modified, which may be unexpected if they intend to reuse the config.

Consider one of these approaches:

  1. Document the mutation in the docstring:
         config_or_path: One of:
             - `None`: Uses default logits-only distillation mode for GPT models.
-            - `DistillationConfig`: Uses the provided config instance directly.
+            - `DistillationConfig`: Uses the provided config instance directly (will be modified in-place).
             - `str`: Path to a YAML file containing distillation settings.
  1. Create a copy to avoid mutating the input:
     elif isinstance(config_or_path, DistillationConfig):
-        cfg = config_or_path
+        from copy import deepcopy
+        cfg = deepcopy(config_or_path)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if config_or_path is None:
logger.warning("Distillation config not provided. Using default.")
cfg = DistillationConfig()
elif isinstance(config_or_path, DistillationConfig):
cfg = config_or_path
else:
with open(config_or_path) as f:
cfg = yaml.safe_load(f)
cfg = DistillationConfig(**cfg)
if config_or_path is None:
logger.warning("Distillation config not provided. Using default.")
cfg = DistillationConfig()
elif isinstance(config_or_path, DistillationConfig):
from copy import deepcopy
cfg = deepcopy(config_or_path)
else:
with open(config_or_path) as f:
cfg = yaml.safe_load(f)
cfg = DistillationConfig(**cfg)
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 110 to 118, the
function assigns a passed-in DistillationConfig instance directly to cfg and
later mutates it, which unintentionally alters the caller's object; to fix this,
create a shallow or deep copy of the incoming DistillationConfig (e.g., via
copy.deepcopy or a provided copy/clone/from_dict constructor) and assign that
copy to cfg before any mutations so the original remains unchanged, or
alternatively document in the function docstring that the input config will be
mutated if that behavior is intended.


criterion = {}
if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():
Expand Down