Skip to content

Commit 214ff18

Browse files
committed
fix: restore classical kd signal in variational distillation strategy
Remove incorrect (1.0 - variationalWeight) multiplication that was suppressing the classical distillation loss and gradient instead of adding variational terms. This fixes the critical issue where: - When variationalWeight = 1.0, loss/gradient collapsed to zero - When variationalWeight < 1.0, only scaled classical KD without variational contribution Changes: - ComputeLoss: Remove (1.0 - variationalWeight) scaling from softLoss and combinedLoss - ComputeGradient: Remove (1.0 - variationalWeight) scaling from soft and hard gradients Note: This is a baseline fix. Full variational integration (adding variational loss/gradient weighted by variationalWeight) requires latent representations (mean, logVar) which are not available in current ComputeLoss/ComputeGradient signatures. Resolves coderabbitai review comments on lines 63-85 and 87-118.
1 parent 1dc2f47 commit 214ff18

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/KnowledgeDistillation/Strategies/VariationalDistillationStrategy.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public override T ComputeLoss(Vector<T> studentOutput, Vector<T> teacherOutput,
6868
var studentSoft = Softmax(studentOutput, Temperature);
6969
var teacherSoft = Softmax(teacherOutput, Temperature);
7070
var softLoss = KLDivergence(teacherSoft, studentSoft);
71-
softLoss = NumOps.Multiply(softLoss, NumOps.FromDouble(Temperature * Temperature * (1.0 - _variationalWeight)));
71+
softLoss = NumOps.Multiply(softLoss, NumOps.FromDouble(Temperature * Temperature));
7272

7373
if (trueLabels != null)
7474
{
@@ -78,7 +78,7 @@ public override T ComputeLoss(Vector<T> studentOutput, Vector<T> teacherOutput,
7878
var combinedLoss = NumOps.Add(
7979
NumOps.Multiply(NumOps.FromDouble(Alpha), hardLoss),
8080
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), softLoss));
81-
return NumOps.Multiply(combinedLoss, NumOps.FromDouble(1.0 - _variationalWeight));
81+
return combinedLoss;
8282
}
8383

8484
return softLoss;
@@ -97,7 +97,7 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
9797
for (int i = 0; i < n; i++)
9898
{
9999
var diff = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
100-
gradient[i] = NumOps.Multiply(diff, NumOps.FromDouble(Temperature * Temperature * (1.0 - _variationalWeight)));
100+
gradient[i] = NumOps.Multiply(diff, NumOps.FromDouble(Temperature * Temperature));
101101
}
102102

103103
if (trueLabels != null)
@@ -109,8 +109,8 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
109109
{
110110
var hardGrad = NumOps.Subtract(studentProbs[i], trueLabels[i]);
111111
gradient[i] = NumOps.Add(
112-
NumOps.Multiply(NumOps.FromDouble(Alpha * (1.0 - _variationalWeight)), hardGrad),
113-
NumOps.Multiply(NumOps.FromDouble((1.0 - Alpha) * (1.0 - _variationalWeight)), gradient[i]));
112+
NumOps.Multiply(NumOps.FromDouble(Alpha), hardGrad),
113+
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), gradient[i]));
114114
}
115115
}
116116

0 commit comments

Comments
 (0)