Skip to content

Commit 0b531ca

Browse files
committed
fix: resolve nullable type inference compilation error in trainbatch
Replace null-conditional operator trueLabels?[i] with explicit null check to avoid 'TOutput cannot be made nullable' compiler error. Changes: - Use explicit null check with hasLabel flag instead of trueLabels?[i] - Pass default(TOutput) when hasLabel is false - Preserves original intent without requiring nullable constraints on TOutput This fixes the compilation blocker at line 153. Resolves coderabbitai review comment on line 172 (formerly 112).
1 parent 87c8373 commit 0b531ca

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/KnowledgeDistillation/KnowledgeDistillationTrainerBase.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,13 @@ public virtual T TrainBatch(
150150
for (int i = 0; i < inputs.Length; i++)
151151
{
152152
var input = inputs[i];
153-
var label = trueLabels?[i];
153+
TOutput label = default;
154+
bool hasLabel = false;
155+
if (trueLabels != null)
156+
{
157+
label = trueLabels[i];
158+
hasLabel = true;
159+
}
154160

155161
// Student forward pass
156162
var studentOutput = studentForward(input);
@@ -159,8 +165,8 @@ public virtual T TrainBatch(
159165
var teacherOutput = GetTeacherPredictions(input, i);
160166

161167
// Compute loss and gradient
162-
var loss = DistillationStrategy.ComputeLoss(studentOutput, teacherOutput, label);
163-
var gradient = DistillationStrategy.ComputeGradient(studentOutput, teacherOutput, label);
168+
var loss = DistillationStrategy.ComputeLoss(studentOutput, teacherOutput, hasLabel ? label : default);
169+
var gradient = DistillationStrategy.ComputeGradient(studentOutput, teacherOutput, hasLabel ? label : default);
164170

165171
totalLoss = NumOps.Add(totalLoss, loss);
166172

0 commit comments

Comments
 (0)