Skip to content

Commit 0c9ea96

Browse files
committed
fix: resolve SelfDistillationTrainer cached prediction alignment with shuffled batches
CRITICAL FIX: The previous implementation used array indices to look up cached teacher predictions, but TrainBatch receives batch-local indices (0 to batchSize-1), not global dataset indices. After data shuffling, this caused the student to learn from incorrect teacher outputs, making self-distillation ineffective or harmful. Changes: - Changed _cachedTeacherPredictions from Vector<Vector<T>> to Dictionary<Vector<T>, Vector<T>> - Use ReferenceEqualityComparer to map input instances to their cached predictions - GetTeacherPredictions now looks up by input reference (not index) to handle shuffled data - Updated EMA blending logic to work with dictionary structure This ensures that regardless of data shuffling, each input sample is matched with its correct cached teacher prediction from the previous generation. Addresses: Code review comment from @coderabbitai on SelfDistillationTrainer.cs:114-124
1 parent bfb3263 commit 0c9ea96

File tree

1 file changed

+49
-24
lines changed

1 file changed

+49
-24
lines changed

src/KnowledgeDistillation/SelfDistillationTrainer.cs

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace AiDotNet.KnowledgeDistillation;
5151
public class SelfDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>
5252
{
5353
private readonly int _generations;
54-
private Vector<Vector<T>>? _cachedTeacherPredictions;
54+
private Dictionary<Vector<T>, Vector<T>>? _cachedTeacherPredictions;
5555

5656
/// <summary>
5757
/// Gets or sets whether to use exponential moving average for teacher predictions.
@@ -102,25 +102,33 @@ public SelfDistillationTrainer(
102102
}
103103

104104
/// <summary>
105-
/// Gets teacher predictions from the cached predictions array (for self-distillation).
105+
/// Gets teacher predictions from the cached predictions dictionary (for self-distillation).
106106
/// </summary>
107-
/// <param name="input">The input data (unused - predictions are cached).</param>
108-
/// <param name="index">The index in the training batch.</param>
109-
/// <returns>Cached teacher prediction for this index.</returns>
107+
/// <param name="input">The input data to look up cached predictions for.</param>
108+
/// <param name="index">The index in the training batch (unused - we use input for lookup).</param>
109+
/// <returns>Cached teacher prediction for this input.</returns>
110110
/// <remarks>
111111
/// <para><b>For Self-Distillation:</b> Instead of calling a separate teacher model,
112-
/// we return predictions that were cached from the previous generation.</para>
112+
/// we return predictions that were cached from the previous generation. We use the input
113+
/// itself as the key (via reference equality) to handle shuffled batches correctly.</para>
113114
/// </remarks>
114115
protected override Vector<T> GetTeacherPredictions(Vector<T> input, int index)
115116
{
116-
if (_cachedTeacherPredictions == null || index >= _cachedTeacherPredictions.Length)
117+
if (_cachedTeacherPredictions == null)
117118
{
118-
// First generation or out of bounds - return input as dummy teacher
119+
// First generation - return dummy teacher
119120
// This will be handled properly in TrainMultipleGenerations
120121
return new Vector<T>(Teacher.OutputDimension);
121122
}
122123

123-
return _cachedTeacherPredictions[index];
124+
// Look up by input reference to handle shuffled data correctly
125+
if (_cachedTeacherPredictions.TryGetValue(input, out var cachedPrediction))
126+
{
127+
return cachedPrediction;
128+
}
129+
130+
// Fallback for inputs not in cache (shouldn't happen in normal flow)
131+
return new Vector<T>(Teacher.OutputDimension);
124132
}
125133

126134
/// <summary>
@@ -174,33 +182,50 @@ public void TrainMultipleGenerations(
174182
// Cache predictions from previous generation (if not first generation)
175183
if (generation > 0)
176184
{
177-
var newPredictions = new Vector<Vector<T>>(trainInputs.Length);
185+
// Use reference equality comparer to map input instances to their predictions
186+
var newPredictions = new Dictionary<Vector<T>, Vector<T>>(ReferenceEqualityComparer.Instance);
178187
for (int i = 0; i < trainInputs.Length; i++)
179188
{
180-
newPredictions[i] = modelForward(trainInputs[i]);
189+
var input = trainInputs[i];
190+
newPredictions[input] = modelForward(input);
181191
}
182192

183193
// Apply EMA if enabled
184194
if (UseEMA && _cachedTeacherPredictions != null)
185195
{
186-
for (int i = 0; i < trainInputs.Length; i++)
196+
var blendedPredictions = new Dictionary<Vector<T>, Vector<T>>(ReferenceEqualityComparer.Instance);
197+
foreach (var kvp in newPredictions)
187198
{
188-
var blended = new Vector<T>(newPredictions[i].Length);
189-
for (int j = 0; j < newPredictions[i].Length; j++)
199+
var input = kvp.Key;
200+
var newPred = kvp.Value;
201+
202+
if (_cachedTeacherPredictions.TryGetValue(input, out var oldPred))
190203
{
191-
var oldValue = NumOps.Multiply(
192-
_cachedTeacherPredictions[i][j],
193-
NumOps.FromDouble(EMADecay));
194-
var newValue = NumOps.Multiply(
195-
newPredictions[i][j],
196-
NumOps.FromDouble(1.0 - EMADecay));
197-
blended[j] = NumOps.Add(oldValue, newValue);
204+
var blended = new Vector<T>(newPred.Length);
205+
for (int j = 0; j < newPred.Length; j++)
206+
{
207+
var oldValue = NumOps.Multiply(
208+
oldPred[j],
209+
NumOps.FromDouble(EMADecay));
210+
var newValue = NumOps.Multiply(
211+
newPred[j],
212+
NumOps.FromDouble(1.0 - EMADecay));
213+
blended[j] = NumOps.Add(oldValue, newValue);
214+
}
215+
blendedPredictions[input] = blended;
216+
}
217+
else
218+
{
219+
// No old prediction, use new one as-is
220+
blendedPredictions[input] = newPred;
198221
}
199-
newPredictions[i] = blended;
200222
}
223+
_cachedTeacherPredictions = blendedPredictions;
224+
}
225+
else
226+
{
227+
_cachedTeacherPredictions = newPredictions;
201228
}
202-
203-
_cachedTeacherPredictions = newPredictions;
204229
}
205230

206231
// Train using base class Train method

0 commit comments

Comments
 (0)