Skip to content

Commit 4763c02

Browse files
Update main
1 parent 42f029c commit 4763c02

File tree

6 files changed

+146
-180
lines changed

6 files changed

+146
-180
lines changed

dd_ranking/aug/cutmix.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,6 @@
55

66
class Cutmix_Augmentation:
77
def __init__(self, params: dict):
8-
# self.transform = kornia.augmentation.RandomCutMixV2(
9-
# num_mix = params["times"],
10-
# cut_size = params["size"],
11-
# same_on_batch = params["same_on_batch"],
12-
# beta = params["beta"],
13-
# keepdim = params["keep_dim"],
14-
# p = params["prob"]
15-
# )
16-
178
self.cutmix_p = params["cutmix_p"]
189

1910
def rand_bbox(self, size, lam):

dd_ranking/aug/mixup.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55

66
class Mixup_Augmentation:
77
def __init__(self, params: dict):
8-
# self.transform = kornia.augmentation.RandomMixUpV2(
9-
# lambda_val = params["lambda_range"],
10-
# same_on_batch = params["same_on_batch"],
11-
# keepdim = params["keepdim"],
12-
# p = params["prob"]
13-
# )
14-
158
self.mixup_p = params["mixup_p"]
169

1710
def mixup(self, images):

dd_ranking/loss/kl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from torch.nn import KLDivLoss
54

65

76
class KLDivergenceLoss(nn.Module):
87
def __init__(self, temperature=1.2):
98
super(KLDivergenceLoss, self).__init__()
109
self.temperature = temperature
11-
self.kl = KLDivLoss(reduction='batchmean')
1210

1311
def forward(self, stu_outputs, tea_outputs):
1412
stu_probs = F.log_softmax(stu_outputs / self.temperature, dim=1)
15-
tea_probs = F.softmax(tea_outputs / self.temperature, dim=1)
16-
loss = self.kl(stu_probs, tea_probs)
13+
with torch.no_grad():
14+
tea_probs = F.softmax(tea_outputs / self.temperature, dim=1)
15+
loss = F.kl_div(stu_probs, tea_probs, reduction='batchmean')
1716
return loss

dd_ranking/loss/sce.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55

66
class SoftCrossEntropyLoss(nn.Module):
7-
def __init__(self):
7+
def __init__(self, temperature=1.2):
88
super(SoftCrossEntropyLoss, self).__init__()
9+
self.temperature = temperature
910

1011
def forward(self, stu_outputs, tea_outputs):
11-
input_log_likelihood = -F.log_softmax(stu_outputs, dim=1)
12-
target_log_likelihood = F.softmax(tea_outputs, dim=1)
12+
input_log_likelihood = -F.log_softmax(stu_outputs / self.temperature, dim=1)
13+
target_log_likelihood = F.softmax(tea_outputs / self.temperature, dim=1)
1314
batch_size = stu_outputs.size(0)
1415
loss = torch.sum(torch.mul(input_log_likelihood, target_log_likelihood)) / batch_size
1516
return loss

0 commit comments

Comments
 (0)