diff --git a/modelopt/torch/distill/losses.py b/modelopt/torch/distill/losses.py index 258824bf0..47a38c405 100644 --- a/modelopt/torch/distill/losses.py +++ b/modelopt/torch/distill/losses.py @@ -18,11 +18,10 @@ """Different types of distillation losses.""" import torch -import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss as Loss -__all__ = ["LogitsDistillationLoss", "MFTLoss", "MGDLoss"] +__all__ = ["LogitsDistillationLoss", "MFTLoss"] class LogitsDistillationLoss(Loss): @@ -193,83 +192,3 @@ def _prepare_corrected_distributions( adjusted_incorrect_distribution, adjusted_correct_distribution, ) # (batch, channels) - - -class MGDLoss(Loss): - """PyTorch version of Masked Generative Distillation. - - This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529. - """ - - def __init__( - self, - num_student_channels: int, - num_teacher_channels: int, - alpha_mgd: float = 1.0, - lambda_mgd: float = 0.65, - ): - """Constructor. - - Args: - num_student_channels: Number of channels in the student's feature map. - num_teacher_channels: Number of channels in the teacher's feature map. - alpha_mgd: Scalar final loss is multiplied by. Defaults to 1.0. - lambda_mgd: Masked ratio. Defaults to 0.65. - """ - super().__init__() - self._alpha_mgd: float = alpha_mgd - self._lambda_mgd: float = lambda_mgd - - if num_student_channels != num_teacher_channels: - self.align = nn.Conv2d( - num_student_channels, - num_teacher_channels, - kernel_size=1, - stride=1, - padding=0, - ) - else: - self.align = nn.Identity() - - self.generation = nn.Sequential( - nn.Conv2d( - num_teacher_channels, - num_teacher_channels, - kernel_size=3, - padding=1, - ), - nn.ReLU(inplace=True), - nn.Conv2d( - num_teacher_channels, - num_teacher_channels, - kernel_size=3, - padding=1, - ), - ) - - def _loss_fn(self, out_s: torch.Tensor, out_t: torch.Tensor): - n, _, h, w = out_t.shape - - mat = torch.rand((n, 1, h, w), device=out_s.device) - mat = torch.where(mat > 1 - self._lambda_mgd, 0, 1) - - masked_feats = torch.mul(out_s, mat) - new_feats = self.generation(masked_feats) - - kd_loss = F.mse_loss(new_feats, out_t) - - return kd_loss - - def forward(self, out_s: torch.Tensor, out_t: torch.Tensor): - """Forward function. - - Args: - out_s: Student's feature map (shape BxCxHxW). - out_t: Teacher's feature map (shape BxCxHxW). - """ - assert out_s.shape[-2:] == out_t.shape[-2:] - - out_s = self.align(out_s) - loss = self._loss_fn(out_s, out_t) * self._alpha_mgd - - return loss