Skip to content

Commit 5402416

Browse files
committed
fix: remove double-scaling of factorWeight in FactorTransferDistillationStrategy
The standard distillation term was being scaled by (1 - _factorWeight) twice: 1. Once when computing softLoss (line 82) 2. Again when multiplying combinedLoss (line 92 for trueLabels case) This double-scaling incorrectly reduced the soft component by (1-factorWeight)^2 instead of (1-factorWeight). Fixed by: - Removing factorWeight from initial softLoss scaling (line 82) - Computing finalLoss (either combinedLoss or softLoss) - Applying (1.0 - _factorWeight) scaling exactly once at the end Same fix applied to ComputeGradient method: - Removed factorWeight from soft gradient scaling (line 111) - Removed factorWeight from hard gradient blending (lines 123-124) - Applied (1.0 - _factorWeight) scaling exactly once at the end Now factorWeight correctly balances: - (1 - factorWeight) * standard_distillation - factorWeight * factor_transfer
1 parent 6f8b5a3 commit 5402416

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

src/KnowledgeDistillation/Strategies/FactorTransferDistillationStrategy.cs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,29 @@ public override T ComputeLoss(Vector<T> studentOutput, Vector<T> teacherOutput,
7575
{
7676
ValidateOutputDimensions(studentOutput, teacherOutput, v => v.Length);
7777

78-
// Standard distillation loss (weighted)
78+
// Standard distillation loss
7979
var studentSoft = Softmax(studentOutput, Temperature);
8080
var teacherSoft = Softmax(teacherOutput, Temperature);
8181
var softLoss = KLDivergence(teacherSoft, studentSoft);
82-
softLoss = NumOps.Multiply(softLoss, NumOps.FromDouble(Temperature * Temperature * (1.0 - _factorWeight)));
82+
softLoss = NumOps.Multiply(softLoss, NumOps.FromDouble(Temperature * Temperature));
8383

84+
T finalLoss;
8485
if (trueLabels != null)
8586
{
8687
ValidateLabelDimensions(studentOutput, trueLabels, v => v.Length);
8788
var studentProbs = Softmax(studentOutput, 1.0);
8889
var hardLoss = CrossEntropy(studentProbs, trueLabels);
89-
var combinedLoss = NumOps.Add(
90+
finalLoss = NumOps.Add(
9091
NumOps.Multiply(NumOps.FromDouble(Alpha), hardLoss),
9192
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), softLoss));
92-
return NumOps.Multiply(combinedLoss, NumOps.FromDouble(1.0 - _factorWeight));
93+
}
94+
else
95+
{
96+
finalLoss = softLoss;
9397
}
9498

95-
return softLoss;
99+
// Apply factor weight reduction exactly once
100+
return NumOps.Multiply(finalLoss, NumOps.FromDouble(1.0 - _factorWeight));
96101
}
97102

98103
public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> teacherOutput, Vector<T>? trueLabels = null)
@@ -108,7 +113,7 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
108113
for (int i = 0; i < n; i++)
109114
{
110115
var diff = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
111-
gradient[i] = NumOps.Multiply(diff, NumOps.FromDouble(Temperature * Temperature * (1.0 - _factorWeight)));
116+
gradient[i] = NumOps.Multiply(diff, NumOps.FromDouble(Temperature * Temperature));
112117
}
113118

114119
if (trueLabels != null)
@@ -120,11 +125,17 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
120125
{
121126
var hardGrad = NumOps.Subtract(studentProbs[i], trueLabels[i]);
122127
gradient[i] = NumOps.Add(
123-
NumOps.Multiply(NumOps.FromDouble(Alpha * (1.0 - _factorWeight)), hardGrad),
124-
NumOps.Multiply(NumOps.FromDouble((1.0 - Alpha) * (1.0 - _factorWeight)), gradient[i]));
128+
NumOps.Multiply(NumOps.FromDouble(Alpha), hardGrad),
129+
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), gradient[i]));
125130
}
126131
}
127132

133+
// Apply factor weight reduction exactly once
134+
for (int i = 0; i < n; i++)
135+
{
136+
gradient[i] = NumOps.Multiply(gradient[i], NumOps.FromDouble(1.0 - _factorWeight));
137+
}
138+
128139
return gradient;
129140
}
130141

0 commit comments

Comments
 (0)