Skip to content

Commit 87c8373

Browse files
committed
fix: throw exception instead of guessing output dimension in teachermodelwrapper
Remove unreliable dummy prediction logic and silent default to 10 that could cause downstream training failures with incorrect label sizes. Changes: - Remove dummy inference attempt with Vector<T>(1) that fails for real models - Remove catch-all fallback that silently returns 10 - Throw InvalidOperationException with clear message if metadata lacks dimension info - Provide guidance to use explicit outputDimension constructor parameter This prevents corrupted training from mismatched one-hot label dimensions. Resolves coderabbitai review comment on line 132.
1 parent 214ff18 commit 87c8373

File tree

1 file changed

+14
-26
lines changed

1 file changed

+14
-26
lines changed

src/KnowledgeDistillation/TeacherModelWrapper.cs

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -100,36 +100,24 @@ public TeacherModelWrapper(IFullModel<T, Vector<T>, Vector<T>> model)
100100
private static int GetOutputDimensionFromModel(IFullModel<T, Vector<T>, Vector<T>> model)
101101
{
102102
// Try to infer output dimension from model metadata
103-
try
104-
{
105-
var metadata = model.GetMetadata();
103+
var metadata = model.GetMetadata();
106104

107-
// Check if metadata contains output dimension/class count
108-
if (metadata.TryGetValue("OutputDimension", out var outputDimValue) && outputDimValue is int outputDim && outputDim > 0)
109-
return outputDim;
105+
// Check if metadata contains output dimension/class count
106+
if (metadata.TryGetValue("OutputDimension", out var outputDimValue) && outputDimValue is int outputDim && outputDim > 0)
107+
return outputDim;
110108

111-
if (metadata.TryGetValue("NumClasses", out var numClassesValue) && numClassesValue is int numClasses && numClasses > 0)
112-
return numClasses;
109+
if (metadata.TryGetValue("NumClasses", out var numClassesValue) && numClassesValue is int numClasses && numClasses > 0)
110+
return numClasses;
113111

114-
if (metadata.TryGetValue("ClassCount", out var classCountValue) && classCountValue is int classCount && classCount > 0)
115-
return classCount;
112+
if (metadata.TryGetValue("ClassCount", out var classCountValue) && classCountValue is int classCount && classCount > 0)
113+
return classCount;
116114

117-
// If metadata doesn't contain dimension info, try inferring from a dummy prediction
118-
// Create a minimal dummy input vector (size 1) to get output shape
119-
var dummyInput = new Vector<T>(1);
120-
var dummyOutput = model.Predict(dummyInput);
121-
122-
if (dummyOutput != null && dummyOutput.Length > 0)
123-
return dummyOutput.Length;
124-
125-
// Ultimate fallback if all else fails
126-
return 10; // Common default for classification (e.g., CIFAR-10)
127-
}
128-
catch
129-
{
130-
// If any error occurs during inference, use reasonable default
131-
return 10; // Common default for classification tasks
132-
}
115+
// If metadata doesn't contain dimension info, we cannot reliably determine output dimension
116+
// Throw instead of guessing to prevent downstream errors with incorrect label sizes
117+
throw new InvalidOperationException(
118+
"Cannot determine output dimension from model metadata. " +
119+
"Please use the constructor overload that explicitly specifies outputDimension, " +
120+
"or ensure the model's GetMetadata() returns 'OutputDimension', 'NumClasses', or 'ClassCount'.");
133121
}
134122

135123
/// <summary>

0 commit comments

Comments
 (0)