Skip to content

Commit 155668a

Browse files
committed
Modify file for no reason
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 6ec8cdc commit 155668a

File tree

1 file changed

+1
-82
lines changed

1 file changed

+1
-82
lines changed

modelopt/torch/distill/losses.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
"""Different types of distillation losses."""
1919

2020
import torch
21-
import torch.nn as nn
2221
import torch.nn.functional as F
2322
from torch.nn.modules.loss import _Loss as Loss
2423

25-
__all__ = ["LogitsDistillationLoss", "MFTLoss", "MGDLoss"]
24+
__all__ = ["LogitsDistillationLoss", "MFTLoss"]
2625

2726

2827
class LogitsDistillationLoss(Loss):
@@ -193,83 +192,3 @@ def _prepare_corrected_distributions(
193192
adjusted_incorrect_distribution,
194193
adjusted_correct_distribution,
195194
) # (batch, channels)
196-
197-
198-
class MGDLoss(Loss):
199-
"""PyTorch version of Masked Generative Distillation.
200-
201-
This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529.
202-
"""
203-
204-
def __init__(
205-
self,
206-
num_student_channels: int,
207-
num_teacher_channels: int,
208-
alpha_mgd: float = 1.0,
209-
lambda_mgd: float = 0.65,
210-
):
211-
"""Constructor.
212-
213-
Args:
214-
num_student_channels: Number of channels in the student's feature map.
215-
num_teacher_channels: Number of channels in the teacher's feature map.
216-
alpha_mgd: Scalar final loss is multiplied by. Defaults to 1.0.
217-
lambda_mgd: Masked ratio. Defaults to 0.65.
218-
"""
219-
super().__init__()
220-
self._alpha_mgd: float = alpha_mgd
221-
self._lambda_mgd: float = lambda_mgd
222-
223-
if num_student_channels != num_teacher_channels:
224-
self.align = nn.Conv2d(
225-
num_student_channels,
226-
num_teacher_channels,
227-
kernel_size=1,
228-
stride=1,
229-
padding=0,
230-
)
231-
else:
232-
self.align = nn.Identity()
233-
234-
self.generation = nn.Sequential(
235-
nn.Conv2d(
236-
num_teacher_channels,
237-
num_teacher_channels,
238-
kernel_size=3,
239-
padding=1,
240-
),
241-
nn.ReLU(inplace=True),
242-
nn.Conv2d(
243-
num_teacher_channels,
244-
num_teacher_channels,
245-
kernel_size=3,
246-
padding=1,
247-
),
248-
)
249-
250-
def _loss_fn(self, out_s: torch.Tensor, out_t: torch.Tensor):
251-
n, _, h, w = out_t.shape
252-
253-
mat = torch.rand((n, 1, h, w), device=out_s.device)
254-
mat = torch.where(mat > 1 - self._lambda_mgd, 0, 1)
255-
256-
masked_feats = torch.mul(out_s, mat)
257-
new_feats = self.generation(masked_feats)
258-
259-
kd_loss = F.mse_loss(new_feats, out_t)
260-
261-
return kd_loss
262-
263-
def forward(self, out_s: torch.Tensor, out_t: torch.Tensor):
264-
"""Forward function.
265-
266-
Args:
267-
out_s: Student's feature map (shape BxCxHxW).
268-
out_t: Teacher's feature map (shape BxCxHxW).
269-
"""
270-
assert out_s.shape[-2:] == out_t.shape[-2:]
271-
272-
out_s = self.align(out_s)
273-
loss = self._loss_fn(out_s, out_t) * self._alpha_mgd
274-
275-
return loss

0 commit comments

Comments
 (0)