@@ -57,41 +57,19 @@ public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKn
5757 public IDistillationStrategy < T , TOutput > DistillationStrategy { get ; protected set ; }
5858
5959 /// <summary>
60- /// Gets or sets the checkpoint configuration for automatic model saving during training.
60+ /// Checkpoint configuration for automatic model saving during training (internal) .
6161 /// </summary>
62- /// <remarks>
63- /// <para><b>For Beginners:</b> Set this property to enable automatic checkpointing.
64- /// If null (default), no automatic checkpointing occurs. If set, the trainer will automatically:
65- /// - Save checkpoints based on your configuration (e.g., every 5 epochs)
66- /// - Keep only the best N checkpoints to save disk space
67- /// - Load the best checkpoint after training completes</para>
68- ///
69- /// <para><b>Example:</b>
70- /// <code>
71- /// trainer.CheckpointConfig = new DistillationCheckpointConfig
72- /// {
73- /// SaveEveryEpochs = 5,
74- /// KeepBestN = 3,
75- /// BestMetric = "validation_loss"
76- /// };
77- /// </code>
78- /// </para>
79- /// </remarks>
80- public DistillationCheckpointConfig ? CheckpointConfig { get ; set ; }
62+ private readonly DistillationCheckpointConfig ? _checkpointConfig ;
8163
8264 /// <summary>
83- /// Gets or sets the student model for checkpointing .
65+ /// Checkpoint manager for handling checkpoint operations (internal) .
8466 /// </summary>
85- /// <remarks>
86- /// <para>Set this if you want automatic checkpointing. The student must implement
87- /// <see cref="ICheckpointableModel"/> for checkpoint saving/loading to work.</para>
88- /// </remarks>
89- public ICheckpointableModel ? Student { get ; set ; }
67+ private DistillationCheckpointManager < T > ? _checkpointManager ;
9068
9169 /// <summary>
92- /// Gets the checkpoint manager (created automatically when CheckpointConfig is set ).
70+ /// Student model reference for checkpointing (internal ).
9371 /// </summary>
94- protected DistillationCheckpointManager < T > ? CheckpointManager { get ; private set ; }
72+ private ICheckpointableModel ? _student ;
9573
9674 private double _lastValidationMetric ;
9775 private T _lastTrainingLoss ;
@@ -102,22 +80,40 @@ public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKn
10280 /// </summary>
10381 /// <param name="teacher">The teacher model.</param>
10482 /// <param name="distillationStrategy">The distillation strategy.</param>
83+ /// <param name="checkpointConfig">Optional checkpoint configuration for automatic model saving during training.</param>
10584 /// <param name="seed">Optional random seed for reproducibility.</param>
10685 /// <remarks>
10786 /// <para><b>For Beginners:</b> The teacher and strategy are the core components:
10887 /// - Teacher: Provides the "expert" knowledge to transfer
10988 /// - Strategy: Defines how to measure and optimize the knowledge transfer</para>
11089 ///
111- /// <para><b>Automatic Checkpointing:</b> To enable automatic checkpointing, set the
112- /// <see cref="CheckpointConfig"/> and <see cref="Student"/> properties after construction.</para>
90+ /// <para><b>Automatic Checkpointing:</b> To enable automatic checkpointing, pass a
91+ /// <see cref="DistillationCheckpointConfig"/> instance. If null (default), no automatic checkpointing occurs.
92+ /// When enabled, the trainer will automatically:
93+ /// - Save checkpoints based on your configuration (e.g., every 5 epochs)
94+ /// - Keep only the best N checkpoints to save disk space
95+ /// - Load the best checkpoint after training completes</para>
96+ ///
97+ /// <para><b>Example with Checkpointing:</b>
98+ /// <code>
99+ /// var config = new DistillationCheckpointConfig
100+ /// {
101+ /// SaveEveryEpochs = 5,
102+ /// KeepBestN = 3
103+ /// };
104+ /// var trainer = new KnowledgeDistillationTrainer(teacher, strategy, checkpointConfig: config);
105+ /// </code>
106+ /// </para>
113107 /// </remarks>
114108 protected KnowledgeDistillationTrainerBase (
115109 ITeacherModel < TInput , TOutput > teacher ,
116110 IDistillationStrategy < T , TOutput > distillationStrategy ,
111+ DistillationCheckpointConfig ? checkpointConfig = null ,
117112 int ? seed = null )
118113 {
119114 Teacher = teacher ?? throw new ArgumentNullException ( nameof ( teacher ) ) ;
120115 DistillationStrategy = distillationStrategy ?? throw new ArgumentNullException ( nameof ( distillationStrategy ) ) ;
116+ _checkpointConfig = checkpointConfig ;
121117 NumOps = MathHelper . GetNumericOperations < T > ( ) ;
122118 Random = seed . HasValue ? new Random ( seed . Value ) : new Random ( ) ;
123119 _lastTrainingLoss = NumOps . Zero ;
@@ -186,6 +182,7 @@ public virtual T TrainBatch(
186182 /// <param name="batchSize">Batch size for mini-batch training.</param>
187183 /// <param name="validationInputs">Optional validation inputs for monitoring.</param>
188184 /// <param name="validationLabels">Optional validation labels.</param>
185+ /// <param name="student">Optional student model for automatic checkpointing (must implement ICheckpointableModel).</param>
189186 /// <param name="onEpochComplete">Optional callback invoked after each epoch with (epoch, avgLoss).</param>
190187 /// <remarks>
191188 /// <para><b>For Beginners:</b> This method orchestrates the complete training process:
@@ -201,6 +198,9 @@ public virtual T TrainBatch(
201198 /// - Use batch sizes that fit in memory (32-128 typical)
202199 /// - Monitor validation loss to detect overfitting
203200 /// - Invoke callbacks to log progress or save checkpoints</para>
201+ ///
202+ /// <para><b>Automatic Checkpointing:</b> If a checkpoint configuration was provided to the constructor
203+ /// and you pass the student model parameter, automatic checkpointing will be enabled.</para>
204204 /// </remarks>
205205 public virtual void Train (
206206 Func < TInput , TOutput > studentForward ,
@@ -211,6 +211,7 @@ public virtual void Train(
211211 int batchSize = 32 ,
212212 Vector < TInput > ? validationInputs = null ,
213213 Vector < TOutput > ? validationLabels = null ,
214+ ICheckpointableModel ? student = null ,
214215 Action < int , T > ? onEpochComplete = null )
215216 {
216217 if ( studentForward == null ) throw new ArgumentNullException ( nameof ( studentForward ) ) ;
@@ -226,6 +227,9 @@ public virtual void Train(
226227 if ( validationInputs != null && validationLabels != null && validationInputs . Length != validationLabels . Length )
227228 throw new ArgumentException ( "Validation inputs and labels must have the same length" ) ;
228229
230+ // Store student reference for checkpointing
231+ _student = student ;
232+
229233 // Prepare for training
230234 OnTrainingStart ( trainInputs , trainLabels ) ;
231235
@@ -303,7 +307,7 @@ public virtual void Train(
303307 int batchSize = 32 ,
304308 Action < int , T > ? onEpochComplete = null )
305309 {
306- Train ( studentForward , studentBackward , trainInputs , trainLabels , epochs , batchSize , null , null , onEpochComplete ) ;
310+ Train ( studentForward , studentBackward , trainInputs , trainLabels , epochs , batchSize , null , null , null , onEpochComplete ) ;
307311 }
308312
309313 /// <summary>
@@ -468,15 +472,15 @@ protected int ArgMax(Vector<T> vector)
468472 /// - Setup curriculum schedules
469473 /// - Allocate temporary buffers</para>
470474 ///
471- /// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
472- /// automatically initializes the checkpoint manager.</para>
475+ /// <para><b>Automatic Checkpointing:</b> If checkpoint configuration was provided to the constructor,
476+ /// this method automatically initializes the checkpoint manager.</para>
473477 /// </remarks>
474478 protected virtual void OnTrainingStart ( Vector < TInput > trainInputs , Vector < TOutput > ? trainLabels )
475479 {
476480 // Initialize checkpoint manager if config is provided
477- if ( CheckpointConfig != null )
481+ if ( _checkpointConfig != null )
478482 {
479- CheckpointManager = new DistillationCheckpointManager < T > ( CheckpointConfig ) ;
483+ _checkpointManager = new DistillationCheckpointManager < T > ( _checkpointConfig ) ;
480484 }
481485 }
482486
@@ -492,25 +496,25 @@ protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutpu
492496 /// - Log final metrics
493497 /// - Free temporary resources</para>
494498 ///
495- /// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
496- /// automatically loads the best checkpoint (based on validation metrics) after training completes.</para>
499+ /// <para><b>Automatic Checkpointing:</b> If checkpoint configuration was provided to the constructor,
500+ /// this method automatically loads the best checkpoint (based on validation metrics) after training completes.</para>
497501 /// </remarks>
498502 protected virtual void OnTrainingEnd ( Vector < TInput > trainInputs , Vector < TOutput > ? trainLabels )
499503 {
500504 // Load best checkpoint if checkpointing was enabled
501- if ( CheckpointManager != null && Student != null )
505+ if ( _checkpointManager != null && _student != null )
502506 {
503- var bestCheckpoint = CheckpointManager . LoadBestCheckpoint (
504- student : Student ,
507+ var bestCheckpoint = _checkpointManager . LoadBestCheckpoint (
508+ student : _student ,
505509 teacher : Teacher as ICheckpointableModel
506510 ) ;
507511
508512 if ( bestCheckpoint != null )
509513 {
510514 Console . WriteLine ( $ "[Checkpointing] Loaded best checkpoint from epoch { bestCheckpoint . Epoch } ") ;
511- if ( bestCheckpoint . Metrics . ContainsKey ( CheckpointConfig ! . BestMetric ) )
515+ if ( bestCheckpoint . Metrics . ContainsKey ( _checkpointConfig ! . BestMetric ) )
512516 {
513- Console . WriteLine ( $ "[Checkpointing] Best { CheckpointConfig . BestMetric } : { bestCheckpoint . Metrics [ CheckpointConfig . BestMetric ] : F4} ") ;
517+ Console . WriteLine ( $ "[Checkpointing] Best { _checkpointConfig . BestMetric } : { bestCheckpoint . Metrics [ _checkpointConfig . BestMetric ] : F4} ") ;
514518 }
515519 }
516520 }
@@ -550,8 +554,8 @@ protected virtual void OnEpochStart(int epoch, Vector<TInput> trainInputs, Vecto
550554 /// to flush partial batches and prevent buffer leakage between epochs. Derived classes should
551555 /// call base.OnEpochEnd() if they override this method.</para>
552556 ///
553- /// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
554- /// automatically saves checkpoints based on your configuration.</para>
557+ /// <para><b>Automatic Checkpointing:</b> If checkpoint configuration was provided to the constructor,
558+ /// this method automatically saves checkpoints based on your configuration.</para>
555559 /// </remarks>
556560 protected virtual void OnEpochEnd ( int epoch , T avgLoss )
557561 {
@@ -567,7 +571,7 @@ protected virtual void OnEpochEnd(int epoch, T avgLoss)
567571 _lastTrainingLoss = avgLoss ;
568572
569573 // Automatic checkpoint saving
570- if ( CheckpointManager != null )
574+ if ( _checkpointManager != null )
571575 {
572576 var metrics = new Dictionary < string , double >
573577 {
@@ -577,12 +581,12 @@ protected virtual void OnEpochEnd(int epoch, T avgLoss)
577581 // Include validation metric if available
578582 if ( _lastValidationMetric > 0 )
579583 {
580- metrics [ CheckpointConfig ! . BestMetric ] = _lastValidationMetric ;
584+ metrics [ _checkpointConfig ! . BestMetric ] = _lastValidationMetric ;
581585 }
582586
583- CheckpointManager . SaveCheckpointIfNeeded (
587+ _checkpointManager . SaveCheckpointIfNeeded (
584588 epoch : epoch ,
585- student : Student ,
589+ student : _student ,
586590 teacher : Teacher as ICheckpointableModel ,
587591 strategy : DistillationStrategy ,
588592 metrics : metrics
0 commit comments