Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions modelopt/torch/distill/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DistillationConfig:
logit_kl_temperature: Temperature for the logit KL-divergence loss.
"""

intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list)
intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list)
logit_layers: tuple[str, str] = ("output_layer", "output_layer")
skip_lm_loss: bool = True
kd_loss_scale: float = 1.0
Expand All @@ -69,12 +69,28 @@ class DistillationConfig:

def __post_init__(self):
assert len(self.logit_layers) == 2, f"{self.logit_layers=}"
assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), (
assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), (
f"{self.intermediate_layer_pairs=}"
)
assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}"
assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}"

@staticmethod
def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]:
"""Parse an intermediate entry into a student layer, teacher layer, and loss function."""
if len(entry) == 3:
student_layer, teacher_layer, loss_fn_name = entry
if loss_fn_name == "cosine":
loss_fn = HiddenStateCosineLoss
elif loss_fn_name == "mse":
loss_fn = MSELoss
else:
raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}")
else:
student_layer, teacher_layer = entry
loss_fn = HiddenStateCosineLoss # default to cosine loss
return student_layer, teacher_layer, loss_fn


def load_distillation_config(
config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig"
Expand Down Expand Up @@ -105,7 +121,8 @@ def load_distillation_config(
# NOTE: Projection layer shared among intermediate layer pairs.
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)

for student_layer, teacher_layer in cfg.intermediate_layer_pairs:
for entry in cfg.intermediate_layer_pairs:
student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry)
if parallel_state.get_tensor_and_context_parallel_rank() == 0:
logger.info(
"Distillation: Adding intermediate loss between"
Expand All @@ -114,7 +131,7 @@ def load_distillation_config(
)
student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg)
teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg)
criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss(
criterion[(student_layer, teacher_layer)] = loss_fn(
student_cfg, projection_layer=projection_layer
)

Expand Down Expand Up @@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
predictions, targets = self.pre_forward(predictions, targets)

loss = F.mse_loss(predictions, targets, reduction="none")
loss = loss.sum(dim=-1)
loss = loss.mean(dim=-1)

return self.post_forward(loss)
return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel)


class HiddenStateCosineLoss(BaseLoss):
Expand Down
Loading