Skip to content

Commit b2ce7b4

Browse files
committed
refactor: apply facade pattern to automatic checkpointing architecture
Refactor checkpointing implementation to follow the library's facade pattern, hiding implementation details and providing a clean, simple API. Changes: - Remove public CheckpointConfig and Student properties from trainer base - Add checkpointConfig parameter to trainer constructor (optional) - Add student parameter to Train method (optional) - Make checkpoint manager and fields private/internal - Update KnowledgeDistillationTrainer constructor to accept checkpointConfig - Update CHECKPOINTING_GUIDE.md to show new facade-based usage Architecture improvements: - Follows facade pattern (hide complexity, expose simple interface) - Configuration injected through constructor (proper DI pattern) - Student model passed to Train method (clean method signature) - No mutable public state exposed - Internal implementation fully encapsulated Old usage (anti-pattern): trainer.CheckpointConfig = config; // Public property mutation trainer.Student = student; // Public property mutation trainer.Train(...); New usage (facade pattern): var trainer = new Trainer(teacher, strategy, checkpointConfig: config); trainer.Train(..., student: student); This matches the library's architectural principles and provides a cleaner API.
1 parent 7d2dfa8 commit b2ce7b4

File tree

3 files changed

+93
-70
lines changed

3 files changed

+93
-70
lines changed

src/KnowledgeDistillation/CHECKPOINTING_GUIDE.md

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,15 @@ DistillationCheckpointManager<T>
6262

6363
## Quick Start: Automatic Checkpointing (Recommended)
6464

65-
The easiest way to enable checkpointing is through automatic checkpointing built into the trainer. Simply configure the checkpoint settings and the trainer handles everything automatically.
65+
The easiest way to enable checkpointing is through automatic checkpointing built into the trainer. Simply pass the checkpoint configuration to the trainer constructor and the student model to the Train method.
6666

6767
### Automatic Checkpointing Example
6868

6969
```csharp
7070
using AiDotNet.KnowledgeDistillation;
7171

72-
// Create trainer
73-
var teacher = LoadPretrainedTeacher();
74-
var student = CreateStudentModel(); // Must implement ICheckpointableModel
75-
var strategy = new ConfidenceBasedAdaptiveStrategy<double>();
76-
77-
var trainer = new KnowledgeDistillationTrainer<double, Vector<double>, Vector<double>>(
78-
teacher,
79-
strategy
80-
);
81-
82-
// Enable automatic checkpointing by setting CheckpointConfig
83-
trainer.CheckpointConfig = new DistillationCheckpointConfig
72+
// Create checkpoint configuration
73+
var checkpointConfig = new DistillationCheckpointConfig
8474
{
8575
CheckpointDirectory = "./checkpoints",
8676
SaveEveryEpochs = 5, // Auto-save every 5 epochs
@@ -90,10 +80,20 @@ trainer.CheckpointConfig = new DistillationCheckpointConfig
9080
LowerIsBetter = true
9181
};
9282

93-
// Set the student model (required for checkpointing)
94-
trainer.Student = student as ICheckpointableModel;
83+
// Create trainer with checkpoint config
84+
var teacher = LoadPretrainedTeacher();
85+
var strategy = new ConfidenceBasedAdaptiveStrategy<double>();
86+
87+
var trainer = new KnowledgeDistillationTrainer<double>(
88+
teacher,
89+
strategy,
90+
checkpointConfig: checkpointConfig // Pass config to constructor
91+
);
92+
93+
// Create student model (must implement ICheckpointableModel)
94+
var student = CreateStudentModel();
9595

96-
// Train - checkpointing happens automatically!
96+
// Train - pass student to Train method for automatic checkpointing!
9797
trainer.Train(
9898
studentForward: student.Predict,
9999
studentBackward: student.ApplyGradient,
@@ -102,7 +102,8 @@ trainer.Train(
102102
epochs: 100,
103103
batchSize: 32,
104104
validationInputs: validationData,
105-
validationLabels: validationLabels
105+
validationLabels: validationLabels,
106+
student: student as ICheckpointableModel // Pass student for checkpointing
106107
);
107108

108109
// After training completes, the best checkpoint is automatically loaded!
@@ -120,13 +121,13 @@ Console.WriteLine("Training complete. Best checkpoint automatically restored.");
120121
- ✅ Automatic best model selection
121122
- ✅ Automatic checkpoint pruning (keeps only best N)
122123
- ✅ Curriculum state preservation (if using curriculum strategies)
123-
- ✅ Clean, simple API
124+
- ✅ Clean, simple API following facade pattern
124125

125126
### Disabling Automatic Checkpointing
126127

127128
```csharp
128-
// Default: no checkpointing
129-
trainer.CheckpointConfig = null; // or simply don't set it
129+
// Default: no checkpointing (don't pass checkpointConfig)
130+
var trainer = new KnowledgeDistillationTrainer<double>(teacher, strategy);
130131

