Skip to content

Commit 2bc248f

Browse files
committed
fix: remove double-scaling of distributionWeight in ProbabilisticDistillationStrategy
The standard distillation term was being scaled by (1 - _distributionWeight) twice: 1. Once when computing softLoss (line 63) 2. Again when multiplying combinedLoss (line 73 for trueLabels case) This double-scaling incorrectly reduced the soft component by (1-distributionWeight)^2 instead of (1-distributionWeight). Fixed by: ComputeLoss: - Removed distributionWeight from initial softLoss scaling (line 63) - Computing finalLoss (either combinedLoss or softLoss) - Applying (1.0 - _distributionWeight) scaling exactly once at the end ComputeGradient: - Removed distributionWeight from soft gradient scaling - Removed distributionWeight from hard gradient blending - Computing combined gradient: Alpha * hardGrad + (1 - Alpha) * softGrad - Applying (1.0 - _distributionWeight) scaling exactly once per element Now distributionWeight correctly balances: - (1 - distributionWeight) * standard_distillation - distributionWeight * distributional_matching Note: KL divergence already uses correct direction KLDivergence(teacherSoft, studentSoft) which computes KL(teacher || student) matching the gradient computation.
1 parent c87817d commit 2bc248f

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

src/KnowledgeDistillation/Strategies/ProbabilisticDistillationStrategy.cs

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,29 @@ public override T ComputeLoss(Vector<T> studentOutput, Vector<T> teacherOutput,
5656
{
5757
ValidateOutputDimensions(studentOutput, teacherOutput, v => v.Length);
5858

59-
// Standard distillation loss (weighted)
59+
// Standard distillation loss
6060
var studentSoft = Softmax(studentOutput, Temperature);
6161
var teacherSoft = Softmax(teacherOutput, Temperature);
6262
var softLoss = KLDivergence(teacherSoft, studentSoft);
63-
softLoss = NumOps.Multiply(softLoss, NumOps.FromDouble(Temperature * Temperature * (1.0 - _distributionWeight)));
63+
softLoss = NumOps.Multiply(softLoss, NumOps.FromDouble(Temperature * Temperature));
6464

65+
T finalLoss;
6566
if (trueLabels != null)
6667
{
6768
ValidateLabelDimensions(studentOutput, trueLabels, v => v.Length);
6869
var studentProbs = Softmax(studentOutput, 1.0);
6970
var hardLoss = CrossEntropy(studentProbs, trueLabels);
70-
var combinedLoss = NumOps.Add(
71+
finalLoss = NumOps.Add(
7172
NumOps.Multiply(NumOps.FromDouble(Alpha), hardLoss),
7273
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), softLoss));
73-
return NumOps.Multiply(combinedLoss, NumOps.FromDouble(1.0 - _distributionWeight));
74+
}
75+
else
76+
{
77+
finalLoss = softLoss;
7478
}
7579

76-
return softLoss;
80+
// Apply distribution weight reduction exactly once
81+
return NumOps.Multiply(finalLoss, NumOps.FromDouble(1.0 - _distributionWeight));
7782
}
7883

7984
public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> teacherOutput, Vector<T>? trueLabels = null)
@@ -86,23 +91,39 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
8691
var studentSoft = Softmax(studentOutput, Temperature);
8792
var teacherSoft = Softmax(teacherOutput, Temperature);
8893

89-
for (int i = 0; i < n; i++)
90-
{
91-
var diff = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
92-
gradient[i] = NumOps.Multiply(diff, NumOps.FromDouble(Temperature * Temperature * (1.0 - _distributionWeight)));
93-
}
94-
9594
if (trueLabels != null)
9695
{
9796
ValidateLabelDimensions(studentOutput, trueLabels, v => v.Length);
9897
var studentProbs = Softmax(studentOutput, 1.0);
9998

10099
for (int i = 0; i < n; i++)
101100
{
101+
// Soft gradient (temperature-scaled)
102+
var softGrad = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
103+
softGrad = NumOps.Multiply(softGrad, NumOps.FromDouble(Temperature * Temperature));
104+
105+
// Hard gradient
102106
var hardGrad = NumOps.Subtract(studentProbs[i], trueLabels[i]);
103-
gradient[i] = NumOps.Add(
104-
NumOps.Multiply(NumOps.FromDouble(Alpha * (1.0 - _distributionWeight)), hardGrad),
105-
NumOps.Multiply(NumOps.FromDouble((1.0 - Alpha) * (1.0 - _distributionWeight)), gradient[i]));
107+
108+
// Combined gradient: Alpha * hardGrad + (1 - Alpha) * softGrad
109+
var combined = NumOps.Add(
110+
NumOps.Multiply(NumOps.FromDouble(Alpha), hardGrad),
111+
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), softGrad));
112+
113+
// Apply distribution weight reduction exactly once
114+
gradient[i] = NumOps.Multiply(combined, NumOps.FromDouble(1.0 - _distributionWeight));
115+
}
116+
}
117+
else
118+
{
119+
for (int i = 0; i < n; i++)
120+
{
121+
// Soft gradient (temperature-scaled)
122+
var softGrad = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
123+
softGrad = NumOps.Multiply(softGrad, NumOps.FromDouble(Temperature * Temperature));
124+
125+
// Apply distribution weight reduction exactly once
126+
gradient[i] = NumOps.Multiply(softGrad, NumOps.FromDouble(1.0 - _distributionWeight));
106127
}
107128
}
108129

0 commit comments

Comments
 (0)