Skip to content

Commit baa7938

Browse files
committed
refactor: improve gradient computation structure in FactorTransferDistillationStrategy
Refactored ComputeGradient to compute combined gradients first, then apply (1.0 - _factorWeight) scaling exactly once per element. This eliminates the separate final loop and makes the logic clearer. Changes: - Compute softGrad = temperature-scaled soft difference (without factorWeight) - Compute hardGrad = studentProbs - trueLabels - Form combined = Alpha * hardGrad + (1 - Alpha) * softGrad - Multiply final combined by (1.0 - _factorWeight) before assigning to gradient[i] - Handles both trueLabels != null and trueLabels == null cases cleanly This ensures (1.0 - _factorWeight) is applied exactly once per gradient element in a single assignment, improving clarity and efficiency.
1 parent 5402416 commit baa7938

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

src/KnowledgeDistillation/Strategies/FactorTransferDistillationStrategy.cs

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,30 +110,40 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
110110
var studentSoft = Softmax(studentOutput, Temperature);
111111
var teacherSoft = Softmax(teacherOutput, Temperature);
112112

113-
for (int i = 0; i < n; i++)
114-
{
115-
var diff = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
116-
gradient[i] = NumOps.Multiply(diff, NumOps.FromDouble(Temperature * Temperature));
117-
}
118-
119113
if (trueLabels != null)
120114
{
121115
ValidateLabelDimensions(studentOutput, trueLabels, v => v.Length);
122116
var studentProbs = Softmax(studentOutput, 1.0);
123117

124118
for (int i = 0; i < n; i++)
125119
{
120+
// Soft gradient (temperature-scaled)
121+
var softGrad = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
122+
softGrad = NumOps.Multiply(softGrad, NumOps.FromDouble(Temperature * Temperature));
123+
124+
// Hard gradient
126125
var hardGrad = NumOps.Subtract(studentProbs[i], trueLabels[i]);
127-
gradient[i] = NumOps.Add(
126+
127+
// Combined gradient: Alpha * hardGrad + (1 - Alpha) * softGrad
128+
var combined = NumOps.Add(
128129
NumOps.Multiply(NumOps.FromDouble(Alpha), hardGrad),
129-
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), gradient[i]));
130+
NumOps.Multiply(NumOps.FromDouble(1.0 - Alpha), softGrad));
131+
132+
// Apply factor weight reduction exactly once
133+
gradient[i] = NumOps.Multiply(combined, NumOps.FromDouble(1.0 - _factorWeight));
130134
}
131135
}
132-
133-
// Apply factor weight reduction exactly once
134-
for (int i = 0; i < n; i++)
136+
else
135137
{
136-
gradient[i] = NumOps.Multiply(gradient[i], NumOps.FromDouble(1.0 - _factorWeight));
138+
for (int i = 0; i < n; i++)
139+
{
140+
// Soft gradient (temperature-scaled)
141+
var softGrad = NumOps.Subtract(studentSoft[i], teacherSoft[i]);
142+
softGrad = NumOps.Multiply(softGrad, NumOps.FromDouble(Temperature * Temperature));
143+
144+
// Apply factor weight reduction exactly once
145+
gradient[i] = NumOps.Multiply(softGrad, NumOps.FromDouble(1.0 - _factorWeight));
146+
}
137147
}
138148

139149
return gradient;

0 commit comments

Comments
 (0)