File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed
python/paddle/fluid/contrib/slim/distillation Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -264,11 +264,14 @@ def apply(self, graph):
264
264
265
265
student_feature_map = ret_graph .var (self .student_feature_map )._var
266
266
teacher_feature_map = ret_graph .var (self .teacher_feature_map )._var
267
- s_fea = student_feature_map / self .student_temperature
268
- t_fea = teacher_feature_map / self .teacher_temperature
267
+ s_fea = layers .softmax (student_feature_map /
268
+ self .student_temperature )
269
+ t_fea = layers .softmax (teacher_feature_map /
270
+ self .teacher_temperature )
269
271
t_fea .stop_gradient = True
270
- ce_loss = layers .softmax_with_cross_entropy (
271
- s_fea , t_fea , soft_label = True )
272
+ ce_loss = layres .reduce_mean (
273
+ layers .cross_entropy (
274
+ s_fea , t_fea , soft_label = True ))
272
275
distillation_loss = ce_loss * self .distillation_loss_weight
273
276
student_loss = 0
274
277
if 'loss' in ret_graph .out_nodes :
You can’t perform that action at this time.
0 commit comments