|
18 | 18 | """Different types of distillation losses.""" |
19 | 19 |
|
20 | 20 | import torch |
21 | | -import torch.nn as nn |
22 | 21 | import torch.nn.functional as F |
23 | 22 | from torch.nn.modules.loss import _Loss as Loss |
24 | 23 |
|
25 | | -__all__ = ["LogitsDistillationLoss", "MFTLoss", "MGDLoss"] |
| 24 | +__all__ = ["LogitsDistillationLoss", "MFTLoss"] |
26 | 25 |
|
27 | 26 |
|
28 | 27 | class LogitsDistillationLoss(Loss): |
@@ -193,83 +192,3 @@ def _prepare_corrected_distributions( |
193 | 192 | adjusted_incorrect_distribution, |
194 | 193 | adjusted_correct_distribution, |
195 | 194 | ) # (batch, channels) |
196 | | - |
197 | | - |
198 | | -class MGDLoss(Loss): |
199 | | - """PyTorch version of Masked Generative Distillation. |
200 | | -
|
201 | | - This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529. |
202 | | - """ |
203 | | - |
204 | | - def __init__( |
205 | | - self, |
206 | | - num_student_channels: int, |
207 | | - num_teacher_channels: int, |
208 | | - alpha_mgd: float = 1.0, |
209 | | - lambda_mgd: float = 0.65, |
210 | | - ): |
211 | | - """Constructor. |
212 | | -
|
213 | | - Args: |
214 | | - num_student_channels: Number of channels in the student's feature map. |
215 | | - num_teacher_channels: Number of channels in the teacher's feature map. |
216 | | - alpha_mgd: Scalar final loss is multiplied by. Defaults to 1.0. |
217 | | - lambda_mgd: Masked ratio. Defaults to 0.65. |
218 | | - """ |
219 | | - super().__init__() |
220 | | - self._alpha_mgd: float = alpha_mgd |
221 | | - self._lambda_mgd: float = lambda_mgd |
222 | | - |
223 | | - if num_student_channels != num_teacher_channels: |
224 | | - self.align = nn.Conv2d( |
225 | | - num_student_channels, |
226 | | - num_teacher_channels, |
227 | | - kernel_size=1, |
228 | | - stride=1, |
229 | | - padding=0, |
230 | | - ) |
231 | | - else: |
232 | | - self.align = nn.Identity() |
233 | | - |
234 | | - self.generation = nn.Sequential( |
235 | | - nn.Conv2d( |
236 | | - num_teacher_channels, |
237 | | - num_teacher_channels, |
238 | | - kernel_size=3, |
239 | | - padding=1, |
240 | | - ), |
241 | | - nn.ReLU(inplace=True), |
242 | | - nn.Conv2d( |
243 | | - num_teacher_channels, |
244 | | - num_teacher_channels, |
245 | | - kernel_size=3, |
246 | | - padding=1, |
247 | | - ), |
248 | | - ) |
249 | | - |
250 | | - def _loss_fn(self, out_s: torch.Tensor, out_t: torch.Tensor): |
251 | | - n, _, h, w = out_t.shape |
252 | | - |
253 | | - mat = torch.rand((n, 1, h, w), device=out_s.device) |
254 | | - mat = torch.where(mat > 1 - self._lambda_mgd, 0, 1) |
255 | | - |
256 | | - masked_feats = torch.mul(out_s, mat) |
257 | | - new_feats = self.generation(masked_feats) |
258 | | - |
259 | | - kd_loss = F.mse_loss(new_feats, out_t) |
260 | | - |
261 | | - return kd_loss |
262 | | - |
263 | | - def forward(self, out_s: torch.Tensor, out_t: torch.Tensor): |
264 | | - """Forward function. |
265 | | -
|
266 | | - Args: |
267 | | - out_s: Student's feature map (shape BxCxHxW). |
268 | | - out_t: Teacher's feature map (shape BxCxHxW). |
269 | | - """ |
270 | | - assert out_s.shape[-2:] == out_t.shape[-2:] |
271 | | - |
272 | | - out_s = self.align(out_s) |
273 | | - loss = self._loss_fn(out_s, out_t) * self._alpha_mgd |
274 | | - |
275 | | - return loss |
0 commit comments