Skip to content

Commit 935e666

Browse files
committed
Add temperature config field
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent c706315 commit 935e666

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ class DistillationConfig:
5656
logit_layers: Tuple of logit layer names.
5757
skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``).
5858
kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``.
59+
logit_kl_temperature: Temperature for the logit KL-divergence loss.
5960
"""
6061

6162
intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list)
6263
logit_layers: tuple[str, str] = ("output_layer", "output_layer")
6364
skip_lm_loss: bool = True
6465
kd_loss_scale: float = 1.0
66+
logit_kl_temperature: float = 1.0
6567
criterion: Criterion | None = None
6668
loss_balancer: mtd.DistillationLossBalancer | None = None
6769

@@ -71,6 +73,7 @@ def __post_init__(self):
7173
f"{self.intermediate_layer_pairs=}"
7274
)
7375
assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}"
76+
assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}"
7477

7578

7679
def load_distillation_config(
@@ -96,7 +99,9 @@ def load_distillation_config(
9699

97100
criterion = {}
98101
if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():
99-
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(student_cfg)
102+
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(
103+
student_cfg, temperature=cfg.logit_kl_temperature
104+
)
100105
# NOTE: Projection layer shared among intermediate layer pairs.
101106
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)
102107

0 commit comments

Comments
 (0)