Skip to content

Commit d83d239

Browse files
ooplesclaude
andcommitted
fix: add validation and fix curriculum/teacher issues
Resolves multiple review comments: - NeuralNetworkBase.cs:2153 - Implement SaveState/LoadState using Serialize/Deserialize - SelfDistillationTrainer.cs:69 - Validate EMADecay is between 0 and 1 - SelfDistillationTrainer.cs:82 - Fix generation documentation off-by-one error - CurriculumDistillationStrategyBase.cs:85 - Fix progress normalization to allow 100% - HardToEasyCurriculumStrategy.cs:190 - Update comments for reachable thresholds - TeacherModelFactory.cs:83 - Use Average aggregation when no weights supplied - TeacherModelFactory.cs:117 - Validate modality weights array length - TransformerTeacherModel.cs:39 - Validate outputDimension is positive - SelfTeacherModel.cs:45 - Validate cached predictions before storing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ab00d56 commit d83d239

File tree

7 files changed

+66
-11
lines changed

7 files changed

+66
-11
lines changed

src/KnowledgeDistillation/SelfDistillationTrainer.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,19 @@ public class SelfDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Ve
6666
/// <summary>
6767
/// Gets or sets the EMA decay rate (default 0.99). Higher values give more weight to history.
6868
/// </summary>
69-
public double EMADecay { get; set; }
69+
/// <exception cref="ArgumentOutOfRangeException">Thrown when value is not between 0 and 1.</exception>
70+
public double EMADecay
71+
{
72+
get => _emaDecay;
73+
set
74+
{
75+
if (value <= 0 || value >= 1)
76+
throw new ArgumentOutOfRangeException(nameof(value),
77+
"EMADecay must be between 0 and 1 (exclusive). Typical values are 0.9-0.999.");
78+
_emaDecay = value;
79+
}
80+
}
81+
private double _emaDecay = 0.99;
7082

7183
/// <summary>
7284
/// Initializes a new instance of the SelfDistillationTrainer class.
@@ -77,8 +89,9 @@ public class SelfDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Ve
7789
/// <param name="seed">Optional random seed for reproducibility.</param>
7890
/// <remarks>
7991
/// <para><b>For Beginners:</b> Generations control how many times the model relearns from itself:
80-
/// - 1 generation: Train normally, then retrain with self as teacher
81-
/// - 2 generations: Do it twice (teacher → student1 → student2)
92+
/// - 1 generation: Train normally (standard training, no self-distillation)
93+
/// - 2 generations: Train, then retrain using self as teacher (first self-distillation)
94+
/// - 3 generations: Train → self-teach → self-teach again
8295
/// - More generations: Diminishing returns, usually not worth it beyond 2-3</para>
8396
///
8497
/// <para>Example:

src/KnowledgeDistillation/Strategies/CurriculumDistillationStrategyBase.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ protected CurriculumDistillationStrategyBase(
8181
/// </summary>
8282
public virtual void UpdateProgress(int step)
8383
{
84-
_currentStep = Math.Max(0, Math.Min(step, TotalSteps - 1));
84+
// Don't clamp to TotalSteps - 1; allow reaching TotalSteps for full 100% progress
85+
_currentStep = Math.Max(0, Math.Min(step, TotalSteps));
8586
}
8687

8788
/// <summary>

src/KnowledgeDistillation/Strategies/HardToEasyCurriculumStrategy.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,11 @@ public override bool ShouldIncludeSample(int sampleIndex)
148148
if (difficulty == null)
149149
return true;
150150

151-
// Hard-to-Easy: Include samples with difficulty ≥ (1 - current progress)
152-
// Progress 0.0 → only hardest samples (difficulty ≥ 1.0)
153-
// Progress 1.0 → all samples (difficulty ≥ 0.0)
151+
// Hard-to-Easy: Include samples with difficulty >= (1 - current progress)
152+
// Progress 0.0 → only hardest samples (difficulty >= 1.0, threshold = 1.0)
153+
// Progress 0.5 → medium samples (difficulty >= 0.5, threshold = 0.5)
154+
// Progress 1.0 → all samples (difficulty >= 0.0, threshold = 0.0)
155+
// With progress now able to reach exactly 1.0, threshold can reach 0.0
154156
double threshold = 1.0 - CurriculumProgress;
155157
return difficulty.Value >= threshold;
156158
}