131132
// Training proceeds without checkpointing
132133
trainer.Train(...);

src/KnowledgeDistillation/KnowledgeDistillationTrainer.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ public class KnowledgeDistillationTrainer<T> : KnowledgeDistillationTrainerBase<
4141
/// </summary>
4242
/// <param name="teacher">The teacher model to learn from.</param>
4343
/// <param name="distillationStrategy">The strategy for computing distillation loss.</param>
44+
/// <param name="checkpointConfig">Optional checkpoint configuration for automatic model saving during training.</param>
4445
/// <param name="seed">Optional random seed for reproducibility.</param>
4546
/// <remarks>
4647
/// <para><b>For Beginners:</b> Create a trainer by providing:
4748
/// 1. A trained teacher model (already performing well on your task)
48-
/// 2. A distillation strategy (defines how to transfer knowledge)</para>
49+
/// 2. A distillation strategy (defines how to transfer knowledge)
50+
/// 3. Optional checkpoint configuration (for automatic model saving)</para>
4951
///
5052
/// <para>Example:
5153
/// <code>
@@ -54,12 +56,28 @@ public class KnowledgeDistillationTrainer<T> : KnowledgeDistillationTrainerBase<
5456
/// var trainer = new KnowledgeDistillationTrainer&lt;double&gt;(teacher, distillationLoss);
5557
/// </code>
5658
/// </para>
59+
///
60+
/// <para>Example with automatic checkpointing:
61+
/// <code>
62+
/// var checkpointConfig = new DistillationCheckpointConfig
63+
/// {
64+
/// SaveEveryEpochs = 5,
65+
/// KeepBestN = 3
66+
/// };
67+
/// var trainer = new KnowledgeDistillationTrainer&lt;double&gt;(
68+
/// teacher,
69+
/// distillationLoss,
70+
/// checkpointConfig: checkpointConfig
71+
/// );
72+
/// </code>
73+
/// </para>
5774
/// </remarks>
5875
public KnowledgeDistillationTrainer(
5976
ITeacherModel<Vector<T>, Vector<T>> teacher,
6077
IDistillationStrategy<T, Vector<T>> distillationStrategy,
78+
DistillationCheckpointConfig? checkpointConfig = null,
6179
int? seed = null)
62-
: base(teacher, distillationStrategy, seed)
80+
: base(teacher, distillationStrategy, checkpointConfig, seed)
6381
{
6482
}
6583

src/KnowledgeDistillation/KnowledgeDistillationTrainerBase.cs

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,41 +57,19 @@ public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKn
5757
public IDistillationStrategy<T, TOutput> DistillationStrategy { get; protected set; }
5858

5959
/// <summary>
60-
/// Gets or sets the checkpoint configuration for automatic model saving during training.
60+
/// Checkpoint configuration for automatic model saving during training (internal).
6161
/// </summary>
62-
/// <remarks>
63-
/// <para><b>For Beginners:</b> Set this property to enable automatic checkpointing.
64-
/// If null (default), no automatic checkpointing occurs. If set, the trainer will automatically:
65-
/// - Save checkpoints based on your configuration (e.g., every 5 epochs)
66-
/// - Keep only the best N checkpoints to save disk space
67-
/// - Load the best checkpoint after training completes</para>
68-
///
69-
/// <para><b>Example:</b>
70-
/// <code>
71-
/// trainer.CheckpointConfig = new DistillationCheckpointConfig
72-
/// {
73-
/// SaveEveryEpochs = 5,
74-
/// KeepBestN = 3,
75-
/// BestMetric = "validation_loss"
76-
/// };
77-
/// </code>
78-
/// </para>
79-
/// </remarks>
80-
public DistillationCheckpointConfig? CheckpointConfig { get; set; }
62+
private readonly DistillationCheckpointConfig? _checkpointConfig;
8163

8264
/// <summary>
83-
/// Gets or sets the student model for checkpointing.
65+
/// Checkpoint manager for handling checkpoint operations (internal).
8466
/// </summary>
85-
/// <remarks>
86-
/// <para>Set this if you want automatic checkpointing. The student must implement
87-
/// <see cref="ICheckpointableModel"/> for checkpoint saving/loading to work.</para>
88-
/// </remarks>
89-
public ICheckpointableModel? Student { get; set; }
67+
private DistillationCheckpointManager<T>? _checkpointManager;
9068

9169
/// <summary>
92-
/// Gets the checkpoint manager (created automatically when CheckpointConfig is set).
70+
/// Student model reference for checkpointing (internal).
9371
/// </summary>
94-
protected DistillationCheckpointManager<T>? CheckpointManager { get; private set; }
72+
private ICheckpointableModel? _student;
9573

