Skip to content

Commit 7d2dfa8

Browse files
committed
feat: implement automatic checkpointing for knowledge distillation
Integrate automatic checkpointing directly into KnowledgeDistillationTrainerBase, eliminating the need for manual checkpoint management code. Changes: - Add CheckpointConfig property (null = disabled, set = enabled) - Add Student property for model to checkpoint - Add CheckpointManager (auto-created when CheckpointConfig is set) - Add tracking fields for validation metrics and training loss - Update OnTrainingStart() to initialize CheckpointManager - Update OnEpochEnd() to auto-save checkpoints with metrics - Update OnValidationComplete() to track validation metrics - Update OnTrainingEnd() to auto-load best checkpoint - Update CHECKPOINTING_GUIDE.md with automatic usage example Benefits: - Zero manual checkpoint code required from users - Configuration-driven (just set CheckpointConfig property) - Automatic best model selection via validation metrics - Automatic checkpoint pruning (keeps only best N) - Curriculum state preservation (if using curriculum strategies) - Clean, simple API for 99% of use cases - Manual control still available for advanced scenarios Usage: trainer.CheckpointConfig = new DistillationCheckpointConfig { SaveEveryEpochs = 5, KeepBestN = 3 }; trainer.Student = student as ICheckpointableModel; trainer.Train(...); // Checkpointing happens automatically!
1 parent ff414c6 commit 7d2dfa8

File tree

2 files changed

+184
-5
lines changed

2 files changed

+184
-5
lines changed

src/KnowledgeDistillation/CHECKPOINTING_GUIDE.md

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,83 @@ DistillationCheckpointManager<T>
6060
└─ GetBestCheckpoint() (query metadata)
6161
```
6262

63-
## Basic Usage
63+
## Quick Start: Automatic Checkpointing (Recommended)
6464

65-
### Example 1: Simple Student Checkpointing
65+
The easiest way to enable checkpointing is through automatic checkpointing built into the trainer. Simply configure the checkpoint settings and the trainer handles everything automatically.
66+
67+
### Automatic Checkpointing Example
68+
69+
```csharp
70+
using AiDotNet.KnowledgeDistillation;
71+
72+
// Create trainer
73+
var teacher = LoadPretrainedTeacher();
74+
var student = CreateStudentModel(); // Must implement ICheckpointableModel
75+
var strategy = new ConfidenceBasedAdaptiveStrategy<double>();
76+
77+
var trainer = new KnowledgeDistillationTrainer<double, Vector<double>, Vector<double>>(
78+
teacher,
79+
strategy
80+
);
81+
82+
// Enable automatic checkpointing by setting CheckpointConfig
83+
trainer.CheckpointConfig = new DistillationCheckpointConfig
84+
{
85+
CheckpointDirectory = "./checkpoints",
86+
SaveEveryEpochs = 5, // Auto-save every 5 epochs
87+
KeepBestN = 3, // Keep only 3 best checkpoints
88+
SaveStudent = true,
89+
BestMetric = "validation_loss",
90+
LowerIsBetter = true
91+
};
92+
93+
// Set the student model (required for checkpointing)
94+
trainer.Student = student as ICheckpointableModel;
95+
96+
// Train - checkpointing happens automatically!
97+
trainer.Train(
98+
studentForward: student.Predict,
99+
studentBackward: student.ApplyGradient,
100+
trainInputs: trainingData,
101+
trainLabels: trainingLabels,
102+
epochs: 100,
103+
batchSize: 32,
104+
validationInputs: validationData,
105+
validationLabels: validationLabels
106+
);
107+
108+
// After training completes, the best checkpoint is automatically loaded!
109+
Console.WriteLine("Training complete. Best checkpoint automatically restored.");
110+
```
111+
112+
**What happens automatically:**
113+
1. **OnTrainingStart**: Checkpoint manager is initialized
114+
2. **OnEpochEnd**: Checkpoints are saved based on your configuration
115+
3. **OnValidationComplete**: Validation metrics are tracked for best checkpoint selection
116+
4. **OnTrainingEnd**: Best checkpoint is automatically loaded
117+
118+
**Benefits:**
119+
- ✅ Zero manual checkpoint management code
120+
- ✅ Automatic best model selection
121+
- ✅ Automatic checkpoint pruning (keeps only best N)
122+
- ✅ Curriculum state preservation (if using curriculum strategies)
123+
- ✅ Clean, simple API
124+
125+
### Disabling Automatic Checkpointing
126+
127+
```csharp
128+
// Default: no checkpointing
129+
trainer.CheckpointConfig = null; // or simply don't set it
130+
131+
// Training proceeds without checkpointing
132+
trainer.Train(...);
133+
```
134+
135+
## Manual Checkpointing (Advanced)
136+
137+
For advanced use cases where you need fine-grained control over checkpoint timing and logic, you can use the `DistillationCheckpointManager` directly.
138+
139+
### Example 1: Simple Student Checkpointing (Manual)
66140

67141
```csharp
68142
using AiDotNet.KnowledgeDistillation;

