@@ -59,7 +59,7 @@ class DistillationConfig:
59
59
logit_kl_temperature: Temperature for the logit KL-divergence loss.
60
60
"""
61
61
62
- intermediate_layer_pairs : list [tuple [str , str ]] = field (default_factory = list )
62
+ intermediate_layer_pairs : list [tuple [str , ... ]] = field (default_factory = list )
63
63
logit_layers : tuple [str , str ] = ("output_layer" , "output_layer" )
64
64
skip_lm_loss : bool = True
65
65
kd_loss_scale : float = 1.0
@@ -69,12 +69,28 @@ class DistillationConfig:
69
69
70
70
def __post_init__ (self ):
71
71
assert len (self .logit_layers ) == 2 , f"{ self .logit_layers = } "
72
- assert all (len (pair ) == 2 for pair in self .intermediate_layer_pairs ), (
72
+ assert all (len (pair ) in ( 2 , 3 ) for pair in self .intermediate_layer_pairs ), (
73
73
f"{ self .intermediate_layer_pairs = } "
74
74
)
75
75
assert self .kd_loss_scale > 0 , f"{ self .kd_loss_scale = } "
76
76
assert self .logit_kl_temperature > 0 , f"{ self .logit_kl_temperature = } "
77
77
78
+ @staticmethod
79
+ def parse_intermediate_entry (entry : tuple [str , ...]) -> tuple [str , str , Callable ]:
80
+ """Parse an intermediate entry into a student layer, teacher layer, and loss function."""
81
+ if len (entry ) == 3 :
82
+ student_layer , teacher_layer , loss_fn_name = entry
83
+ if loss_fn_name == "cosine" :
84
+ loss_fn = HiddenStateCosineLoss
85
+ elif loss_fn_name == "mse" :
86
+ loss_fn = MSELoss
87
+ else :
88
+ raise ValueError (f"Unknown intermediate loss function: { loss_fn_name } " )
89
+ else :
90
+ student_layer , teacher_layer = entry
91
+ loss_fn = HiddenStateCosineLoss # default to cosine loss
92
+ return student_layer , teacher_layer , loss_fn
93
+
78
94
79
95
def load_distillation_config (
80
96
config_path : str | None , student_cfg : "TransformerConfig" , teacher_cfg : "TransformerConfig"
@@ -105,7 +121,8 @@ def load_distillation_config(
105
121
# NOTE: Projection layer shared among intermediate layer pairs.
106
122
projection_layer = ProjectionLayer (student_cfg , teacher_cfg )
107
123
108
- for student_layer , teacher_layer in cfg .intermediate_layer_pairs :
124
+ for entry in cfg .intermediate_layer_pairs :
125
+ student_layer , teacher_layer , loss_fn = cfg .parse_intermediate_entry (entry )
109
126
if parallel_state .get_tensor_and_context_parallel_rank () == 0 :
110
127
logger .info (
111
128
"Distillation: Adding intermediate loss between"
@@ -114,7 +131,7 @@ def load_distillation_config(
114
131
)
115
132
student_layer = _adjust_layer_index_for_pp (student_layer , student_cfg )
116
133
teacher_layer = _adjust_layer_index_for_pp (teacher_layer , teacher_cfg )
117
- criterion [(student_layer , teacher_layer )] = HiddenStateCosineLoss (
134
+ criterion [(student_layer , teacher_layer )] = loss_fn (
118
135
student_cfg , projection_layer = projection_layer
119
136
)
120
137
@@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
202
219
predictions , targets = self .pre_forward (predictions , targets )
203
220
204
221
loss = F .mse_loss (predictions , targets , reduction = "none" )
205
- loss = loss .sum (dim = - 1 )
222
+ loss = loss .mean (dim = - 1 )
206
223
207
- return self .post_forward (loss )
224
+ return self .post_forward (loss , is_sequence_parallel = self . _config . sequence_parallel )
208
225
209
226
210
227
class HiddenStateCosineLoss (BaseLoss ):
0 commit comments