9674
private double _lastValidationMetric;
9775
private T _lastTrainingLoss;
@@ -102,22 +80,40 @@ public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKn
10280
/// </summary>
10381
/// <param name="teacher">The teacher model.</param>
10482
/// <param name="distillationStrategy">The distillation strategy.</param>
83+
/// <param name="checkpointConfig">Optional checkpoint configuration for automatic model saving during training.</param>
10584
/// <param name="seed">Optional random seed for reproducibility.</param>
10685
/// <remarks>
10786
/// <para><b>For Beginners:</b> The teacher and strategy are the core components:
10887
/// - Teacher: Provides the "expert" knowledge to transfer
10988
/// - Strategy: Defines how to measure and optimize the knowledge transfer</para>
11089
///
111-
/// <para><b>Automatic Checkpointing:</b> To enable automatic checkpointing, set the
112-
/// <see cref="CheckpointConfig"/> and <see cref="Student"/> properties after construction.</para>
90+
/// <para><b>Automatic Checkpointing:</b> To enable automatic checkpointing, pass a
91+
/// <see cref="DistillationCheckpointConfig"/> instance. If null (default), no automatic checkpointing occurs.
92+
/// When enabled, the trainer will automatically:
93+
/// - Save checkpoints based on your configuration (e.g., every 5 epochs)
94+
/// - Keep only the best N checkpoints to save disk space
95+
/// - Load the best checkpoint after training completes</para>
96+
///
97+
/// <para><b>Example with Checkpointing:</b>
98+
/// <code>
99+
/// var config = new DistillationCheckpointConfig
100+
/// {
101+
/// SaveEveryEpochs = 5,
102+
/// KeepBestN = 3
103+
/// };
104+
/// var trainer = new KnowledgeDistillationTrainer(teacher, strategy, checkpointConfig: config);
105+
/// </code>
106+
/// </para>
113107
/// </remarks>
114108
protected KnowledgeDistillationTrainerBase(
115109
ITeacherModel<TInput, TOutput> teacher,
116110
IDistillationStrategy<T, TOutput> distillationStrategy,
111+
DistillationCheckpointConfig? checkpointConfig = null,
117112
int? seed = null)
118113
{
119114
Teacher = teacher ?? throw new ArgumentNullException(nameof(teacher));
120115
DistillationStrategy = distillationStrategy ?? throw new ArgumentNullException(nameof(distillationStrategy));
116+
_checkpointConfig = checkpointConfig;
121117
NumOps = MathHelper.GetNumericOperations<T>();
122118
Random = seed.HasValue ? new Random(seed.Value) : new Random();
123119
_lastTrainingLoss = NumOps.Zero;
@@ -186,6 +182,7 @@ public virtual T TrainBatch(
186182
/// <param name="batchSize">Batch size for mini-batch training.</param>
187183
/// <param name="validationInputs">Optional validation inputs for monitoring.</param>
188184
/// <param name="validationLabels">Optional validation labels.</param>
185+
/// <param name="student">Optional student model for automatic checkpointing (must implement ICheckpointableModel).</param>
189186
/// <param name="onEpochComplete">Optional callback invoked after each epoch with (epoch, avgLoss).</param>
190187
/// <remarks>
191188
/// <para><b>For Beginners:</b> This method orchestrates the complete training process:
@@ -201,6 +198,9 @@ public virtual T TrainBatch(
201198
/// - Use batch sizes that fit in memory (32-128 typical)
202199
/// - Monitor validation loss to detect overfitting
203200
/// - Invoke callbacks to log progress or save checkpoints</para>
201+
///
202+
/// <para><b>Automatic Checkpointing:</b> If a checkpoint configuration was provided to the constructor
203+
/// and you pass the student model parameter, automatic checkpointing will be enabled.</para>
204204
/// </remarks>
205205
public virtual void Train(
206206
Func<TInput, TOutput> studentForward,
@@ -211,6 +211,7 @@ public virtual void Train(
211211
int batchSize = 32,
212212
Vector<TInput>? validationInputs = null,
213213
Vector<TOutput>? validationLabels = null,
214+
ICheckpointableModel? student = null,
214215
Action<int, T>? onEpochComplete = null)
215216
{
216217
if (studentForward == null) throw new ArgumentNullException(nameof(studentForward));
@@ -226,6 +227,9 @@ public virtual void Train(
226227
if (validationInputs != null && validationLabels != null && validationInputs.Length != validationLabels.Length)
227228
throw new ArgumentException("Validation inputs and labels must have the same length");
228229

230+
// Store student reference for checkpointing
231+
_student = student;
232+
229233
// Prepare for training
230234
OnTrainingStart(trainInputs, trainLabels);
231235

@@ -303,7 +307,7 @@ public virtual void Train(
303307
int batchSize = 32,
304308
Action<int, T>? onEpochComplete = null)
305309
{
306-
Train(studentForward, studentBackward, trainInputs, trainLabels, epochs, batchSize, null, null, onEpochComplete);
310+
Train(studentForward, studentBackward, trainInputs, trainLabels, epochs, batchSize, null, null, null, onEpochComplete);
307311
}
308312

309313
/// <summary>
@@ -468,15 +472,15 @@ protected int ArgMax(Vector<T> vector)
468472
/// - Setup curriculum schedules
469473
/// - Allocate temporary buffers</para>
470474
///
471-
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
472-
/// automatically initializes the checkpoint manager.</para>
475+
/// <para><b>Automatic Checkpointing:</b> If checkpoint configuration was provided to the constructor,
476+
/// this method automatically initializes the checkpoint manager.</para>
473477
/// </remarks>
474478
protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
475479
{
476480
// Initialize checkpoint manager if config is provided
477-
if (CheckpointConfig != null)
481+
if (_checkpointConfig != null)
478482
{
479-
CheckpointManager = new DistillationCheckpointManager<T>(CheckpointConfig);
483+
_checkpointManager = new DistillationCheckpointManager<T>(_checkpointConfig);
480484
}
481485
}
482486

@@ -492,25 +496,25 @@ protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutpu
492496
/// - Log final metrics
493497
/// - Free temporary resources</para>
494498
///
495-
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
496-
/// automatically loads the best checkpoint (based on validation metrics) after training completes.</para>
499+
/// <para><b>Automatic Checkpointing:</b> If checkpoint configuration was provided to the constructor,
500+
/// this method automatically loads the best checkpoint (based on validation metrics) after training completes.</para>
497501
/// </remarks>
498502
protected virtual void OnTrainingEnd(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
499503
{
500504
// Load best checkpoint if checkpointing was enabled
501-
if (CheckpointManager != null && Student != null)
505+
if (_checkpointManager != null && _student != null)
502506
{
503-
var bestCheckpoint = CheckpointManager.LoadBestCheckpoint(
504-
student: Student,
507+
var bestCheckpoint = _checkpointManager.LoadBestCheckpoint(
508+
student: _student,
505509
teacher: Teacher as ICheckpointableModel
506510
);
507511

508512
if (bestCheckpoint != null)
509513
{
510514
Console.WriteLine($"[Checkpointing] Loaded best checkpoint from epoch {bestCheckpoint.Epoch}");
511-
if (bestCheckpoint.Metrics.ContainsKey(CheckpointConfig!.BestMetric))
515+
if (bestCheckpoint.Metrics.ContainsKey(_checkpointConfig!.BestMetric))
512516
{
513-
Console.WriteLine($"[Checkpointing] Best {CheckpointConfig.BestMetric}: {bestCheckpoint.Metrics[CheckpointConfig.BestMetric]:F4}");
517+
Console.WriteLine($"[Checkpointing] Best {_checkpointConfig.BestMetric}: {bestCheckpoint.Metrics[_checkpointConfig.BestMetric]:F4}");
514518
}
515519
}
516520
}
@@ -550,8 +554,8 @@ protected virtual void OnEpochStart(int epoch, Vector<TInput> trainInputs, Vecto
550554
/// to flush partial batches and prevent buffer leakage between epochs. Derived classes should
551555
/// call base.OnEpochEnd() if they override this method.</para>
552556
///
553-
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
554-
/// automatically saves checkpoints based on your configuration.</para>
557+
/// <para><b>Automatic Checkpointing:</b> If checkpoint configuration was provided to the constructor,
558+
/// this method automatically saves checkpoints based on your configuration.</para>
555559
/// </remarks>
556560
protected virtual void OnEpochEnd(int epoch, T avgLoss)
557561
{
@@ -567,7 +571,7 @@ protected virtual void OnEpochEnd(int epoch, T avgLoss)
567571
_lastTrainingLoss = avgLoss;
568572

569573
// Automatic checkpoint saving
570-
if (CheckpointManager != null)
574+
if (_checkpointManager != null)
571575
{
572576
var metrics = new Dictionary<string, double>
573577
{
@@ -577,12 +581,12 @@ protected virtual void OnEpochEnd(int epoch, T avgLoss)
577581
// Include validation metric if available
578582
if (_lastValidationMetric > 0)
579583
{
580-
metrics[CheckpointConfig!.BestMetric] = _lastValidationMetric;
584+
metrics[_checkpointConfig!.BestMetric] = _lastValidationMetric;
581585
}
582586

583-
CheckpointManager.SaveCheckpointIfNeeded(
587+
_checkpointManager.SaveCheckpointIfNeeded(
584588
epoch: epoch,
585-
student: Student,
589+
student: _student,
586590
teacher: Teacher as ICheckpointableModel,
587591
strategy: DistillationStrategy,
588592
metrics: metrics

0 commit comments

Comments
 (0)