@@ -56,12 +56,14 @@ class DistillationConfig:
56
56
logit_layers: Tuple of logit layer names.
57
57
skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``).
58
58
kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``.
59
+ logit_kl_temperature: Temperature for the logit KL-divergence loss.
59
60
"""
60
61
61
62
intermediate_layer_pairs : list [tuple [str , str ]] = field (default_factory = list )
62
63
logit_layers : tuple [str , str ] = ("output_layer" , "output_layer" )
63
64
skip_lm_loss : bool = True
64
65
kd_loss_scale : float = 1.0
66
+ logit_kl_temperature : float = 1.0
65
67
criterion : Criterion | None = None
66
68
loss_balancer : mtd .DistillationLossBalancer | None = None
67
69
@@ -71,6 +73,7 @@ def __post_init__(self):
71
73
f"{ self .intermediate_layer_pairs = } "
72
74
)
73
75
assert self .kd_loss_scale > 0 , f"{ self .kd_loss_scale = } "
76
+ assert self .logit_kl_temperature > 0 , f"{ self .logit_kl_temperature = } "
74
77
75
78
76
79
def load_distillation_config (
@@ -96,7 +99,9 @@ def load_distillation_config(
96
99
97
100
criterion = {}
98
101
if student_cfg .pipeline_model_parallel_size == 1 or parallel_state .is_pipeline_last_stage ():
99
- criterion [tuple (cfg .logit_layers )] = LogitsKLLoss (student_cfg )
102
+ criterion [tuple (cfg .logit_layers )] = LogitsKLLoss (
103
+ student_cfg , temperature = cfg .logit_kl_temperature
104
+ )
100
105
# NOTE: Projection layer shared among intermediate layer pairs.
101
106
projection_layer = ProjectionLayer (student_cfg , teacher_cfg )
102
107
0 commit comments