@@ -59,7 +59,7 @@ class DistillationConfig:
5959 logit_kl_temperature: Temperature for the logit KL-divergence loss.
6060 """
6161
62- intermediate_layer_pairs : list [tuple [str , str ]] = field (default_factory = list )
62+ intermediate_layer_pairs : list [tuple [str , ... ]] = field (default_factory = list )
6363 logit_layers : tuple [str , str ] = ("output_layer" , "output_layer" )
6464 skip_lm_loss : bool = True
6565 kd_loss_scale : float = 1.0
@@ -69,12 +69,28 @@ class DistillationConfig:
6969
7070 def __post_init__ (self ):
7171 assert len (self .logit_layers ) == 2 , f"{ self .logit_layers = } "
72- assert all (len (pair ) == 2 for pair in self .intermediate_layer_pairs ), (
72+ assert all (len (pair ) in ( 2 , 3 ) for pair in self .intermediate_layer_pairs ), (
7373 f"{ self .intermediate_layer_pairs = } "
7474 )
7575 assert self .kd_loss_scale > 0 , f"{ self .kd_loss_scale = } "
7676 assert self .logit_kl_temperature > 0 , f"{ self .logit_kl_temperature = } "
7777
78+ @staticmethod
79+ def parse_intermediate_entry (entry : tuple [str , ...]) -> tuple [str , str , Callable ]:
80+ """Parse an intermediate entry into a student layer, teacher layer, and loss function."""
81+ if len (entry ) == 3 :
82+ student_layer , teacher_layer , loss_fn_name = entry
83+ if loss_fn_name == "cosine" :
84+ loss_fn = HiddenStateCosineLoss
85+ elif loss_fn_name == "mse" :
86+ loss_fn = MSELoss
87+ else :
88+ raise ValueError (f"Unknown intermediate loss function: { loss_fn_name } " )
89+ else :
90+ student_layer , teacher_layer = entry
91+ loss_fn = HiddenStateCosineLoss # default to cosine loss
92+ return student_layer , teacher_layer , loss_fn
93+
7894
7995def load_distillation_config (
8096 config_path : str | None , student_cfg : "TransformerConfig" , teacher_cfg : "TransformerConfig"
@@ -105,7 +121,8 @@ def load_distillation_config(
105121 # NOTE: Projection layer shared among intermediate layer pairs.
106122 projection_layer = ProjectionLayer (student_cfg , teacher_cfg )
107123
108- for student_layer , teacher_layer in cfg .intermediate_layer_pairs :
124+ for entry in cfg .intermediate_layer_pairs :
125+ student_layer , teacher_layer , loss_fn = cfg .parse_intermediate_entry (entry )
109126 if parallel_state .get_tensor_and_context_parallel_rank () == 0 :
110127 logger .info (
111128 "Distillation: Adding intermediate loss between"
@@ -114,7 +131,7 @@ def load_distillation_config(
114131 )
115132 student_layer = _adjust_layer_index_for_pp (student_layer , student_cfg )
116133 teacher_layer = _adjust_layer_index_for_pp (teacher_layer , teacher_cfg )
117- criterion [(student_layer , teacher_layer )] = HiddenStateCosineLoss (
134+ criterion [(student_layer , teacher_layer )] = loss_fn (
118135 student_cfg , projection_layer = projection_layer
119136 )
120137
@@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
202219 predictions , targets = self .pre_forward (predictions , targets )
203220
204221 loss = F .mse_loss (predictions , targets , reduction = "none" )
205- loss = loss .sum (dim = - 1 )
222+ loss = loss .mean (dim = - 1 )
206223
207- return self .post_forward (loss )
224+ return self .post_forward (loss , is_sequence_parallel = self . _config . sequence_parallel )
208225
209226
210227class HiddenStateCosineLoss (BaseLoss ):
0 commit comments