Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/Interfaces/IPruningMask.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
namespace AiDotNet.Interfaces;

/// <summary>
/// Represents a binary mask for pruning weights in a neural network layer.
/// </summary>
/// <typeparam name="T">Numeric type for mask values</typeparam>
/// <remarks>
/// <para>
/// A pruning mask is a binary matrix that determines which weights to keep (1) and which to remove (0)
/// during model compression. It enables selective removal of network parameters while maintaining the
/// ability to restore the network structure.
/// </para>
/// <para><b>For Beginners:</b> Think of a pruning mask as a stencil or template.
///
/// Imagine you're painting a picture and want to cover certain areas:
/// - The mask has holes (1s) where paint should go through (weights to keep)
/// - The mask is solid (0s) where paint should be blocked (weights to prune/remove)
///
/// In neural networks:
/// - A pruning mask helps you selectively remove less important connections
/// - This makes your model smaller and faster without losing too much accuracy
/// - The mask can be applied to weight matrices to zero out pruned weights
/// </para>
/// </remarks>
public interface IPruningMask<T>
{
/// <summary>
/// Gets the mask dimensions matching the weight matrix shape.
/// </summary>
int[] Shape { get; }

/// <summary>
/// Gets the sparsity ratio (proportion of zeros).
/// </summary>
/// <returns>Value between 0 (dense) and 1 (fully pruned)</returns>
/// <remarks>
/// <para><b>For Beginners:</b> Sparsity measures how many weights have been removed.
/// - 0.0 means no weights removed (0% sparse, 100% dense)
/// - 0.5 means half the weights removed (50% sparse)
/// - 0.9 means 90% of weights removed (90% sparse)
/// </para>
/// </remarks>
double GetSparsity();

/// <summary>
/// Applies the mask to a weight matrix (element-wise multiplication).
/// </summary>
/// <param name="weights">Weight matrix to prune</param>
/// <returns>Pruned weights (zeros where mask is zero)</returns>
Matrix<T> Apply(Matrix<T> weights);

/// <summary>
/// Applies the mask to a weight tensor (for convolutional layers).
/// </summary>
Tensor<T> Apply(Tensor<T> weights);

/// <summary>
/// Updates the mask based on new pruning criteria.
/// </summary>
/// <param name="keepIndices">Indices of weights to keep (not prune)</param>
void UpdateMask(bool[,] keepIndices);

/// <summary>
/// Combines this mask with another mask (logical AND).
/// </summary>
IPruningMask<T> CombineWith(IPruningMask<T> otherMask);
}
89 changes: 89 additions & 0 deletions src/Interfaces/IPruningStrategy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
namespace AiDotNet.Interfaces;

/// <summary>
/// Defines a strategy for pruning neural network weights.
/// </summary>
/// <typeparam name="T">Numeric type for weights and gradients</typeparam>
/// <remarks>
/// <para>
/// A pruning strategy determines which weights in a neural network should be removed
/// to reduce model size and computational requirements while maintaining accuracy.
/// Different strategies use different criteria to measure weight importance.
/// </para>
/// <para><b>For Beginners:</b> A pruning strategy decides which connections to remove from a neural network.
///
/// Think of it like pruning a tree:
/// - You want to remove branches that don't contribute much to the tree's health
/// - You keep the important branches that carry nutrients and support the structure
/// - The goal is a healthier, more efficient tree
///
/// In neural networks:
/// - Different strategies measure "importance" differently
/// - Magnitude-based: Remove smallest weights (they contribute less to output)
/// - Gradient-based: Remove weights with smallest gradients (they learn slowly)
/// - Structured: Remove entire neurons or filters (cleaner architecture)
///
/// All strategies aim to compress the model while preserving its predictive power.
/// </para>
/// </remarks>
public interface IPruningStrategy<T>
{
/// <summary>
/// Computes importance scores for each weight.
/// </summary>
/// <param name="weights">Weight matrix</param>
/// <param name="gradients">Gradient matrix (optional, can be null)</param>
/// <returns>Importance score for each weight (higher = more important)</returns>
/// <remarks>
/// <para><b>For Beginners:</b> This method assigns each weight a score representing its importance.
/// Higher scores mean the weight is more important and should be kept.
/// Lower scores mean the weight can be safely removed.
/// </para>
/// </remarks>
Matrix<T> ComputeImportanceScores(Matrix<T> weights, Matrix<T>? gradients = null);

/// <summary>
/// Creates a pruning mask based on target sparsity.
/// </summary>
/// <param name="importanceScores">Importance scores from ComputeImportanceScores</param>
/// <param name="targetSparsity">Target sparsity ratio (0 to 1)</param>
/// <returns>Binary mask (1 = keep, 0 = prune)</returns>
/// <remarks>
/// <para><b>For Beginners:</b> This creates the actual mask that determines which weights to remove.
/// If targetSparsity is 0.7, it will mark 70% of the least important weights for removal.
/// </para>
/// </remarks>
IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSparsity);

