@@ -897,7 +897,7 @@ public IPredictionModelBuilder<T, TInput, TOutput> ConfigureKnowledgeDistillatio
897897 /// <summary>
898898 /// Performs knowledge distillation training using the configured options.
899899 /// </summary>
900- private async Task < OptimizationResult < T , TInput , TOutput > > PerformKnowledgeDistillationAsync (
900+ private Task < OptimizationResult < T , TInput , TOutput > > PerformKnowledgeDistillationAsync (
901901 IFullModel < T , TInput , TOutput > studentModel ,
902902 IOptimizer < T , TInput , TOutput > optimizer ,
903903 TInput XTrain ,
@@ -933,18 +933,34 @@ private async Task<OptimizationResult<T, TInput, TOutput>> PerformKnowledgeDisti
933933 }
934934 else if ( options . Teachers != null && options . Teachers . Length > 0 )
935935 {
936- // Use ensemble of teachers
936+ // Use ensemble of teachers - validate and convert to Vector-specialized types
937+ if ( typeof ( TInput ) != typeof ( Vector < T > ) || typeof ( TOutput ) != typeof ( Vector < T > ) )
938+ throw new InvalidOperationException (
939+ "Teachers array requires Vector<T> input/output types for distillation. " +
940+ "Ensure your model is configured with Vector<T> types." ) ;
941+
942+ var vectorTeachers = options . Teachers . Cast < ITeacherModel < Vector < T > , Vector < T > > > ( ) . ToArray ( ) ;
937943 teacher = KnowledgeDistillation . TeacherModelFactory < T > . CreateTeacher (
938944 TeacherModelType . Ensemble ,
939- ensembleModels : options . Teachers ,
945+ ensembleModels : vectorTeachers ,
940946 ensembleWeights : options . EnsembleWeights ) ;
941947 }
942948 else if ( options . TeacherForward != null )
943949 {
944- // Wrap forward function as teacher
950+ // Wrap forward function as teacher - validate and convert to Vector-specialized types
951+ if ( typeof ( TInput ) != typeof ( Vector < T > ) || typeof ( TOutput ) != typeof ( Vector < T > ) )
952+ throw new InvalidOperationException (
953+ "TeacherForward requires Vector<T> input/output types for distillation. " +
954+ "Ensure your model is configured with Vector<T> types." ) ;
955+
945956 int outputDim = options . OutputDimension ?? 10 ;
957+ var vectorForward = options . TeacherForward as Func < Vector < T > , Vector < T > > ;
958+ if ( vectorForward == null )
959+ throw new InvalidOperationException (
960+ "TeacherForward must be Func<Vector<T>, Vector<T>> for distillation." ) ;
961+
946962 teacher = new KnowledgeDistillation . TeacherModelWrapper < T > (
947- options . TeacherForward ,
963+ vectorForward ,
948964 outputDim ) ;
949965 }
950966 else
@@ -1192,15 +1208,15 @@ private async Task<OptimizationResult<T, TInput, TOutput>> PerformKnowledgeDisti
11921208 }
11931209
11941210 // Step 7: Return result using optimizer's infrastructure
1195- return optimizer . Optimize ( OptimizerHelper < T , TInput , TOutput > . CreateOptimizationInputData (
1196- XTrain , yTrain , XVal , yVal , XTest , yTest ) ) ;
1211+ return Task . FromResult ( optimizer . Optimize ( OptimizerHelper < T , TInput , TOutput > . CreateOptimizationInputData (
1212+ XTrain , yTrain , XVal , yVal , XTest , yTest ) ) ) ;
11971213 }
11981214 catch ( Exception ex )
11991215 {
12001216 Console . WriteLine ( $ "Error setting up knowledge distillation: { ex . Message } ") ;
12011217 Console . WriteLine ( "Falling back to standard training." ) ;
1202- return optimizer . Optimize ( OptimizerHelper < T , TInput , TOutput > . CreateOptimizationInputData (
1203- XTrain , yTrain , XVal , yVal , XTest , yTest ) ) ;
1218+ return Task . FromResult ( optimizer . Optimize ( OptimizerHelper < T , TInput , TOutput > . CreateOptimizationInputData (
1219+ XTrain , yTrain , XVal , yVal , XTest , yTest ) ) ) ;
12041220 }
12051221 }
12061222
0 commit comments