Skip to content

Commit f5a28ea

Browse files
ooplesclaude
andcommitted
fix: resolve compilation errors in PredictionModelBuilder
Resolves review comment on PredictionModelBuilder.cs:1205 - Removed async modifier from PerformKnowledgeDistillationAsync (no await operations) - Added type validation and casting for Teachers array to Vector-specialized types - Added type validation and casting for TeacherForward function to Vector-specialized types - Wrapped return statements in Task.FromResult to match Task<> return type 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 5f82106 commit f5a28ea

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

src/PredictionModelBuilder.cs

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ public IPredictionModelBuilder<T, TInput, TOutput> ConfigureKnowledgeDistillatio
897897
/// <summary>
898898
/// Performs knowledge distillation training using the configured options.
899899
/// </summary>
900-
private async Task<OptimizationResult<T, TInput, TOutput>> PerformKnowledgeDistillationAsync(
900+
private Task<OptimizationResult<T, TInput, TOutput>> PerformKnowledgeDistillationAsync(
901901
IFullModel<T, TInput, TOutput> studentModel,
902902
IOptimizer<T, TInput, TOutput> optimizer,
903903
TInput XTrain,
@@ -933,18 +933,34 @@ private async Task<OptimizationResult<T, TInput, TOutput>> PerformKnowledgeDisti
933933
}
934934
else if (options.Teachers != null && options.Teachers.Length > 0)
935935
{
936-
// Use ensemble of teachers
936+
// Use ensemble of teachers - validate and convert to Vector-specialized types
937+
if (typeof(TInput) != typeof(Vector<T>) || typeof(TOutput) != typeof(Vector<T>))
938+
throw new InvalidOperationException(
939+
"Teachers array requires Vector<T> input/output types for distillation. " +
940+
"Ensure your model is configured with Vector<T> types.");
941+
942+
var vectorTeachers = options.Teachers.Cast<ITeacherModel<Vector<T>, Vector<T>>>().ToArray();
937943
teacher = KnowledgeDistillation.TeacherModelFactory<T>.CreateTeacher(
938944
TeacherModelType.Ensemble,
939-
ensembleModels: options.Teachers,
945+
ensembleModels: vectorTeachers,
940946
ensembleWeights: options.EnsembleWeights);
941947
}
942948
else if (options.TeacherForward != null)
943949
{
944-
// Wrap forward function as teacher
950+
// Wrap forward function as teacher - validate and convert to Vector-specialized types
951+
if (typeof(TInput) != typeof(Vector<T>) || typeof(TOutput) != typeof(Vector<T>))
952+
throw new InvalidOperationException(
953+
"TeacherForward requires Vector<T> input/output types for distillation. " +
954+
"Ensure your model is configured with Vector<T> types.");
955+
945956
int outputDim = options.OutputDimension ?? 10;
957+
var vectorForward = options.TeacherForward as Func<Vector<T>, Vector<T>>;
958+
if (vectorForward == null)
959+
throw new InvalidOperationException(
960+
"TeacherForward must be Func<Vector<T>, Vector<T>> for distillation.");
961+
946962
teacher = new KnowledgeDistillation.TeacherModelWrapper<T>(
947-
options.TeacherForward,
963+
vectorForward,
948964
outputDim);
949965
}
950966
else
@@ -1192,15 +1208,15 @@ private async Task<OptimizationResult<T, TInput, TOutput>> PerformKnowledgeDisti
11921208
}
11931209

11941210
// Step 7: Return result using optimizer's infrastructure
1195-
return optimizer.Optimize(OptimizerHelper<T, TInput, TOutput>.CreateOptimizationInputData(
1196-
XTrain, yTrain, XVal, yVal, XTest, yTest));
1211+
return Task.FromResult(optimizer.Optimize(OptimizerHelper<T, TInput, TOutput>.CreateOptimizationInputData(
1212+
XTrain, yTrain, XVal, yVal, XTest, yTest)));
11971213
}
11981214
catch (Exception ex)
11991215
{
12001216
Console.WriteLine($"Error setting up knowledge distillation: {ex.Message}");
12011217
Console.WriteLine("Falling back to standard training.");
1202-
return optimizer.Optimize(OptimizerHelper<T, TInput, TOutput>.CreateOptimizationInputData(
1203-
XTrain, yTrain, XVal, yVal, XTest, yTest));
1218+
return Task.FromResult(optimizer.Optimize(OptimizerHelper<T, TInput, TOutput>.CreateOptimizationInputData(
1219+
XTrain, yTrain, XVal, yVal, XTest, yTest)));
12041220
}
12051221
}
12061222

0 commit comments

Comments
 (0)