Skip to content

Commit 06dad2d

Browse files
committed
fix: add comprehensive validation to MultiModalTeacherModel constructor
The constructor was dividing by modalityTeachers.Length (line 24) and accessing _modalityTeachers[0] (lines 14, 36) without validating the array, which could cause DivideByZeroException, IndexOutOfRangeException, or inconsistent behavior. Added validation to fail fast in the constructor: 1. Check that modality teachers array is non-empty (Length > 0) 2. Check that no teacher is null 3. Check that all teachers have the same OutputDimension 4. Improved modalityWeights validation with clearer error message This prevents runtime errors when OutputDimension or GetLogits is invoked and provides clear, actionable error messages indicating which validation failed. Similar to the fix applied to DistributedTeacherModel.
1 parent 38b49eb commit 06dad2d

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

src/KnowledgeDistillation/Teachers/MultiModalTeacherModel.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,40 @@ public MultiModalTeacherModel(
1919
{
2020
_modalityTeachers = modalityTeachers ?? throw new ArgumentNullException(nameof(modalityTeachers));
2121

22+
// Validate modality teachers array is non-empty
23+
if (_modalityTeachers.Length == 0)
24+
throw new ArgumentException("Modality teachers array cannot be empty", nameof(modalityTeachers));
25+
26+
// Validate no teacher is null
27+
for (int i = 0; i < _modalityTeachers.Length; i++)
28+
{
29+
if (_modalityTeachers[i] == null)
30+
throw new ArgumentException($"Modality teacher at index {i} is null", nameof(modalityTeachers));
31+
}
32+
33+
// Validate all teachers have the same output dimension
34+
int expectedOutputDim = _modalityTeachers[0].OutputDimension;
35+
for (int i = 1; i < _modalityTeachers.Length; i++)
36+
{
37+
if (_modalityTeachers[i].OutputDimension != expectedOutputDim)
38+
throw new ArgumentException(
39+
$"Modality teacher at index {i} has OutputDimension {_modalityTeachers[i].OutputDimension}, " +
40+
$"but expected {expectedOutputDim} (from teacher 0). " +
41+
$"All modality teachers must have the same OutputDimension.",
42+
nameof(modalityTeachers));
43+
}
44+
45+
// Set or validate modality weights
2246
if (modalityWeights == null)
2347
{
2448
_modalityWeights = Enumerable.Repeat(1.0 / modalityTeachers.Length, modalityTeachers.Length).ToArray();
2549
}
2650
else
2751
{
2852
if (modalityWeights.Length != modalityTeachers.Length)
29-
throw new ArgumentException("Modality weights must match number of teachers");
53+
throw new ArgumentException(
54+
$"Modality weights length ({modalityWeights.Length}) must match number of teachers ({modalityTeachers.Length})",
55+
nameof(modalityWeights));
3056
_modalityWeights = modalityWeights;
3157
}
3258
}

0 commit comments

Comments
 (0)