Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions nemo/collections/llm/modelopt/distill/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,138 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
return self.post_forward(loss, tp_reduce=True)


class MFTLoss(BaseLoss):
"""Calculates the Minifinetuning loss between two logits tensors and with the presence of labels without reducing the sequence dim. This function implements the distillation loss found in the paper: https://arxiv.org/abs/2506.15702."""

def __init__(
self, model_config: "TransformerConfig", threshold: float, temperature: float = 1.0, reverse: bool = False
):
"""Constructor.

Args:
model_config: MCore transformer config.
threshold: Threshold for the MFT loss, used to determine the correction factor for the teacher probability given the ground truth labels.
temperature: Divide tensors by this value prior to calculating loss.
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher)
"""
super().__init__(model_config)
self._temperature = temperature
self._reverse = reverse
self._threshold = threshold

def _prepare_corrected_distributions(
self,
logits: torch.Tensor,
labels: torch.Tensor,
threshold: float,
apply_threshold_to_all: bool = True,
) -> torch.Tensor:
"""Prepare the corrected distributions for MFT loss.

Args:
logits: The logits from the teacher model, shape (batch, channels) # e.g. (batch_size * seq_len, vocab_size)
in case of LMs
labels: The ground truth labels, shape (batch) # e.g. (batch_size * seq_len) in case of LMs
threshold: The threshold value for the MFT correction.
apply_threshold_to_all: If True, apply the threshold correction to all tokens,
not just the incorrect argmax tokens. Defaults to True.

Returns:
A tensor containing the corrected distributions, shape (batch_size * seq_len, vocab_size).
"""
# Ensure logits is a 2D tensor and labels is a 1D tensor
if logits.dim() != 2 or labels.dim() != 1:
raise ValueError("Logits must be a 2D tensor and labels must be a 1D tensor.")
# logits: (batch, channels)
# labels: (batch)
distribution = F.softmax(logits, dim=-1) # (batch, channels)

argmax = distribution.argmax(dim=-1) # (batch,)
incorrect_argmax = argmax != labels # (batch,)

p_argmax = torch.gather(distribution, 1, argmax.unsqueeze(1)).squeeze(1) # (batch,)
p_label = torch.gather(distribution, 1, labels.unsqueeze(1)).squeeze(1) # (batch,)

# correction of the distribution at the tokens where the argmax is incorrect
mixin_factor = (p_argmax - p_label + threshold) / (1 + p_argmax - p_label + 1e-7) # (batch,)
adjusted_incorrect_distribution = distribution * (1 - mixin_factor.unsqueeze(1)) # (batch, channels)
_ = adjusted_incorrect_distribution.scatter_add_(
1, labels.unsqueeze(1), mixin_factor.unsqueeze(1)
) # (batch, channels)

if apply_threshold_to_all:
# correction of the distribution at the tokens where the argmax is correct but
# the separation may not be large enough
capped_targets = torch.where(p_label > 1 - threshold, 1, p_label + threshold) # (batch,)
mixin_factor = (capped_targets - p_argmax) / (1 - p_argmax + 1e-7) # (batch,)
adjusted_correct_distribution = distribution * (1 - mixin_factor.unsqueeze(1)) # (batch, channels)
_ = adjusted_correct_distribution.scatter_add_(1, labels.unsqueeze(1), mixin_factor.unsqueeze(1))
else:
adjusted_correct_distribution = distribution

return torch.where(
incorrect_argmax.unsqueeze(1),
adjusted_incorrect_distribution,
adjusted_correct_distribution,
) # (batch, channels)

def forward(self, predictions: Tensor, targets: Tensor, labels: Tensor) -> Tensor:
"""Forward function.

Args:
predictions: Student model tensors (size [s, b, h])
targets: Teacher model tensors (size [s, b, h])
labels: Ground truth labels (size [b, s])

Returns:
KLD loss of tensors (size [b, s])
"""
predictions, targets = self.pre_forward(predictions, targets)

# Division by temp should happen prior to finding max for both student and teacher.
# Currently we don't use temperature in any of ours runs (temp=1.0)
output_teacher = targets.float() / self._temperature
output_student = predictions.float() / self._temperature

# Compute local softmax, and the reweight to compute global softmax.
if self._config.tensor_model_parallel_size > 1:
raise NotImplementedError(
"MFTLoss does not support tensor model parallelism. Please use sequence parallelism instead."
)

else:
if self._reverse:
teacher_log_softmax: Tensor = F.log_softmax(output_teacher, dim=-1)
student_softmax: Tensor = F.softmax(output_student, dim=-1)
corrected_student_distribution = self._prepare_corrected_distributions(
student_softmax, labels, self._threshold, apply_threshold_to_all=True
)
loss = torch.sum(
F.kl_div(
teacher_log_softmax,
corrected_student_distribution,
reduction="none",
),
dim=-1,
)
else:
student_log_softmax: Tensor = F.log_softmax(output_student, dim=-1)
teacher_softmax: Tensor = F.softmax(output_teacher, dim=-1)
corrected_teacher_distribution = self._prepare_corrected_distributions(
teacher_softmax, labels, self._threshold, apply_threshold_to_all=True
)
loss = torch.sum(
F.kl_div(
student_log_softmax,
corrected_teacher_distribution,
reduction="none",
),
dim=-1,
)

return self.post_forward(loss, tp_reduce=True)


class LogitsAndIntermediatesLossBalancer(DistillationLossBalancer):
"""
LossBalancer implementation for Logit and Intermediate losses.
Expand Down
15 changes: 13 additions & 2 deletions nemo/collections/llm/modelopt/distill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nemo.utils import logging
from nemo.utils.import_utils import safe_import, safe_import_from

from .loss import HiddenStateCosineLoss, LogitsAndIntermediatesLossBalancer, LogitsKLLoss, ProjectionLayer
from .loss import HiddenStateCosineLoss, LogitsAndIntermediatesLossBalancer, LogitsKLLoss, MFTLoss, ProjectionLayer

if TYPE_CHECKING:
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
Expand All @@ -53,19 +53,26 @@ class DistillationConfig:
logit_layers: Tuple of logit layer names.
skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``).
kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``.
use_mft: Whether to use MFT (Minifinetuning) for distillation.
mft_threshold: Threshold for MFT loss, used to determine the correction factor for the teacher probability given the ground truth labels.
"""

intermediate_layer_pairs: List[Tuple[str, str]] = field(default_factory=list)
logit_layers: Tuple[str, str] = ("output_layer", "output_layer")
skip_lm_loss: bool = True
kd_loss_scale: float = 1.0
use_mft: bool = False
mft_threshold: float | None = None
criterion: Optional[Dict[Tuple[str, str], torch.nn.Module]] = None
loss_balancer: Optional[DistillationLossBalancer] = None

def __post_init__(self):
assert len(self.logit_layers) == 2, f"{self.logit_layers=}"
assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), f"{self.intermediate_layer_pairs=}"
assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}"
assert not self.use_mft or (
self.mft_threshold is not None and 1 >= self.mft_threshold >= 0
), f"{self.use_mft=} & {self.mft_threshold=}"


def load_distillation_config(
Expand All @@ -91,7 +98,11 @@ def load_distillation_config(

criterion = {}
if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(student_cfg)
criterion[tuple(cfg.logit_layers)] = (
LogitsKLLoss(student_cfg)
if not cfg.use_mft
else MFTLoss(model_config=student_cfg, threshold=cfg.mft_threshold)
)
# NOTE: Projection layer shared among intermediate layer pairs.
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)

Expand Down