src/KnowledgeDistillation/TeacherModelFactory.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ private static ITeacherModel<Vector<T>, Vector<T>> CreateEnsembleTeacher(
7575
if (ensembleModels == null || ensembleModels.Length == 0)
7676
throw new ArgumentException("Ensemble models are required for Ensemble teacher type");
7777

78-
var aggregation = ensembleWeights != null
78+
// Use plain average when no weights are supplied
79+
var aggregation = ensembleWeights != null && ensembleWeights.Length > 0
7980
? EnsembleAggregationMode.WeightedAverage
80-
: EnsembleAggregationMode.WeightedAverage;
81+
: EnsembleAggregationMode.Average;
8182

8283
return new EnsembleTeacherModel<T>(ensembleModels, ensembleWeights, aggregation);
8384
}
@@ -112,6 +113,9 @@ private static ITeacherModel<Vector<T>, Vector<T>> CreateMultiModalTeacher(
112113
{
113114
if (modalityTeachers == null || modalityTeachers.Length == 0)
114115
throw new ArgumentException("Modality teachers are required for MultiModal teacher type");
116+
if (modalityWeights != null && modalityWeights.Length != modalityTeachers.Length)
117+
throw new ArgumentException(
118+
$"Modality weights length ({modalityWeights.Length}) must match modality teachers length ({modalityTeachers.Length})");
115119

116120
return new MultiModalTeacherModel<T>(modalityTeachers, modalityWeights);
117121
}

src/KnowledgeDistillation/Teachers/SelfTeacherModel.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,28 @@ public SelfTeacherModel(int outputDimension)
1818
_outputDim = outputDimension;
1919
}
2020

21+
/// <summary>
22+
/// Caches predictions from the student model for later use.
23+
/// </summary>
24+
/// <param name="predictions">The predictions to cache.</param>
25+
/// <exception cref="ArgumentNullException">Thrown when predictions is null.</exception>
26+
/// <exception cref="ArgumentException">Thrown when predictions array is empty or contains null/zero-length vectors.</exception>
2127
public void CachePredictions(Vector<T>[] predictions)
2228
{
29+
if (predictions == null)
30+
throw new ArgumentNullException(nameof(predictions));
31+
if (predictions.Length == 0)
32+
throw new ArgumentException("Predictions array cannot be empty.", nameof(predictions));
33+
34+
// Validate that all prediction vectors are non-null and non-empty
35+
for (int i = 0; i < predictions.Length; i++)
36+
{
37+
if (predictions[i] == null)
38+
throw new ArgumentException($"Prediction at index {i} is null.", nameof(predictions));
39+
if (predictions[i].Length == 0)
40+
throw new ArgumentException($"Prediction at index {i} has zero length.", nameof(predictions));
41+
}
42+
2343
_cachedPredictions = predictions;
2444
}
2545

src/KnowledgeDistillation/Teachers/TransformerTeacherModel.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,16 @@ public class TransformerTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>,
3030
/// </summary>
3131
/// <param name="forwardFunc">Function that performs forward pass and returns logits.</param>
3232
/// <param name="outputDimension">The number of output dimensions.</param>
33+
/// <exception cref="ArgumentNullException">Thrown when forwardFunc is null.</exception>
34+
/// <exception cref="ArgumentOutOfRangeException">Thrown when outputDimension is not positive.</exception>
3335
public TransformerTeacherModel(
3436
Func<Vector<T>, Vector<T>> forwardFunc,
3537
int outputDimension)
3638
{
3739
_forwardFunc = forwardFunc ?? throw new ArgumentNullException(nameof(forwardFunc));
40+
if (outputDimension <= 0)
41+
throw new ArgumentOutOfRangeException(nameof(outputDimension),
42+
"Output dimension must be positive.");
3843
_outputDim = outputDimension;
3944
}
4045

src/NeuralNetworks/NeuralNetworkBase.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,7 +2140,11 @@ public virtual void ApplyGradients(Vector<T> gradients, T learningRate)
21402140
/// <param name="stream">The stream to write the model state to.</param>
21412141
public virtual void SaveState(Stream stream)
21422142
{
2143-
throw new NotImplementedException("SaveState is not yet implemented for NeuralNetworkBase. Consider using explicit serialization of layer parameters.");
2143+
if (stream == null) throw new ArgumentNullException(nameof(stream));
2144+
if (!stream.CanWrite) throw new ArgumentException("Stream must be writable.", nameof(stream));
2145+
var data = Serialize();
2146+
stream.Write(data, 0, data.Length);
2147+
stream.Flush();
21442148
}
21452149

21462150
/// <summary>
@@ -2149,6 +2153,12 @@ public virtual void SaveState(Stream stream)
21492153
/// <param name="stream">The stream to read the model state from.</param>
21502154
public virtual void LoadState(Stream stream)
21512155
{
2152-
throw new NotImplementedException("LoadState is not yet implemented for NeuralNetworkBase. Consider using explicit deserialization of layer parameters.");
2156+
if (stream == null) throw new ArgumentNullException(nameof(stream));
2157+
if (!stream.CanRead) throw new ArgumentException("Stream must be readable.", nameof(stream));
2158+
using var ms = new MemoryStream();
2159+
stream.CopyTo(ms);
2160+
var data = ms.ToArray();
2161+
if (data.Length == 0) throw new InvalidOperationException("Stream contains no data.");
2162+
Deserialize(data);
21532163
}
21542164
}

0 commit comments

Comments
 (0)