src/KnowledgeDistillation/KnowledgeDistillationTrainerBase.cs

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,47 @@ public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKn
5656
/// </summary>
5757
public IDistillationStrategy<T, TOutput> DistillationStrategy { get; protected set; }
5858

59+
/// <summary>
60+
/// Gets or sets the checkpoint configuration for automatic model saving during training.
61+
/// </summary>
62+
/// <remarks>
63+
/// <para><b>For Beginners:</b> Set this property to enable automatic checkpointing.
64+
/// If null (default), no automatic checkpointing occurs. If set, the trainer will automatically:
65+
/// - Save checkpoints based on your configuration (e.g., every 5 epochs)
66+
/// - Keep only the best N checkpoints to save disk space
67+
/// - Load the best checkpoint after training completes</para>
68+
///
69+
/// <para><b>Example:</b>
70+
/// <code>
71+
/// trainer.CheckpointConfig = new DistillationCheckpointConfig
72+
/// {
73+
/// SaveEveryEpochs = 5,
74+
/// KeepBestN = 3,
75+
/// BestMetric = "validation_loss"
76+
/// };
77+
/// </code>
78+
/// </para>
79+
/// </remarks>
80+
public DistillationCheckpointConfig? CheckpointConfig { get; set; }
81+
82+
/// <summary>
83+
/// Gets or sets the student model for checkpointing.
84+
/// </summary>
85+
/// <remarks>
86+
/// <para>Set this if you want automatic checkpointing. The student must implement
87+
/// <see cref="ICheckpointableModel"/> for checkpoint saving/loading to work.</para>
88+
/// </remarks>
89+
public ICheckpointableModel? Student { get; set; }
90+
91+
/// <summary>
92+
/// Gets the checkpoint manager (created automatically when CheckpointConfig is set).
93+
/// </summary>
94+
protected DistillationCheckpointManager<T>? CheckpointManager { get; private set; }
95+
96+
private double _lastValidationMetric;
97+
private T _lastTrainingLoss;
98+
private int _currentEpoch;
99+
59100
/// <summary>
60101
/// Initializes a new instance of the KnowledgeDistillationTrainerBase class.
61102
/// </summary>
@@ -66,6 +107,9 @@ public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKn
66107
/// <para><b>For Beginners:</b> The teacher and strategy are the core components:
67108
/// - Teacher: Provides the "expert" knowledge to transfer
68109
/// - Strategy: Defines how to measure and optimize the knowledge transfer</para>
110+
///
111+
/// <para><b>Automatic Checkpointing:</b> To enable automatic checkpointing, set the
112+
/// <see cref="CheckpointConfig"/> and <see cref="Student"/> properties after construction.</para>
69113
/// </remarks>
70114
protected KnowledgeDistillationTrainerBase(
71115
ITeacherModel<TInput, TOutput> teacher,
@@ -76,6 +120,7 @@ protected KnowledgeDistillationTrainerBase(
76120
DistillationStrategy = distillationStrategy ?? throw new ArgumentNullException(nameof(distillationStrategy));
77121
NumOps = MathHelper.GetNumericOperations<T>();
78122
Random = seed.HasValue ? new Random(seed.Value) : new Random();
123+
_lastTrainingLoss = NumOps.Zero;
79124
}
80125

81126
/// <summary>
@@ -422,10 +467,17 @@ protected int ArgMax(Vector<T> vector)
422467
/// - Initialize EMA buffers (for self-distillation)
423468
/// - Setup curriculum schedules
424469
/// - Allocate temporary buffers</para>
470+
///
471+
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
472+
/// automatically initializes the checkpoint manager.</para>
425473
/// </remarks>
426474
protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
427475
{
428-
// Default: no-op, derived classes can override
476+
// Initialize checkpoint manager if config is provided
477+
if (CheckpointConfig != null)
478+
{
479+
CheckpointManager = new DistillationCheckpointManager<T>(CheckpointConfig);
480+
}
429481
}
430482

431483
/// <summary>
@@ -439,10 +491,29 @@ protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutpu
439491
/// - Save final checkpoints
440492
/// - Log final metrics
441493
/// - Free temporary resources</para>
494+
///
495+
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
496+
/// automatically loads the best checkpoint (based on validation metrics) after training completes.</para>
442497
/// </remarks>
443498
protected virtual void OnTrainingEnd(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
444499
{
445-
// Default: no-op, derived classes can override
500+
// Load best checkpoint if checkpointing was enabled
501+
if (CheckpointManager != null && Student != null)
502+
{
503+
var bestCheckpoint = CheckpointManager.LoadBestCheckpoint(
504+
student: Student,
505+
teacher: Teacher as ICheckpointableModel
506+
);
507+
508+
if (bestCheckpoint != null)
509+
{
510+
Console.WriteLine($"[Checkpointing] Loaded best checkpoint from epoch {bestCheckpoint.Epoch}");
511+
if (bestCheckpoint.Metrics.ContainsKey(CheckpointConfig!.BestMetric))
512+
{
513+
Console.WriteLine($"[Checkpointing] Best {CheckpointConfig.BestMetric}: {bestCheckpoint.Metrics[CheckpointConfig.BestMetric]:F4}");
514+
}
515+
}
516+
}
446517
}
447518

448519
/// <summary>
@@ -478,6 +549,9 @@ protected virtual void OnEpochStart(int epoch, Vector<TInput> trainInputs, Vecto
478549
/// <para><b>IMPORTANT:</b> This base implementation calls Reset() on RelationalDistillationStrategy
479550
/// to flush partial batches and prevent buffer leakage between epochs. Derived classes should
480551
/// call base.OnEpochEnd() if they override this method.</para>
552+
///
553+
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
554+
/// automatically saves checkpoints based on your configuration.</para>
481555
/// </remarks>
482556
protected virtual void OnEpochEnd(int epoch, T avgLoss)
483557
{
@@ -487,6 +561,33 @@ protected virtual void OnEpochEnd(int epoch, T avgLoss)
487561
{
488562
relationalStrategy.Reset();
489563
}
564+
565+
// Track current state for checkpointing
566+
_currentEpoch = epoch;
567+
_lastTrainingLoss = avgLoss;
568+
569+
// Automatic checkpoint saving
570+
if (CheckpointManager != null)
571+
{
572+
var metrics = new Dictionary<string, double>
573+
{
574+
{ "training_loss", Convert.ToDouble(_lastTrainingLoss) }
575+
};
576+
577+
// Include validation metric if available
578+
if (_lastValidationMetric > 0)
579+
{
580+
metrics[CheckpointConfig!.BestMetric] = _lastValidationMetric;
581+
}
582+
583+
CheckpointManager.SaveCheckpointIfNeeded(
584+
epoch: epoch,
585+
student: Student,
586+
teacher: Teacher as ICheckpointableModel,
587+
strategy: DistillationStrategy,
588+
metrics: metrics
589+
);
590+
}
490591
}
491592

492593
/// <summary>
@@ -500,9 +601,13 @@ protected virtual void OnEpochEnd(int epoch, T avgLoss)
500601
/// - Implement early stopping
501602
/// - Track best model
502603
/// - Adjust hyperparameters based on validation performance</para>
604+
///
605+
/// <para><b>Automatic Checkpointing:</b> If <see cref="CheckpointConfig"/> is set, this method
606+
/// automatically tracks validation metrics for best checkpoint selection.</para>
503607
/// </remarks>
504608
protected virtual void OnValidationComplete(int epoch, double accuracy)
505609
{
506-
// Default: no-op, derived classes can override
610+
// Track validation metric for checkpointing
611+
_lastValidationMetric = accuracy;
507612
}
508613
}

0 commit comments

Comments
 (0)