File tree Expand file tree Collapse file tree 6 files changed +146
-180
lines changed
Expand file tree Collapse file tree 6 files changed +146
-180
lines changed Original file line number Diff line number Diff line change 55
66class Cutmix_Augmentation :
77 def __init__ (self , params : dict ):
8- # self.transform = kornia.augmentation.RandomCutMixV2(
9- # num_mix = params["times"],
10- # cut_size = params["size"],
11- # same_on_batch = params["same_on_batch"],
12- # beta = params["beta"],
13- # keepdim = params["keep_dim"],
14- # p = params["prob"]
15- # )
16-
178 self .cutmix_p = params ["cutmix_p" ]
189
1910 def rand_bbox (self , size , lam ):
Original file line number Diff line number Diff line change 55
66class Mixup_Augmentation :
77 def __init__ (self , params : dict ):
8- # self.transform = kornia.augmentation.RandomMixUpV2(
9- # lambda_val = params["lambda_range"],
10- # same_on_batch = params["same_on_batch"],
11- # keepdim = params["keepdim"],
12- # p = params["prob"]
13- # )
14-
158 self .mixup_p = params ["mixup_p" ]
169
1710 def mixup (self , images ):
Original file line number Diff line number Diff line change 11import torch
22import torch .nn as nn
33import torch .nn .functional as F
4- from torch .nn import KLDivLoss
54
65
76class KLDivergenceLoss (nn .Module ):
87 def __init__ (self , temperature = 1.2 ):
98 super (KLDivergenceLoss , self ).__init__ ()
109 self .temperature = temperature
11- self .kl = KLDivLoss (reduction = 'batchmean' )
1210
1311 def forward (self , stu_outputs , tea_outputs ):
1412 stu_probs = F .log_softmax (stu_outputs / self .temperature , dim = 1 )
15- tea_probs = F .softmax (tea_outputs / self .temperature , dim = 1 )
16- loss = self .kl (stu_probs , tea_probs )
13+ with torch .no_grad ():
14+ tea_probs = F .softmax (tea_outputs / self .temperature , dim = 1 )
15+ loss = F .kl_div (stu_probs , tea_probs , reduction = 'batchmean' )
1716 return loss
Original file line number Diff line number Diff line change 44
55
66class SoftCrossEntropyLoss (nn .Module ):
7- def __init__ (self ):
7+ def __init__ (self , temperature = 1.2 ):
88 super (SoftCrossEntropyLoss , self ).__init__ ()
9+ self .temperature = temperature
910
1011 def forward (self , stu_outputs , tea_outputs ):
11- input_log_likelihood = - F .log_softmax (stu_outputs , dim = 1 )
12- target_log_likelihood = F .softmax (tea_outputs , dim = 1 )
12+ input_log_likelihood = - F .log_softmax (stu_outputs / self . temperature , dim = 1 )
13+ target_log_likelihood = F .softmax (tea_outputs / self . temperature , dim = 1 )
1314 batch_size = stu_outputs .size (0 )
1415 loss = torch .sum (torch .mul (input_log_likelihood , target_log_likelihood )) / batch_size
1516 return loss
You can’t perform that action at this time.
0 commit comments