Skip to content

Commit 6b30c8c

Browse files
committed
fix: use recorded student performance in AccuracyBased adaptive strategy
The AccuracyBased case in ComputeAdaptiveTemperature was hardcoded to difficulty = 0.5, ignoring the StudentPerformance dictionary. Now: - Add optional sampleIndex parameter to GetSoftPredictions overload - Pass sampleIndex through to ComputeAdaptiveTemperature - Look up recorded performance: difficulty = 1.0 - StudentPerformance[index] - Fallback to 0.5 (medium difficulty) if no performance data available - High performance -> low difficulty -> sharper temperature - Low performance -> high difficulty -> softer temperature
1 parent b0b3266 commit 6b30c8c

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

src/KnowledgeDistillation/Teachers/AdaptiveTeacherModel.cs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,23 @@ public override Vector<T> GetLogits(Vector<T> input)
117117
/// <param name="temperature">Base temperature (will be adapted).</param>
118118
/// <returns>Soft predictions with adaptive temperature.</returns>
119119
public override Vector<T> GetSoftPredictions(Vector<T> input, double temperature = 1.0)
120+
{
121+
return GetSoftPredictions(input, temperature, sampleIndex: null);
122+
}
123+
124+
/// <summary>
125+
/// Gets soft predictions with adaptive temperature based on sample difficulty.
126+
/// </summary>
127+
/// <param name="input">Input data.</param>
128+
/// <param name="temperature">Base temperature (will be adapted).</param>
129+
/// <param name="sampleIndex">Optional sample index to look up recorded performance (for AccuracyBased strategy).</param>
130+
/// <returns>Soft predictions with adaptive temperature.</returns>
131+
public Vector<T> GetSoftPredictions(Vector<T> input, double temperature, int? sampleIndex)
120132
{
121133
var logits = GetLogits(input);
122134

123135
// Compute adaptive temperature based on strategy
124-
double adaptiveTemp = ComputeAdaptiveTemperature(logits, temperature);
136+
double adaptiveTemp = ComputeAdaptiveTemperature(logits, temperature, sampleIndex);
125137

126138
return ApplyTemperatureSoftmax(logits, adaptiveTemp);
127139
}
@@ -176,7 +188,10 @@ public void UpdateStudentPerformance(int sampleIndex, Vector<T> studentPredictio
176188
/// <summary>
177189
/// Computes adaptive temperature based on sample difficulty.
178190
/// </summary>
179-
private double ComputeAdaptiveTemperature(Vector<T> logits, double baseTemperature)
191+
/// <param name="logits">Raw model outputs.</param>
192+
/// <param name="baseTemperature">Base temperature to scale.</param>
193+
/// <param name="sampleIndex">Optional sample index to look up recorded performance.</param>
194+
private double ComputeAdaptiveTemperature(Vector<T> logits, double baseTemperature, int? sampleIndex)
180195
{
181196
// Convert logits to normalized probabilities
182197
var probs = ApplyTemperatureSoftmax(logits, 1.0);
@@ -196,8 +211,18 @@ private double ComputeAdaptiveTemperature(Vector<T> logits, double baseTemperatu
196211
break;
197212

198213
case AdaptiveStrategy.AccuracyBased:
199-
// Use stored performance (default to medium difficulty)
200-
difficulty = 0.5;
214+
// Use stored performance for this sample if available
215+
// High performance (1.0) -> low difficulty (0.0) -> min temp (sharper)
216+
// Low performance (0.0) -> high difficulty (1.0) -> max temp (softer)
217+
if (sampleIndex.HasValue && StudentPerformance.ContainsKey(sampleIndex.Value))
218+
{
219+
difficulty = 1.0 - StudentPerformance[sampleIndex.Value];
220+
}
221+
else
222+
{
223+
// Fallback to medium difficulty if no performance data
224+
difficulty = 0.5;
225+
}
201226
break;
202227
}
203228

0 commit comments

Comments
 (0)