Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 1 addition & 82 deletions modelopt/torch/distill/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
"""Different types of distillation losses."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss as Loss

__all__ = ["LogitsDistillationLoss", "MFTLoss", "MGDLoss"]
__all__ = ["LogitsDistillationLoss", "MFTLoss"]


class LogitsDistillationLoss(Loss):
Expand Down Expand Up @@ -193,83 +192,3 @@ def _prepare_corrected_distributions(
adjusted_incorrect_distribution,
adjusted_correct_distribution,
) # (batch, channels)


class MGDLoss(Loss):
"""PyTorch version of Masked Generative Distillation.

This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529.
"""

def __init__(
self,
num_student_channels: int,
num_teacher_channels: int,
alpha_mgd: float = 1.0,
lambda_mgd: float = 0.65,
):
"""Constructor.

Args:
num_student_channels: Number of channels in the student's feature map.
num_teacher_channels: Number of channels in the teacher's feature map.
alpha_mgd: Scalar final loss is multiplied by. Defaults to 1.0.
lambda_mgd: Masked ratio. Defaults to 0.65.
"""
super().__init__()
self._alpha_mgd: float = alpha_mgd
self._lambda_mgd: float = lambda_mgd

if num_student_channels != num_teacher_channels:
self.align = nn.Conv2d(
num_student_channels,
num_teacher_channels,
kernel_size=1,
stride=1,
padding=0,
)
else:
self.align = nn.Identity()

self.generation = nn.Sequential(
nn.Conv2d(
num_teacher_channels,
num_teacher_channels,
kernel_size=3,
padding=1,
),
nn.ReLU(inplace=True),
nn.Conv2d(
num_teacher_channels,
num_teacher_channels,
kernel_size=3,
padding=1,
),
)

def _loss_fn(self, out_s: torch.Tensor, out_t: torch.Tensor):
n, _, h, w = out_t.shape

mat = torch.rand((n, 1, h, w), device=out_s.device)
mat = torch.where(mat > 1 - self._lambda_mgd, 0, 1)

masked_feats = torch.mul(out_s, mat)
new_feats = self.generation(masked_feats)

kd_loss = F.mse_loss(new_feats, out_t)

return kd_loss

def forward(self, out_s: torch.Tensor, out_t: torch.Tensor):
"""Forward function.

Args:
out_s: Student's feature map (shape BxCxHxW).
out_t: Teacher's feature map (shape BxCxHxW).
"""
assert out_s.shape[-2:] == out_t.shape[-2:]

out_s = self.align(out_s)
loss = self._loss_fn(out_s, out_t) * self._alpha_mgd

return loss
Loading