Skip to content

Commit 5667d5b

Browse files
ooplesclaude
andcommitted
fix: ensemble geometric mean, multimodal performance, and compatibility issues
Resolves multiple review comments: - EnsembleTeacherModel.cs:205 - Fix geometric mean for negative logits with sign tracking - CurriculumTeacherModel.cs:40 - Remove unused strategy parameter - MultiModalTeacherModel.cs:87 - Cache teacher logits to prevent O(n*m) repeated calls - PredictionModelBuilder.cs:891 - Replace ThrowIfNull with standard null check for TFM compatibility 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent d83d239 commit 5667d5b

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

src/KnowledgeDistillation/Teachers/CurriculumTeacherModel.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,11 @@ public class CurriculumTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>,
2929
/// Initializes a new instance of the CurriculumTeacherModel class.
3030
/// </summary>
3131
/// <param name="baseTeacher">The underlying teacher model.</param>
32-
/// <param name="strategy">Curriculum strategy (kept for backward compatibility, not used).</param>
33-
public CurriculumTeacherModel(
34-
ITeacherModel<Vector<T>, Vector<T>> baseTeacher,
35-
CurriculumStrategy strategy = CurriculumStrategy.EasyToHard)
32+
public CurriculumTeacherModel(ITeacherModel<Vector<T>, Vector<T>> baseTeacher)
3633
{
3734
_baseTeacher = baseTeacher ?? throw new ArgumentNullException(nameof(baseTeacher));
38-
// Note: strategy parameter maintained for backward compatibility but curriculum
39-
// logic should be implemented in the training strategy, not the teacher
35+
// Note: Curriculum logic is implemented in CurriculumDistillationStrategy,
36+
// not in the teacher model. This is a simple wrapper around the base teacher.
4037
}
4138

4239
/// <summary>

src/KnowledgeDistillation/Teachers/EnsembleTeacherModel.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,18 @@ private Vector<T> AggregateLogits(Vector<T>[] teacherLogits)
192192
case EnsembleAggregationMode.GeometricMean:
193193
// Geometric mean: (x1 * x2 * ... * xn)^(1/n)
194194
// For numerical stability, use log space: exp(mean(log(xi)))
195+
// For logits (which can be negative), track sign separately
195196
for (int i = 0; i < n; i++)
196197
{
197198
double logSum = 0;
199+
int sign = 1;
198200
for (int t = 0; t < _teachers.Length; t++)
199201
{
200202
double val = Convert.ToDouble(teacherLogits[t][i]);
203+
if (val < 0) sign *= -1; // Track overall sign
201204
logSum += Math.Log(Math.Abs(val) + 1e-10) * _weights![t];
202205
}
203-
result[i] = NumOps.FromDouble(Math.Exp(logSum));
206+
result[i] = NumOps.FromDouble(sign * Math.Exp(logSum));
204207
}
205208
break;
206209

src/KnowledgeDistillation/Teachers/MultiModalTeacherModel.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,20 @@ public override Vector<T> GetLogits(Vector<T> input)
7171
int n = _modalityTeachers[0].OutputDimension;
7272
var combined = new Vector<T>(n);
7373

74+
// Cache logits from each teacher to avoid repeated calls
75+
var teacherLogits = new Vector<T>[_modalityTeachers.Length];
76+
for (int i = 0; i < _modalityTeachers.Length; i++)
77+
{
78+
teacherLogits[i] = _modalityTeachers[i].GetLogits(input);
79+
}
80+
81+
// Combine weighted logits
7482
for (int j = 0; j < n; j++)
7583
{
7684
T sum = NumOps.Zero;
7785
for (int i = 0; i < _modalityTeachers.Length; i++)
7886
{
79-
var logits = _modalityTeachers[i].GetLogits(input);
80-
var weighted = NumOps.Multiply(logits[j], NumOps.FromDouble(_modalityWeights[i]));
87+
var weighted = NumOps.Multiply(teacherLogits[i][j], NumOps.FromDouble(_modalityWeights[i]));
8188
sum = NumOps.Add(sum, weighted);
8289
}
8390
combined[j] = sum;

src/PredictionModelBuilder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ public async Task<string> AskAgentAsync(string question)
885885
public IPredictionModelBuilder<T, TInput, TOutput> ConfigureKnowledgeDistillation(
886886
KnowledgeDistillationOptions<T, TInput, TOutput> options)
887887
{
888-
ArgumentNullException.ThrowIfNull(options);
888+
if (options == null) throw new ArgumentNullException(nameof(options));
889889
_knowledgeDistillationOptions = options;
890890
return this;
891891
}

0 commit comments

Comments
 (0)