/// <summary>
/// Prunes a weight matrix in-place.
/// </summary>
/// <param name="weights">Weight matrix to prune</param>
/// <param name="mask">Pruning mask to apply</param>
/// <remarks>
/// <para><b>For Beginners:</b> This actually removes the weights by applying the mask.
/// After this operation, pruned weights become zero.
/// </para>
/// </remarks>
void ApplyPruning(Matrix<T> weights, IPruningMask<T> mask);

/// <summary>
/// Gets whether this strategy requires gradients.
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b> Some strategies need gradient information to determine importance.
/// If true, you must provide gradients when calling ComputeImportanceScores.
/// </para>
/// </remarks>
bool RequiresGradients { get; }

/// <summary>
/// Gets whether this is structured pruning (removes entire rows/cols).
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b> Structured pruning removes entire neurons or filters.
/// Unstructured pruning (false) removes individual weights anywhere in the network.
/// </para>
/// </remarks>
bool IsStructured { get; }
}
169 changes: 169 additions & 0 deletions src/Pruning/GradientPruningStrategy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
namespace AiDotNet.Pruning;

/// <summary>
/// Prunes weights based on gradient magnitude (sensitivity).
/// </summary>
/// <typeparam name="T">The numeric type used for calculations (e.g., float, double).</typeparam>
/// <remarks>
/// <para>
/// Gradient-based pruning uses gradient information to determine weight importance.
/// Weights with small gradients have little impact on the loss function and can be safely removed.
/// This approach considers both the weight value and how much it affects learning.
/// </para>
/// <para><b>For Beginners:</b> This strategy removes connections that don't learn much.
///
/// Think of it like identifying which team members contribute to a project:
/// - High gradient = This weight changes a lot during training, it's learning something important
/// - Low gradient = This weight barely changes, it's not contributing much to learning
///
/// The importance score is calculated as |weight × gradient|:
/// - If a weight is large BUT has tiny gradients, it might not be doing much
/// - If a weight is learning slowly (small gradient), removing it won't hurt performance
///
/// This is smarter than magnitude-based pruning because it considers learning dynamics,
/// not just weight size. However, it requires gradient information from training.
///
/// Example:
/// - Weight = 0.5, Gradient = 0.001 → Importance = |0.5 × 0.001| = 0.0005 (low, prune it)
/// - Weight = 0.3, Gradient = 0.9 → Importance = |0.3 × 0.9| = 0.27 (high, keep it)
/// </para>
/// </remarks>
public class GradientPruningStrategy<T> : IPruningStrategy<T>
{
private readonly INumericOperations<T> _numOps;

/// <summary>
/// Gets whether this strategy requires gradients (true for gradient-based).
/// </summary>
public bool RequiresGradients => true;

/// <summary>
/// Gets whether this is structured pruning (false for gradient-based).
/// </summary>
public bool IsStructured => false;

/// <summary>
/// Initializes a new instance of GradientPruningStrategy.
/// </summary>
public GradientPruningStrategy()
{
_numOps = MathHelper.GetNumericOperations<T>();
}

/// <summary>
/// Computes importance scores as the product of weight magnitude and gradient magnitude.
/// </summary>
/// <param name="weights">Weight matrix</param>
/// <param name="gradients">Gradient matrix (required for this strategy)</param>
/// <returns>Matrix of importance scores</returns>
/// <exception cref="ArgumentException">Thrown when gradients are null or shape doesn't match weights</exception>
/// <remarks>
/// <para><b>For Beginners:</b> This calculates how important each weight is by looking at both:
/// 1. The weight's value
/// 2. How much the weight is learning (its gradient)
///
/// The importance is |weight × gradient|. This tells us how much removing the weight
/// would affect the model's learning and output.
/// </para>
/// </remarks>
public Matrix<T> ComputeImportanceScores(Matrix<T> weights, Matrix<T>? gradients = null)
{
if (gradients == null)
throw new ArgumentException("GradientPruningStrategy requires gradients");

if (weights.Rows != gradients.Rows || weights.Columns != gradients.Columns)
throw new ArgumentException("Weights and gradients must have same shape");

// Importance = |weight * gradient|
// This measures how much removing the weight affects the loss
var scores = new Matrix<T>(weights.Rows, weights.Columns);

for (int i = 0; i < weights.Rows; i++)
{
for (int j = 0; j < weights.Columns; j++)
{
// |w_ij * g_ij|
var product = _numOps.Multiply(weights[i, j], gradients[i, j]);
scores[i, j] = _numOps.Abs(product);
}
}

return scores;
}

/// <summary>
/// Creates a pruning mask by selecting weights with lowest gradient-based importance.
/// </summary>
/// <param name="importanceScores">Importance scores from ComputeImportanceScores</param>
/// <param name="targetSparsity">Target sparsity ratio (0 to 1)</param>
/// <returns>Binary mask (1 = keep, 0 = prune)</returns>
/// <exception cref="ArgumentException">Thrown when targetSparsity is not between 0 and 1</exception>
/// <remarks>
/// <para><b>For Beginners:</b> This creates the mask that decides which weights to remove.
///
/// Similar to magnitude pruning, but using gradient-based scores:
/// - Weights with low |weight × gradient| scores are pruned
/// - Weights with high scores are kept
///
/// This tends to preserve weights that are actively contributing to learning.
/// </para>
/// </remarks>
public IPruningMask<T> CreateMask(Matrix<T> importanceScores, double targetSparsity)
{
// Same logic as magnitude pruning, but with gradient-based scores
if (targetSparsity < 0 || targetSparsity > 1)
throw new ArgumentException("targetSparsity must be between 0 and 1");

int totalElements = importanceScores.Rows * importanceScores.Columns;
int numToPrune = (int)(totalElements * targetSparsity);

var flatScores = new List<(int row, int col, T score)>();

for (int i = 0; i < importanceScores.Rows; i++)
for (int j = 0; j < importanceScores.Columns; j++)
flatScores.Add((i, j, importanceScores[i, j]));

flatScores.Sort((a, b) =>
{
double aVal = Convert.ToDouble(a.score);
double bVal = Convert.ToDouble(b.score);
return aVal.CompareTo(bVal);
});

var keepIndices = new bool[importanceScores.Rows, importanceScores.Columns];

for (int i = 0; i < importanceScores.Rows; i++)
for (int j = 0; j < importanceScores.Columns; j++)
keepIndices[i, j] = true;

for (int i = 0; i < numToPrune && i < flatScores.Count; i++)
{
var (row, col, _) = flatScores[i];
keepIndices[row, col] = false;
}

var mask = new PruningMask<T>(importanceScores.Rows, importanceScores.Columns);
mask.UpdateMask(keepIndices);

return mask;
}

/// <summary>
/// Applies the pruning mask to weights in-place.
/// </summary>
/// <param name="weights">Weight matrix to prune</param>
/// <param name="mask">Pruning mask to apply</param>
/// <remarks>
/// <para><b>For Beginners:</b> This actually removes the weights by setting them to zero.
/// The pruned weights are those identified as having low gradient-based importance.
/// </para>
/// </remarks>
public void ApplyPruning(Matrix<T> weights, IPruningMask<T> mask)
{
var pruned = mask.Apply(weights);

for (int i = 0; i < weights.Rows; i++)
for (int j = 0; j < weights.Columns; j++)
weights[i, j] = pruned[i, j];
}
}
Loading
Loading