Skip to content

Commit 281357d

Browse files
committed
Copy changes made to Megatron-LM
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 26c203a commit 281357d

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class DistillationConfig:
5959
logit_kl_temperature: Temperature for the logit KL-divergence loss.
6060
"""
6161

62-
intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list)
62+
intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list)
6363
logit_layers: tuple[str, str] = ("output_layer", "output_layer")
6464
skip_lm_loss: bool = True
6565
kd_loss_scale: float = 1.0
@@ -69,12 +69,28 @@ class DistillationConfig:
6969

7070
def __post_init__(self):
7171
assert len(self.logit_layers) == 2, f"{self.logit_layers=}"
72-
assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), (
72+
assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), (
7373
f"{self.intermediate_layer_pairs=}"
7474
)
7575
assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}"
7676
assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}"
7777

78+
@staticmethod
79+
def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]:
80+
"""Parse an intermediate entry into a student layer, teacher layer, and loss function."""
81+
if len(entry) == 3:
82+
student_layer, teacher_layer, loss_fn_name = entry
83+
if loss_fn_name == "cosine":
84+
loss_fn = HiddenStateCosineLoss
85+
elif loss_fn_name == "mse":
86+
loss_fn = MSELoss
87+
else:
88+
raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}")
89+
else:
90+
student_layer, teacher_layer = entry
91+
loss_fn = HiddenStateCosineLoss # default to cosine loss
92+
return student_layer, teacher_layer, loss_fn
93+
7894

7995
def load_distillation_config(
8096
config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig"
@@ -105,7 +121,8 @@ def load_distillation_config(
105121
# NOTE: Projection layer shared among intermediate layer pairs.
106122
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)
107123

108-
for student_layer, teacher_layer in cfg.intermediate_layer_pairs:
124+
for entry in cfg.intermediate_layer_pairs:
125+
student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry)
109126
if parallel_state.get_tensor_and_context_parallel_rank() == 0:
110127
logger.info(
111128
"Distillation: Adding intermediate loss between"
@@ -114,7 +131,7 @@ def load_distillation_config(
114131
)
115132
student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg)
116133
teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg)
117-
criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss(
134+
criterion[(student_layer, teacher_layer)] = loss_fn(
118135
student_cfg, projection_layer=projection_layer
119136
)
120137

@@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
202219
predictions, targets = self.pre_forward(predictions, targets)
203220

204221
loss = F.mse_loss(predictions, targets, reduction="none")
205-
loss = loss.sum(dim=-1)
222+
loss = loss.mean(dim=-1)
206223

207-
return self.post_forward(loss)
224+
return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel)
208225

209226

210227
class HiddenStateCosineLoss(BaseLoss):

0 commit comments

Comments
 (0)