|
18 | 18 | """Different types of distillation losses."""
|
19 | 19 |
|
20 | 20 | import torch
|
21 |
| -import torch.nn as nn |
22 | 21 | import torch.nn.functional as F
|
23 | 22 | from torch.nn.modules.loss import _Loss as Loss
|
24 | 23 |
|
25 |
| -__all__ = ["LogitsDistillationLoss", "MFTLoss", "MGDLoss"] |
| 24 | +__all__ = ["LogitsDistillationLoss", "MFTLoss"] |
26 | 25 |
|
27 | 26 |
|
28 | 27 | class LogitsDistillationLoss(Loss):
|
@@ -193,83 +192,3 @@ def _prepare_corrected_distributions(
|
193 | 192 | adjusted_incorrect_distribution,
|
194 | 193 | adjusted_correct_distribution,
|
195 | 194 | ) # (batch, channels)
|
196 |
| - |
197 |
| - |
198 |
| -class MGDLoss(Loss): |
199 |
| - """PyTorch version of Masked Generative Distillation. |
200 |
| -
|
201 |
| - This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529. |
202 |
| - """ |
203 |
| - |
204 |
| - def __init__( |
205 |
| - self, |
206 |
| - num_student_channels: int, |
207 |
| - num_teacher_channels: int, |
208 |
| - alpha_mgd: float = 1.0, |
209 |
| - lambda_mgd: float = 0.65, |
210 |
| - ): |
211 |
| - """Constructor. |
212 |
| -
|
213 |
| - Args: |
214 |
| - num_student_channels: Number of channels in the student's feature map. |
215 |
| - num_teacher_channels: Number of channels in the teacher's feature map. |
216 |
| - alpha_mgd: Scalar final loss is multiplied by. Defaults to 1.0. |
217 |
| - lambda_mgd: Masked ratio. Defaults to 0.65. |
218 |
| - """ |
219 |
| - super().__init__() |
220 |
| - self._alpha_mgd: float = alpha_mgd |
221 |
| - self._lambda_mgd: float = lambda_mgd |
222 |
| - |
223 |
| - if num_student_channels != num_teacher_channels: |
224 |
| - self.align = nn.Conv2d( |
225 |
| - num_student_channels, |
226 |
| - num_teacher_channels, |
227 |
| - kernel_size=1, |
228 |
| - stride=1, |
229 |
| - padding=0, |
230 |
| - ) |
231 |
| - else: |
232 |
| - self.align = nn.Identity() |
233 |
| - |
234 |
| - self.generation = nn.Sequential( |
235 |
| - nn.Conv2d( |
236 |
| - num_teacher_channels, |
237 |
| - num_teacher_channels, |
238 |
| - kernel_size=3, |
239 |
| - padding=1, |
240 |
| - ), |
241 |
| - nn.ReLU(inplace=True), |
242 |
| - nn.Conv2d( |
243 |
| - num_teacher_channels, |
244 |
| - num_teacher_channels, |
245 |
| - kernel_size=3, |
246 |
| - padding=1, |
247 |
| - ), |
248 |
| - ) |
249 |
| - |
250 |
| - def _loss_fn(self, out_s: torch.Tensor, out_t: torch.Tensor): |
251 |
| - n, _, h, w = out_t.shape |
252 |
| - |
253 |
| - mat = torch.rand((n, 1, h, w), device=out_s.device) |
254 |
| - mat = torch.where(mat > 1 - self._lambda_mgd, 0, 1) |
255 |
| - |
256 |
| - masked_feats = torch.mul(out_s, mat) |
257 |
| - new_feats = self.generation(masked_feats) |
258 |
| - |
259 |
| - kd_loss = F.mse_loss(new_feats, out_t) |
260 |
| - |
261 |
| - return kd_loss |
262 |
| - |
263 |
| - def forward(self, out_s: torch.Tensor, out_t: torch.Tensor): |
264 |
| - """Forward function. |
265 |
| -
|
266 |
| - Args: |
267 |
| - out_s: Student's feature map (shape BxCxHxW). |
268 |
| - out_t: Teacher's feature map (shape BxCxHxW). |
269 |
| - """ |
270 |
| - assert out_s.shape[-2:] == out_t.shape[-2:] |
271 |
| - |
272 |
| - out_s = self.align(out_s) |
273 |
| - loss = self._loss_fn(out_s, out_t) * self._alpha_mgd |
274 |
| - |
275 |
| - return loss |
0 commit comments