@@ -87,30 +87,31 @@ def __init__(self, model, model_old, device, opts, trainer_state=None, classes=N
8787 self .lkd_flag = self .lkd > 0. and model_old is not None
8888 self .kd_need_labels = False
8989 self .lgkd_flag = opts .lgkd
90- if opts .unkd :
91- self .lkd_loss = UnbiasedKnowledgeDistillationLoss (reduction = "none" , alpha = opts .alpha )
92- elif opts .lgkd :
93- self .lkd_loss = LabelGuidedKnowledgeDistillationLoss (alpha = opts .alpha ,
94- prev_kd = opts .prev_kd ,
95- novel_kd = opts .novel_kd )
96- elif opts .kd_bce_sig :
97- self .lkd_loss = BCESigmoid (reduction = "none" , alpha = opts .alpha , shape = opts .kd_bce_sig_shape )
98- elif opts .exkd_gt and self .old_classes > 0 and self .step > 0 :
99- self .lkd_loss = ExcludedKnowledgeDistillationLoss (
100- reduction = 'none' , index_new = self .old_classes , new_reduction = "gt" ,
101- initial_nb_classes = opts .inital_nb_classes ,
102- temperature_semiold = opts .temperature_semiold
103- )
104- self .kd_need_labels = True
105- elif opts .exkd_sum and self .old_classes > 0 and self .step > 0 :
106- self .lkd_loss = ExcludedKnowledgeDistillationLoss (
107- reduction = 'none' , index_new = self .old_classes , new_reduction = "sum" ,
108- initial_nb_classes = opts .inital_nb_classes ,
109- temperature_semiold = opts .temperature_semiold
110- )
111- self .kd_need_labels = True
112- else :
113- self .lkd_loss = KnowledgeDistillationLoss (alpha = opts .alpha )
90+ if self .step > 0 :
91+ if opts .unkd :
92+ self .lkd_loss = UnbiasedKnowledgeDistillationLoss (reduction = "none" , alpha = opts .alpha )
93+ elif opts .lgkd :
94+ self .lkd_loss = LabelGuidedKnowledgeDistillationLoss (alpha = opts .alpha ,
95+ prev_kd = opts .prev_kd ,
96+ novel_kd = opts .novel_kd )
97+ elif opts .kd_bce_sig :
98+ self .lkd_loss = BCESigmoid (reduction = "none" , alpha = opts .alpha , shape = opts .kd_bce_sig_shape )
99+ elif opts .exkd_gt and self .old_classes > 0 :
100+ self .lkd_loss = ExcludedKnowledgeDistillationLoss (
101+ reduction = 'none' , index_new = self .old_classes , new_reduction = "gt" ,
102+ initial_nb_classes = opts .inital_nb_classes ,
103+ temperature_semiold = opts .temperature_semiold
104+ )
105+ self .kd_need_labels = True
106+ elif opts .exkd_sum and self .old_classes > 0 :
107+ self .lkd_loss = ExcludedKnowledgeDistillationLoss (
108+ reduction = 'none' , index_new = self .old_classes , new_reduction = "sum" ,
109+ initial_nb_classes = opts .inital_nb_classes ,
110+ temperature_semiold = opts .temperature_semiold
111+ )
112+ self .kd_need_labels = True
113+ else :
114+ self .lkd_loss = KnowledgeDistillationLoss (alpha = opts .alpha )
114115
115116 # ICARL
116117 self .icarl_combined = False
0 commit comments