@@ -56,12 +56,14 @@ class DistillationConfig:
5656 logit_layers: Tuple of logit layer names.
5757 skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``).
5858 kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``.
59+ logit_kl_temperature: Temperature for the logit KL-divergence loss.
5960 """
6061
6162 intermediate_layer_pairs : list [tuple [str , str ]] = field (default_factory = list )
6263 logit_layers : tuple [str , str ] = ("output_layer" , "output_layer" )
6364 skip_lm_loss : bool = True
6465 kd_loss_scale : float = 1.0
66+ logit_kl_temperature : float = 1.0
6567 criterion : Criterion | None = None
6668 loss_balancer : mtd .DistillationLossBalancer | None = None
6769
@@ -71,6 +73,7 @@ def __post_init__(self):
7173 f"{ self .intermediate_layer_pairs = } "
7274 )
7375 assert self .kd_loss_scale > 0 , f"{ self .kd_loss_scale = } "
76+ assert self .logit_kl_temperature > 0 , f"{ self .logit_kl_temperature = } "
7477
7578
7679def load_distillation_config (
@@ -96,7 +99,9 @@ def load_distillation_config(
9699
97100 criterion = {}
98101 if student_cfg .pipeline_model_parallel_size == 1 or parallel_state .is_pipeline_last_stage ():
99- criterion [tuple (cfg .logit_layers )] = LogitsKLLoss (student_cfg )
102+ criterion [tuple (cfg .logit_layers )] = LogitsKLLoss (
103+ student_cfg , temperature = cfg .logit_kl_temperature
104+ )
100105 # NOTE: Projection layer shared among intermediate layer pairs.
101106 projection_layer = ProjectionLayer (student_cfg , teacher_cfg )
102107
0 commit comments