diff --git a/src/FederatedLearning/Aggregators/FedAvgAggregationStrategy.cs b/src/FederatedLearning/Aggregators/FedAvgAggregationStrategy.cs
new file mode 100644
index 000000000..a5efe6a29
--- /dev/null
+++ b/src/FederatedLearning/Aggregators/FedAvgAggregationStrategy.cs
@@ -0,0 +1,158 @@
+namespace AiDotNet.FederatedLearning.Aggregators;
+
+using AiDotNet.Interfaces;
+
+///
+/// Implements the Federated Averaging (FedAvg) aggregation strategy.
+///
+///
+/// FedAvg is the foundational aggregation algorithm for federated learning, proposed by
+/// McMahan et al. in 2017. It performs a weighted average of client model updates based
+/// on the number of training samples each client has.
+///
+/// For Beginners: FedAvg is like calculating a weighted class average where students
+/// who solved more practice problems have more influence on the final answer.
+///
+/// How FedAvg works:
+/// 1. Each client trains on their local data and computes model updates
+/// 2. Clients send their updated model weights to the server
+/// 3. Server computes weighted average: weight = (client_samples / total_samples)
+/// 4. New global model = Σ(weight_i × client_model_i)
+///
+/// For example, with 3 hospitals:
+/// - Hospital A: 1000 patients, model accuracy 90%
+/// - Hospital B: 500 patients, model accuracy 88%
+/// - Hospital C: 1500 patients, model accuracy 92%
+///
+/// Total patients: 3000
+/// Hospital A weight: 1000/3000 = 0.333
+/// Hospital B weight: 500/3000 = 0.167
+/// Hospital C weight: 1500/3000 = 0.500
+///
+/// For each model parameter:
+/// global_param = 0.333 × A_param + 0.167 × B_param + 0.500 × C_param
+///
+/// Benefits:
+/// - Simple and efficient
+/// - Well-studied theoretically
+/// - Works well when clients have similar data distributions (IID data)
+///
+/// Limitations:
+/// - Assumes clients are equally reliable
+/// - Can struggle with non-IID data (different distributions across clients)
+/// - No built-in handling for stragglers (slow clients)
+///
+/// Reference: McMahan, H. B., et al. (2017). "Communication-Efficient Learning of Deep Networks
+/// from Decentralized Data." AISTATS 2017.
+///
+/// The numeric type for model parameters (e.g., double, float).
+public class FedAvgAggregationStrategy : IAggregationStrategy>
+ where T : struct, IComparable, IConvertible
+{
+ ///
+ /// Aggregates client models using weighted averaging based on the number of samples.
+ ///
+ ///
+ /// This method implements the core FedAvg algorithm:
+ ///
+ /// Mathematical formulation:
+ /// w_global = Σ(n_k / n_total) × w_k
+ ///
+ /// where:
+ /// - w_global: global model weights
+ /// - w_k: client k's model weights
+ /// - n_k: number of samples at client k
+ /// - n_total: total samples across all clients
+ ///
+ /// For Beginners: This combines all client models into one by taking a weighted
+ /// average, where clients with more data have more influence.
+ ///
+ /// Step-by-step process:
+ /// 1. Calculate total samples across all clients
+ /// 2. For each client, compute weight = client_samples / total_samples
+ /// 3. For each model parameter, compute weighted sum
+ /// 4. Return the aggregated model
+ ///
+ /// For example, if we have 2 clients with a simple model (one parameter):
+ /// - Client 1: 300 samples, parameter value = 0.8
+ /// - Client 2: 700 samples, parameter value = 0.6
+ ///
+ /// Total samples: 1000
+ /// Client 1 weight: 300/1000 = 0.3
+ /// Client 2 weight: 700/1000 = 0.7
+ /// Aggregated parameter: 0.3 × 0.8 + 0.7 × 0.6 = 0.24 + 0.42 = 0.66
+ ///
+ /// Dictionary mapping client IDs to their model parameters.
+ /// Dictionary mapping client IDs to their sample counts (weights).
+ /// The aggregated global model parameters.
+ public Dictionary Aggregate(
+ Dictionary> clientModels,
+ Dictionary clientWeights)
+ {
+ if (clientModels == null || clientModels.Count == 0)
+ {
+ throw new ArgumentException("Client models cannot be null or empty.", nameof(clientModels));
+ }
+
+ if (clientWeights == null || clientWeights.Count == 0)
+ {
+ throw new ArgumentException("Client weights cannot be null or empty.", nameof(clientWeights));
+ }
+
+ // Calculate total weight (total number of samples across all clients)
+ double totalWeight = clientWeights.Values.Sum();
+
+ if (totalWeight <= 0)
+ {
+ throw new ArgumentException("Total weight must be positive.", nameof(clientWeights));
+ }
+
+ // Get the first client's model structure to initialize the aggregated model
+ var firstClientModel = clientModels.First().Value;
+ var aggregatedModel = new Dictionary();
+
+ // Initialize aggregated model with zeros
+ foreach (var layerName in firstClientModel.Keys)
+ {
+ aggregatedModel[layerName] = new T[firstClientModel[layerName].Length];
+ }
+
+ // Perform weighted aggregation
+ foreach (var clientId in clientModels.Keys)
+ {
+ var clientModel = clientModels[clientId];
+ var clientWeight = clientWeights[clientId];
+
+ // Normalized weight for this client
+ double normalizedWeight = clientWeight / totalWeight;
+
+ // Add weighted contribution from this client to the aggregated model
+ foreach (var layerName in clientModel.Keys)
+ {
+ var clientParams = clientModel[layerName];
+ var aggregatedParams = aggregatedModel[layerName];
+
+ for (int i = 0; i < clientParams.Length; i++)
+ {
+ // Convert T to double, perform weighted addition, convert back to T
+ double currentValue = Convert.ToDouble(aggregatedParams[i]);
+ double clientValue = Convert.ToDouble(clientParams[i]);
+ double weightedValue = currentValue + (normalizedWeight * clientValue);
+
+ aggregatedParams[i] = (T)Convert.ChangeType(weightedValue, typeof(T));
+ }
+ }
+ }
+
+ return aggregatedModel;
+ }
+
+ ///
+ /// Gets the name of the aggregation strategy.
+ ///
+ /// The string "FedAvg".
+ public string GetStrategyName()
+ {
+ return "FedAvg";
+ }
+}
diff --git a/src/FederatedLearning/Aggregators/FedBNAggregationStrategy.cs b/src/FederatedLearning/Aggregators/FedBNAggregationStrategy.cs
new file mode 100644
index 000000000..ccee14b1f
--- /dev/null
+++ b/src/FederatedLearning/Aggregators/FedBNAggregationStrategy.cs
@@ -0,0 +1,257 @@
+namespace AiDotNet.FederatedLearning.Aggregators;
+
+using AiDotNet.Interfaces;
+
+///
+/// Implements the Federated Batch Normalization (FedBN) aggregation strategy.
+///
+///
+/// FedBN is a specialized aggregation strategy that handles batch normalization layers
+/// differently from other layers in neural networks. Proposed by Li et al. in 2021,
+/// it addresses the challenge of non-IID data by keeping batch normalization parameters local.
+///
+/// For Beginners: FedBN recognizes that some parts of a neural network should remain
+/// personalized to each client rather than being averaged globally.
+///
+/// The key insight:
+/// - Batch Normalization (BN) layers learn statistics specific to each client's data
+/// - Averaging BN parameters across clients with different data distributions hurts performance
+/// - Solution: Keep BN layers local, only aggregate other layers (Conv, FC, etc.)
+///
+/// How FedBN works:
+/// 1. During aggregation, identify batch normalization layers
+/// 2. Aggregate only non-BN layers using weighted averaging
+/// 3. Keep each client's BN layers unchanged (personalized)
+/// 4. Send back global model with client-specific BN layers
+///
+/// For example, in a CNN with layers:
+/// - Conv1 (filters) → BN1 (normalization) → ReLU → Conv2 → BN2 → FC (classification)
+///
+/// FedBN aggregates:
+/// - ✓ Conv1 filters: Averaged across clients
+/// - ✗ BN1 params: Kept local to each client
+/// - ✓ Conv2 filters: Averaged across clients
+/// - ✗ BN2 params: Kept local to each client
+/// - ✓ FC weights: Averaged across clients
+///
+/// Why this matters:
+/// - Different clients may have different data ranges, distributions
+/// - Hospital A images: brightness range [0, 100]
+/// - Hospital B images: brightness range [50, 200]
+/// - Each needs different normalization parameters
+/// - Shared feature extractors (Conv layers) + personalized normalization works better
+///
+/// When to use FedBN:
+/// - Training deep neural networks (especially CNNs)
+/// - Non-IID data with distribution shift
+/// - Batch normalization or layer normalization in architecture
+/// - Want to improve accuracy without changing training much
+///
+/// Benefits:
+/// - Significantly improves accuracy on non-IID data
+/// - Simple modification to FedAvg
+/// - No additional communication cost
+/// - Each client keeps personalized normalization
+///
+/// Limitations:
+/// - Only helps when using batch normalization
+/// - Doesn't address other heterogeneity challenges
+/// - Requires identifying BN layers in model structure
+///
+/// Reference: Li, X., et al. (2021). "Federated Learning on Non-IID Data Silos: An Experimental Study."
+/// ICDE 2021.
+///
+/// The numeric type for model parameters (e.g., double, float).
+public class FedBNAggregationStrategy : IAggregationStrategy>
+ where T : struct, IComparable, IConvertible
+{
+ private readonly HashSet _batchNormLayerPatterns;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// For Beginners: Creates a FedBN aggregator that knows how to identify
+ /// batch normalization layers in your model.
+ ///
+ /// Common BN layer naming patterns:
+ /// - "bn", "batchnorm", "batch_norm": Explicit BN layers
+ /// - "gamma", "beta": BN trainable parameters
+ /// - "running_mean", "running_var": BN statistics
+ ///
+ /// The strategy will exclude these from aggregation, keeping them client-specific.
+ ///
+ ///
+ /// Patterns to identify batch normalization layers. If null, uses default patterns.
+ ///
+ public FedBNAggregationStrategy(HashSet? batchNormLayerPatterns = null)
+ {
+ // Default patterns for identifying batch normalization layers
+ _batchNormLayerPatterns = batchNormLayerPatterns ?? new HashSet(StringComparer.OrdinalIgnoreCase)
+ {
+ "bn",
+ "batchnorm",
+ "batch_norm",
+ "batch_normalization",
+ "gamma",
+ "beta",
+ "running_mean",
+ "running_var",
+ "moving_mean",
+ "moving_variance"
+ };
+ }
+
+ ///
+ /// Aggregates client models while keeping batch normalization layers local.
+ ///
+ ///
+ /// This method implements selective aggregation:
+ ///
+ /// For Beginners: Think of this as a smart averaging that knows some parameters
+ /// should stay personal (like BN layers) while others should be shared (like Conv layers).
+ ///
+ /// Step-by-step process:
+ /// 1. For each layer in the model:
+ /// - Check if it's a batch normalization layer (by name matching)
+ /// - If BN: Keep first client's values (could be any client's, they stay local)
+ /// - If not BN: Compute weighted average across all clients
+ /// 2. Return aggregated model
+ ///
+ /// Mathematical formulation:
+ /// For non-BN layers:
+ /// w_global[layer] = Σ(n_k / n_total) × w_k[layer]
+ ///
+ /// For BN layers:
+ /// w_global[layer] = w_client[layer] (keeps local values)
+ ///
+ /// For example, with 3 clients and a model with:
+ /// - "conv1": [0.5, 0.6, 0.7] at clients → Average these
+ /// - "bn1_gamma": [1.0, 1.2, 0.9] at clients → Keep local (don't average)
+ /// - "conv2": [0.3, 0.4, 0.5] at clients → Average these
+ /// - "bn2_beta": [0.1, 0.2, 0.15] at clients → Keep local (don't average)
+ ///
+ /// Note: In practice, each client would maintain their own BN parameters.
+ /// The "global" model returned includes BN params that each client will replace
+ /// with their local version upon receiving the update.
+ ///
+ /// Dictionary mapping client IDs to their model parameters.
+ /// Dictionary mapping client IDs to their sample counts (weights).
+ /// The aggregated global model parameters with BN layers excluded from aggregation.
+ public Dictionary Aggregate(
+ Dictionary> clientModels,
+ Dictionary clientWeights)
+ {
+ if (clientModels == null || clientModels.Count == 0)
+ {
+ throw new ArgumentException("Client models cannot be null or empty.", nameof(clientModels));
+ }
+
+ if (clientWeights == null || clientWeights.Count == 0)
+ {
+ throw new ArgumentException("Client weights cannot be null or empty.", nameof(clientWeights));
+ }
+
+ double totalWeight = clientWeights.Values.Sum();
+
+ if (totalWeight <= 0)
+ {
+ throw new ArgumentException("Total weight must be positive.", nameof(clientWeights));
+ }
+
+ var firstClientModel = clientModels.First().Value;
+ var aggregatedModel = new Dictionary();
+
+ // Process each layer
+ foreach (var layerName in firstClientModel.Keys)
+ {
+ // Check if this is a batch normalization layer
+ bool isBatchNormLayer = IsBatchNormalizationLayer(layerName);
+
+ if (isBatchNormLayer)
+ {
+ // For BN layers, keep the first client's parameters (they stay local)
+ // In practice, each client will maintain their own BN params
+ aggregatedModel[layerName] = (T[])firstClientModel[layerName].Clone();
+ }
+ else
+ {
+ // For non-BN layers, perform weighted aggregation (like FedAvg)
+ aggregatedModel[layerName] = new T[firstClientModel[layerName].Length];
+
+ foreach (var clientId in clientModels.Keys)
+ {
+ var clientModel = clientModels[clientId];
+ var clientWeight = clientWeights[clientId];
+ double normalizedWeight = clientWeight / totalWeight;
+
+ var clientParams = clientModel[layerName];
+ var aggregatedParams = aggregatedModel[layerName];
+
+ for (int i = 0; i < clientParams.Length; i++)
+ {
+ double currentValue = Convert.ToDouble(aggregatedParams[i]);
+ double clientValue = Convert.ToDouble(clientParams[i]);
+ double weightedValue = currentValue + (normalizedWeight * clientValue);
+
+ aggregatedParams[i] = (T)Convert.ChangeType(weightedValue, typeof(T));
+ }
+ }
+ }
+ }
+
+ return aggregatedModel;
+ }
+
+ ///
+ /// Determines whether a layer is a batch normalization layer based on its name.
+ ///
+ ///
+ /// For Beginners: This checks if a layer name contains any of the known
+ /// batch normalization patterns.
+ ///
+ /// For example:
+ /// - "conv1_weights" → false (not BN)
+ /// - "bn1_gamma" → true (contains "bn")
+ /// - "batch_norm_2_beta" → true (contains "batch_norm")
+ /// - "fc_bias" → false (not BN)
+ ///
+ /// The name of the layer to check.
+ /// True if the layer is a batch normalization layer, false otherwise.
+ private bool IsBatchNormalizationLayer(string layerName)
+ {
+ string lowerLayerName = layerName.ToLowerInvariant();
+
+ foreach (var pattern in _batchNormLayerPatterns)
+ {
+ if (lowerLayerName.Contains(pattern.ToLowerInvariant()))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Gets the name of the aggregation strategy.
+ ///
+ /// The string "FedBN".
+ public string GetStrategyName()
+ {
+ return "FedBN";
+ }
+
+ ///
+ /// Gets the batch normalization layer patterns used for identification.
+ ///
+ ///
+ /// For Beginners: Returns the list of patterns used to recognize which
+ /// layers are batch normalization layers.
+ ///
+ /// A set of BN layer patterns.
+ public IReadOnlySet GetBatchNormPatterns()
+ {
+ return _batchNormLayerPatterns;
+ }
+}
diff --git a/src/FederatedLearning/Aggregators/FedProxAggregationStrategy.cs b/src/FederatedLearning/Aggregators/FedProxAggregationStrategy.cs
new file mode 100644
index 000000000..0f033d666
--- /dev/null
+++ b/src/FederatedLearning/Aggregators/FedProxAggregationStrategy.cs
@@ -0,0 +1,194 @@
+namespace AiDotNet.FederatedLearning.Aggregators;
+
+using AiDotNet.Interfaces;
+
+///
+/// Implements the Federated Proximal (FedProx) aggregation strategy.
+///
+///
+/// FedProx is an extension of FedAvg that handles system and statistical heterogeneity
+/// in federated learning. It was proposed by Li et al. in 2020 to address challenges
+/// when clients have different computational capabilities or data distributions.
+///
+/// For Beginners: FedProx is like FedAvg with a "safety rope" that prevents
+/// individual clients from pulling the shared model too far in their own direction.
+///
+/// Key differences from FedAvg:
+/// 1. Adds a proximal term to local training objective
+/// 2. Prevents client models from deviating too much from global model
+/// 3. Improves convergence when clients have heterogeneous data or capabilities
+///
+/// How FedProx works:
+/// During local training, each client minimizes:
+/// Local Loss + (μ/2) × ||w - w_global||²
+///
+/// where:
+/// - Local Loss: Standard loss on client's data
+/// - μ (mu): Proximal term coefficient (controls constraint strength)
+/// - w: Client's current model weights
+/// - w_global: Global model weights received from server
+/// - ||w - w_global||²: Squared distance between client and global model
+///
+/// For example, with μ = 0.01:
+/// - Client trains on local data
+/// - Proximal term penalizes large deviations from global model
+/// - If client's data is very different, can still adapt but with limitation
+/// - Prevents overfitting to local data distribution
+///
+/// When to use FedProx:
+/// - Non-IID data (different distributions across clients)
+/// - System heterogeneity (some clients much slower/faster)
+/// - Want more stable convergence than FedAvg
+/// - Stragglers problem (some clients take much longer)
+///
+/// Benefits:
+/// - Better convergence on non-IID data
+/// - More robust to stragglers
+/// - Theoretically proven convergence guarantees
+/// - Small computational overhead
+///
+/// Limitations:
+/// - Requires tuning μ parameter
+/// - Slightly slower local training per iteration
+/// - May converge slower if μ is too large
+///
+/// Reference: Li, T., et al. (2020). "Federated Optimization in Heterogeneous Networks."
+/// MLSys 2020.
+///
+/// The numeric type for model parameters (e.g., double, float).
+public class FedProxAggregationStrategy : IAggregationStrategy>
+ where T : struct, IComparable, IConvertible
+{
+ private readonly double _mu;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// For Beginners: Creates a FedProx aggregator with a specified proximal term strength.
+ ///
+ /// The μ (mu) parameter controls the trade-off between local adaptation and global consistency:
+ /// - μ = 0: Equivalent to FedAvg (no constraint)
+ /// - μ = 0.01: Mild constraint (recommended starting point)
+ /// - μ = 0.1: Moderate constraint
+ /// - μ = 1.0+: Strong constraint (may be too restrictive)
+ ///
+ /// Recommendations:
+ /// - Start with μ = 0.01
+ /// - Increase if convergence is unstable
+ /// - Decrease if convergence is too slow
+ ///
+ /// The proximal term coefficient (typically 0.01 to 1.0).
+ public FedProxAggregationStrategy(double mu = 0.01)
+ {
+ if (mu < 0)
+ {
+ throw new ArgumentException("Mu must be non-negative.", nameof(mu));
+ }
+
+ _mu = mu;
+ }
+
+ ///
+ /// Aggregates client models using FedProx weighted averaging.
+ ///
+ ///
+ /// The aggregation step in FedProx is identical to FedAvg. The key difference is in
+ /// the local training objective (which includes the proximal term), not in aggregation.
+ ///
+ /// For Beginners: At the server side, FedProx aggregates just like FedAvg.
+ /// The magic happens during client-side training where the proximal term keeps
+ /// client models from straying too far.
+ ///
+ /// Aggregation formula (same as FedAvg):
+ /// w_global = Σ(n_k / n_total) × w_k
+ ///
+ /// The proximal term μ affects how w_k is computed during local training, but not
+ /// how we aggregate the models here.
+ ///
+ /// For implementation in local training (not shown here):
+ /// - Gradient = ∇Loss + μ(w - w_global)
+ /// - This additional term pulls weights towards global model
+ ///
+ /// Dictionary mapping client IDs to their model parameters.
+ /// Dictionary mapping client IDs to their sample counts (weights).
+ /// The aggregated global model parameters.
+ public Dictionary Aggregate(
+ Dictionary> clientModels,
+ Dictionary clientWeights)
+ {
+ if (clientModels == null || clientModels.Count == 0)
+ {
+ throw new ArgumentException("Client models cannot be null or empty.", nameof(clientModels));
+ }
+
+ if (clientWeights == null || clientWeights.Count == 0)
+ {
+ throw new ArgumentException("Client weights cannot be null or empty.", nameof(clientWeights));
+ }
+
+ // Calculate total weight
+ double totalWeight = clientWeights.Values.Sum();
+
+ if (totalWeight <= 0)
+ {
+ throw new ArgumentException("Total weight must be positive.", nameof(clientWeights));
+ }
+
+ // Initialize aggregated model
+ var firstClientModel = clientModels.First().Value;
+ var aggregatedModel = new Dictionary();
+
+ foreach (var layerName in firstClientModel.Keys)
+ {
+ aggregatedModel[layerName] = new T[firstClientModel[layerName].Length];
+ }
+
+ // Perform weighted aggregation (same as FedAvg)
+ foreach (var clientId in clientModels.Keys)
+ {
+ var clientModel = clientModels[clientId];
+ var clientWeight = clientWeights[clientId];
+ double normalizedWeight = clientWeight / totalWeight;
+
+ foreach (var layerName in clientModel.Keys)
+ {
+ var clientParams = clientModel[layerName];
+ var aggregatedParams = aggregatedModel[layerName];
+
+ for (int i = 0; i < clientParams.Length; i++)
+ {
+ double currentValue = Convert.ToDouble(aggregatedParams[i]);
+ double clientValue = Convert.ToDouble(clientParams[i]);
+ double weightedValue = currentValue + (normalizedWeight * clientValue);
+
+ aggregatedParams[i] = (T)Convert.ChangeType(weightedValue, typeof(T));
+ }
+ }
+ }
+
+ return aggregatedModel;
+ }
+
+ ///
+ /// Gets the name of the aggregation strategy.
+ ///
+ /// A string indicating "FedProx" with the μ parameter value.
+ public string GetStrategyName()
+ {
+ return $"FedProx(μ={_mu})";
+ }
+
+ ///
+ /// Gets the proximal term coefficient μ.
+ ///
+ ///
+ /// For Beginners: Returns the strength of the constraint that keeps client
+ /// models from deviating too far from the global model.
+ ///
+ /// The μ parameter value.
+ public double GetMu()
+ {
+ return _mu;
+ }
+}
diff --git a/src/FederatedLearning/Personalization/PersonalizedFederatedLearning.cs b/src/FederatedLearning/Personalization/PersonalizedFederatedLearning.cs
new file mode 100644
index 000000000..6f28d480a
--- /dev/null
+++ b/src/FederatedLearning/Personalization/PersonalizedFederatedLearning.cs
@@ -0,0 +1,371 @@
+namespace AiDotNet.FederatedLearning.Personalization;
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+///
+/// Implements personalized federated learning where each client maintains some client-specific parameters.
+///
+///
+/// Personalized Federated Learning (PFL) addresses the challenge of heterogeneous data distributions
+/// across clients by allowing each client to maintain personalized model components while still
+/// benefiting from collaborative learning.
+///
+/// For Beginners: Personalized FL is like having a shared textbook but personal notes.
+/// Everyone learns from the same core material (global model) but adapts it to their specific
+/// needs (personalized layers).
+///
+/// Key concept:
+/// - Global layers: Shared across all clients, learn common patterns
+/// - Personalized layers: Client-specific, adapt to local data distribution
+/// - Clients train both but only global layers are aggregated
+///
+/// How it works:
+/// 1. Model is split into global and personalized parts
+/// 2. During local training, both parts are updated
+/// 3. Only global parts are sent to server for aggregation
+/// 4. Personalized parts stay on the client
+/// 5. Client receives updated global parts and keeps personalized parts
+///
+/// For example, in healthcare:
+/// - Hospital A: Urban population, young average age
+/// - Hospital B: Rural population, old average age
+/// - Hospital C: Suburban population, mixed age
+///
+/// Model structure:
+/// - Global layers (shared): General disease detection features
+/// - Personalized layers: Adapt to local demographics
+///
+/// Benefits:
+/// - Better performance on non-IID data
+/// - Each client gets a model optimized for their data
+/// - Preserves privacy (personalized parts never leave client)
+/// - Relatively simple to implement
+///
+/// Common approaches:
+/// 1. Layer-wise personalization: Last few layers personalized
+/// 2. Feature-wise personalization: Some features personalized
+/// 3. Meta-learning: Learn how to adapt quickly to local data
+/// 4. Multi-task learning: Treat each client as a separate task
+///
+/// When to use PFL:
+/// - Clients have significantly different data distributions
+/// - Standard FedAvg performance is poor
+/// - Can afford client-side storage for personalized parameters
+/// - Want better local performance even at cost of global performance
+///
+/// Limitations:
+/// - Requires more storage on client (for personalized params)
+/// - May sacrifice some global model quality
+/// - Need to choose which layers to personalize
+/// - Risk of overfitting to local data
+///
+/// Reference:
+/// - Wang, K., et al. (2019). "Federated Evaluation of On-device Personalization." arXiv preprint.
+/// - Fallah, A., et al. (2020). "Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach." NeurIPS 2020.
+///
+/// The numeric type for model parameters (e.g., double, float).
+public class PersonalizedFederatedLearning
+ where T : struct, IComparable, IConvertible
+{
+ private readonly double _personalizationFraction;
+ private readonly HashSet _personalizedLayers;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// For Beginners: Sets up personalized federated learning with a specified
+ /// fraction of the model kept personalized.
+ ///
+ /// The personalization fraction determines the split:
+ /// - 0.0: No personalization (standard federated learning)
+ /// - 0.2: Last 20% of layers personalized (common choice)
+ /// - 0.5: Half personalized, half global
+ /// - 1.0: Fully personalized (no collaboration)
+ ///
+ /// Typical strategy:
+ /// - Keep early layers (feature extractors) global
+ /// - Keep late layers (task-specific) personalized
+ ///
+ /// For example, in a CNN:
+ /// - Conv layers 1-3: Global (learn general visual features)
+ /// - Conv layers 4-5: Personalized (adapt to local image characteristics)
+ /// - FC layers: Personalized (task-specific classification)
+ ///
+ ///
+ /// The fraction of model layers to keep personalized (0.0 to 1.0).
+ /// Typically 0.2 for last 20% of layers.
+ ///
+ public PersonalizedFederatedLearning(double personalizationFraction = 0.2)
+ {
+ if (personalizationFraction < 0.0 || personalizationFraction > 1.0)
+ {
+ throw new ArgumentException("Personalization fraction must be between 0 and 1.", nameof(personalizationFraction));
+ }
+
+ _personalizationFraction = personalizationFraction;
+ _personalizedLayers = new HashSet();
+ }
+
+ ///
+ /// Identifies which layers should be personalized based on the model structure.
+ ///
+ ///
+ /// For Beginners: This decides which parts of the model will be personalized
+ /// vs. shared globally.
+ ///
+ /// Common strategies:
+ /// 1. Last-N layers: Personalize the final layers (default)
+ /// 2. By name: Personalize layers matching specific patterns
+ /// 3. By type: Personalize certain layer types (e.g., batch norm)
+ ///
+ /// For example, with 10 layers and 20% personalization:
+ /// - Layers 0-7: Global (shared)
+ /// - Layers 8-9: Personalized (last 2 layers = 20%)
+ ///
+ /// The intuition:
+ /// - Early layers learn low-level features (edges, textures) → should be shared
+ /// - Late layers learn high-level, task-specific features → can be personalized
+ ///
+ /// The model structure with layer names.
+ /// The strategy for selecting personalized layers ("last_n", "by_pattern").
+ /// Optional patterns for "by_pattern" strategy.
+ public void IdentifyPersonalizedLayers(
+ Dictionary modelStructure,
+ string strategy = "last_n",
+ HashSet? customPatterns = null)
+ {
+ if (modelStructure == null || modelStructure.Count == 0)
+ {
+ throw new ArgumentException("Model structure cannot be null or empty.", nameof(modelStructure));
+ }
+
+ _personalizedLayers.Clear();
+
+ if (strategy == "last_n")
+ {
+ // Personalize the last N% of layers
+ int totalLayers = modelStructure.Count;
+ int personalizedCount = (int)Math.Ceiling(totalLayers * _personalizationFraction);
+
+ var layerNames = modelStructure.Keys.ToList();
+
+ // Take the last personalizedCount layers
+ for (int i = totalLayers - personalizedCount; i < totalLayers; i++)
+ {
+ _personalizedLayers.Add(layerNames[i]);
+ }
+ }
+ else if (strategy == "by_pattern" && customPatterns != null)
+ {
+ // Personalize layers matching specific patterns
+ foreach (var layerName in modelStructure.Keys)
+ {
+ foreach (var pattern in customPatterns)
+ {
+ if (layerName.Contains(pattern, StringComparison.OrdinalIgnoreCase))
+ {
+ _personalizedLayers.Add(layerName);
+ break;
+ }
+ }
+ }
+ }
+ else
+ {
+ throw new ArgumentException($"Unknown strategy: {strategy}. Use 'last_n' or 'by_pattern'.", nameof(strategy));
+ }
+ }
+
+ ///
+ /// Separates a model into global and personalized components.
+ ///
+ ///
+ /// For Beginners: Splits the model into two parts:
+ /// - Global part: Will be sent to server and aggregated
+ /// - Personalized part: Stays on client
+ ///
+ /// For example:
+ /// Original model: {layer1: [...], layer2: [...], layer3: [...], layer4: [...]}
+ /// If layers 3-4 are personalized:
+ /// - Global: {layer1: [...], layer2: [...]}
+ /// - Personalized: {layer3: [...], layer4: [...]}
+ ///
+ /// This separation enables:
+ /// - Efficient communication (only send global parts)
+ /// - Privacy (personalized parts never leave client)
+ /// - Flexibility (different personalization per client)
+ ///
+ /// The complete model with all layers.
+ /// Output: The global layers to be aggregated.
+ /// Output: The personalized layers to keep local.
+ public void SeparateModel(
+ Dictionary fullModel,
+ out Dictionary globalPart,
+ out Dictionary personalizedPart)
+ {
+ if (fullModel == null || fullModel.Count == 0)
+ {
+ throw new ArgumentException("Full model cannot be null or empty.", nameof(fullModel));
+ }
+
+ globalPart = new Dictionary();
+ personalizedPart = new Dictionary();
+
+ foreach (var layer in fullModel)
+ {
+ if (_personalizedLayers.Contains(layer.Key))
+ {
+ // This layer is personalized - keep local
+ personalizedPart[layer.Key] = (T[])layer.Value.Clone();
+ }
+ else
+ {
+ // This layer is global - will be aggregated
+ globalPart[layer.Key] = (T[])layer.Value.Clone();
+ }
+ }
+ }
+
+ ///
+ /// Combines global model update with client's personalized layers.
+ ///
+ ///
+ /// For Beginners: After the server aggregates global layers, each client
+ /// combines the updated global layers with their own personalized layers to form
+ /// the complete model for the next round.
+ ///
+ /// Process:
+ /// 1. Receive updated global layers from server
+ /// 2. Retrieve client's personalized layers from local storage
+ /// 3. Merge them into one complete model
+ /// 4. Ready for next round of local training
+ ///
+ /// For example:
+ /// Server sends global update: {layer1: [...], layer2: [...]}
+ /// Client has personalized: {layer3: [...], layer4: [...]}
+ /// Combined model: {layer1: [...], layer2: [...], layer3: [...], layer4: [...]}
+ ///
+ /// This ensures:
+ /// - Global knowledge is incorporated (layers 1-2 updated)
+ /// - Local adaptation is preserved (layers 3-4 unchanged)
+ /// - Model structure remains consistent
+ ///
+ /// The updated global layers from server.
+ /// The client's personalized layers.
+ /// The complete model combining both parts.
+ public Dictionary CombineModels(
+ Dictionary globalUpdate,
+ Dictionary personalizedLayers)
+ {
+ if (globalUpdate == null)
+ {
+ throw new ArgumentNullException(nameof(globalUpdate));
+ }
+
+ if (personalizedLayers == null)
+ {
+ throw new ArgumentNullException(nameof(personalizedLayers));
+ }
+
+ var combinedModel = new Dictionary();
+
+ // Add all global layers
+ foreach (var layer in globalUpdate)
+ {
+ combinedModel[layer.Key] = (T[])layer.Value.Clone();
+ }
+
+ // Add all personalized layers
+ foreach (var layer in personalizedLayers)
+ {
+ combinedModel[layer.Key] = (T[])layer.Value.Clone();
+ }
+
+ return combinedModel;
+ }
+
+ ///
+ /// Checks if a specific layer is personalized.
+ ///
+ ///
+ /// For Beginners: Returns whether a given layer should be kept local
+ /// (personalized) or sent to the server (global).
+ ///
+ /// The name of the layer to check.
+ /// True if the layer is personalized, false if global.
+ public bool IsLayerPersonalized(string layerName)
+ {
+ return _personalizedLayers.Contains(layerName);
+ }
+
+ ///
+ /// Gets the set of all personalized layer names.
+ ///
+ ///
+ /// For Beginners: Returns the list of which layers are personalized.
+ /// Useful for logging, debugging, and understanding the model split.
+ ///
+ /// A read-only set of personalized layer names.
+ public IReadOnlySet GetPersonalizedLayers()
+ {
+ return _personalizedLayers;
+ }
+
+ ///
+ /// Gets the personalization fraction.
+ ///
+ /// The fraction of layers that are personalized.
+ public double GetPersonalizationFraction()
+ {
+ return _personalizationFraction;
+ }
+
+ ///
+ /// Calculates statistics about the model split.
+ ///
+ ///
+ /// For Beginners: Provides useful information about how the model is divided:
+ /// - How many parameters are global vs. personalized
+ /// - What percentage of the model is personalized
+ /// - Communication savings from personalization
+ ///
+ /// This helps understand the trade-offs:
+ /// - More personalized → Less communication, more storage per client
+ /// - More global → More communication, less storage per client
+ ///
+ /// The complete model.
+ /// A dictionary with statistics.
+ public Dictionary GetModelSplitStatistics(Dictionary fullModel)
+ {
+ if (fullModel == null || fullModel.Count == 0)
+ {
+ throw new ArgumentException("Full model cannot be null or empty.", nameof(fullModel));
+ }
+
+ int totalParams = fullModel.Values.Sum(layer => layer.Length);
+ int personalizedParams = fullModel
+ .Where(layer => _personalizedLayers.Contains(layer.Key))
+ .Sum(layer => layer.Value.Length);
+ int globalParams = totalParams - personalizedParams;
+
+ int totalLayers = fullModel.Count;
+ int personalizedLayerCount = _personalizedLayers.Count;
+ int globalLayerCount = totalLayers - personalizedLayerCount;
+
+ return new Dictionary
+ {
+ ["total_parameters"] = totalParams,
+ ["global_parameters"] = globalParams,
+ ["personalized_parameters"] = personalizedParams,
+ ["global_parameter_fraction"] = totalParams > 0 ? (double)globalParams / totalParams : 0,
+ ["personalized_parameter_fraction"] = totalParams > 0 ? (double)personalizedParams / totalParams : 0,
+ ["total_layers"] = totalLayers,
+ ["global_layers"] = globalLayerCount,
+ ["personalized_layers"] = personalizedLayerCount,
+ ["communication_reduction"] = totalParams > 0 ? (double)personalizedParams / totalParams : 0
+ };
+ }
+}
diff --git a/src/FederatedLearning/Privacy/GaussianDifferentialPrivacy.cs b/src/FederatedLearning/Privacy/GaussianDifferentialPrivacy.cs
new file mode 100644
index 000000000..5f99c92e5
--- /dev/null
+++ b/src/FederatedLearning/Privacy/GaussianDifferentialPrivacy.cs
@@ -0,0 +1,337 @@
+namespace AiDotNet.FederatedLearning.Privacy;
+
+using AiDotNet.Interfaces;
+using System;
+
+///
+/// Implements differential privacy using the Gaussian mechanism.
+///
+///
+/// The Gaussian mechanism provides (ε, δ)-differential privacy by adding calibrated
+/// Gaussian (normal distribution) noise to model parameters. This is one of the most
+/// widely used privacy mechanisms in federated learning.
+///
+/// For Beginners: Differential privacy is like adding static to a phone conversation.
+/// You add just enough noise that individual voices can't be identified, but the overall
+/// message still gets through clearly.
+///
+/// How Gaussian Differential Privacy works:
+/// 1. Clip gradients/parameters to bound sensitivity (maximum change any single data point can cause)
+/// 2. Add Gaussian noise: noise ~ N(0, σ²) where σ depends on ε, δ, and sensitivity
+/// 3. The noise calibration ensures (ε, δ)-DP guarantee
+///
+/// Mathematical formulation:
+/// - Sensitivity Δ: Maximum L2 norm of gradient for any single training example
+/// - Noise scale σ = (Δ/ε) × sqrt(2 × ln(1.25/δ))
+/// - For each parameter w: w_private = w + N(0, σ²)
+///
+/// Privacy parameters:
+/// - ε (epsilon): Privacy budget - smaller is more private
+/// * ε = 0.1: Very strong privacy, significant noise
+/// * ε = 1.0: Strong privacy, moderate noise (recommended)
+/// * ε = 10: Weak privacy, minimal noise
+///
+/// - δ (delta): Failure probability - should be very small
+/// * Typically δ = 1/n² where n is dataset size
+/// * Common choice: δ = 1e-5
+///
+/// For example, protecting hospital patient data:
+/// - Original gradient: [0.5, -0.3, 0.8, -0.2]
+/// - Clip to max norm 1.0: [0.45, -0.27, 0.72, -0.18] (clipped)
+/// - Add Gaussian noise with σ=0.1: [0.47, -0.29, 0.75, -0.21]
+/// - Result: Individual patient influence is masked by noise
+///
+/// Privacy composition:
+/// - Each time you share data, you consume privacy budget
+/// - After T rounds with ε per round: total ε_total ≈ ε × sqrt(2T × ln(1/δ))
+/// - This is more efficient than naive composition (ε × T)
+///
+/// Trade-offs:
+/// - More privacy (smaller ε) → more noise → lower accuracy
+/// - Less privacy (larger ε) → less noise → higher accuracy
+/// - Must find acceptable balance for your application
+///
+/// When to use Gaussian DP:
+/// - Need provable privacy guarantees
+/// - Working with sensitive data (healthcare, finance)
+/// - Regulatory requirements (GDPR, HIPAA)
+/// - Publishing models or sharing with untrusted parties
+///
+/// Reference: Dwork, C., & Roth, A. (2014). "The Algorithmic Foundations of Differential Privacy."
+/// Abadi, M., et al. (2016). "Deep Learning with Differential Privacy." CCS 2016.
+///
+/// The numeric type for model parameters (e.g., double, float).
+public class GaussianDifferentialPrivacy : IPrivacyMechanism>
+ where T : struct, IComparable, IConvertible
+{
+ private double _privacyBudgetConsumed;
+ private readonly double _clipNorm;
+ private readonly Random _random;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// For Beginners: Creates a differential privacy mechanism with a specified
+ /// gradient clipping threshold.
+ ///
+ /// Gradient clipping (clipNorm) is crucial for DP:
+ /// - Bounds the maximum influence any single data point can have
+ /// - Makes noise calibration possible
+ /// - Common values: 0.1 - 10.0 depending on model and data
+ ///
+ /// Lower clipNorm:
+ /// - Stronger privacy guarantee
+ /// - More aggressive clipping
+ /// - May slow convergence
+ ///
+ /// Higher clipNorm:
+ /// - Less clipping
+ /// - Faster convergence
+ /// - Requires more noise for same privacy
+ ///
+ /// Recommendations:
+ /// - Start with clipNorm = 1.0
+ /// - Monitor gradient norms during training
+ /// - Adjust based on typical gradient magnitudes
+ ///
+ /// The maximum L2 norm for gradient clipping (sensitivity bound).
+ /// Optional random seed for reproducibility.
+ public GaussianDifferentialPrivacy(double clipNorm = 1.0, int? randomSeed = null)
+ {
+ if (clipNorm <= 0)
+ {
+ throw new ArgumentException("Clip norm must be positive.", nameof(clipNorm));
+ }
+
+ _clipNorm = clipNorm;
+ _privacyBudgetConsumed = 0.0;
+ _random = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random();
+ }
+
+ ///
+ /// Applies differential privacy to model parameters by adding calibrated Gaussian noise.
+ ///
+ ///
+ /// This method implements the Gaussian mechanism for (ε, δ)-differential privacy:
+ ///
+ /// For Beginners: This adds carefully calculated random noise to protect privacy
+ /// while maintaining model utility.
+ ///
+ /// Step-by-step process:
+ /// 1. Calculate current L2 norm of model parameters
+ /// 2. If norm > clipNorm, scale down parameters to clipNorm
+ /// 3. Calculate noise scale σ based on ε, δ, and sensitivity
+ /// 4. Add Gaussian noise N(0, σ²) to each parameter
+ /// 5. Update privacy budget consumed
+ ///
+ /// Mathematical details:
+ /// - Sensitivity Δ = clipNorm (worst-case parameter change)
+ /// - σ = (Δ/ε) × sqrt(2 × ln(1.25/δ))
+ /// - Noise ~ N(0, σ²) added to each parameter independently
+ ///
+ /// For example, with ε=1.0, δ=1e-5, clipNorm=1.0:
+ /// - σ = (1.0/1.0) × sqrt(2 × ln(125000)) ≈ 4.7
+ /// - Each parameter gets noise from N(0, 4.7²)
+ /// - Original params: [0.5, -0.3, 0.8]
+ /// - Noisy params: [0.52, -0.35, 0.83] (example with small noise realization)
+ ///
+ /// Privacy accounting:
+ /// - Each call consumes ε privacy budget
+ /// - Total budget accumulates: ε_total = ε_1 + ε_2 + ... (simplified)
+ /// - Advanced: Use Rényi DP for tighter composition bounds
+ ///
+ /// The model parameters to add noise to.
+ /// Privacy budget for this operation (smaller = more private).
+ /// Failure probability (typically 1e-5 or smaller).
+ /// The model with differential privacy applied.
+ public Dictionary ApplyPrivacy(Dictionary model, double epsilon, double delta)
+ {
+ if (model == null || model.Count == 0)
+ {
+ throw new ArgumentException("Model cannot be null or empty.", nameof(model));
+ }
+
+ if (epsilon <= 0)
+ {
+ throw new ArgumentException("Epsilon must be positive.", nameof(epsilon));
+ }
+
+ if (delta <= 0 || delta >= 1)
+ {
+ throw new ArgumentException("Delta must be in (0, 1).", nameof(delta));
+ }
+
+ // Create a copy of the model
+ var noisyModel = new Dictionary();
+ foreach (var layer in model)
+ {
+ noisyModel[layer.Key] = (T[])layer.Value.Clone();
+ }
+
+ // Step 1: Gradient clipping - Calculate L2 norm of all parameters
+ double l2Norm = CalculateL2Norm(noisyModel);
+
+ // If norm exceeds clip threshold, scale down
+ if (l2Norm > _clipNorm)
+ {
+ double scaleFactor = _clipNorm / l2Norm;
+
+ foreach (var layerName in noisyModel.Keys)
+ {
+ var parameters = noisyModel[layerName];
+ for (int i = 0; i < parameters.Length; i++)
+ {
+ double value = Convert.ToDouble(parameters[i]);
+ parameters[i] = (T)Convert.ChangeType(value * scaleFactor, typeof(T));
+ }
+ }
+ }
+
+ // Step 2: Calculate noise scale based on Gaussian mechanism
+ // σ = (Δ/ε) × sqrt(2 × ln(1.25/δ))
+ // where Δ = clipNorm (sensitivity)
+ double sensitivity = _clipNorm;
+ double noiseSigma = (sensitivity / epsilon) * Math.Sqrt(2.0 * Math.Log(1.25 / delta));
+
+ // Step 3: Add Gaussian noise to each parameter
+ foreach (var layerName in noisyModel.Keys)
+ {
+ var parameters = noisyModel[layerName];
+ for (int i = 0; i < parameters.Length; i++)
+ {
+ double value = Convert.ToDouble(parameters[i]);
+ double noise = GenerateGaussianNoise(0.0, noiseSigma);
+ parameters[i] = (T)Convert.ChangeType(value + noise, typeof(T));
+ }
+ }
+
+ // Update privacy budget consumed
+ _privacyBudgetConsumed += epsilon;
+
+ return noisyModel;
+ }
+
+ ///
+ /// Calculates the L2 norm (Euclidean norm) of all model parameters.
+ ///
+ ///
+ /// For Beginners: L2 norm is the "length" of the parameter vector in
+ /// high-dimensional space. It's calculated as sqrt(sum of squares).
+ ///
+ /// For example, with parameters [3, 4]:
+ /// - L2 norm = sqrt(3² + 4²) = sqrt(9 + 16) = sqrt(25) = 5
+ ///
+ /// Used for gradient clipping to bound sensitivity.
+ ///
+ /// The model to calculate norm for.
+ /// The L2 norm of all parameters.
+ private double CalculateL2Norm(Dictionary model)
+ {
+ double sumOfSquares = 0.0;
+
+ foreach (var layer in model.Values)
+ {
+ foreach (var param in layer)
+ {
+ double value = Convert.ToDouble(param);
+ sumOfSquares += value * value;
+ }
+ }
+
+ return Math.Sqrt(sumOfSquares);
+ }
+
+ ///
+ /// Generates a sample from a Gaussian (normal) distribution.
+ ///
+ ///
+ /// Uses the Box-Muller transform to generate Gaussian random variables from
+ /// uniform random variables.
+ ///
+ /// For Beginners: This creates random noise from a bell curve distribution.
+ /// Most noise values will be close to the mean, with rare large values.
+ ///
+ /// Box-Muller transform:
+ /// - Generate two uniform random numbers U1, U2 in [0, 1]
+ /// - Z = sqrt(-2 × ln(U1)) × cos(2π × U2)
+ /// - Z follows standard normal N(0, 1)
+ /// - Scale and shift: X = mean + sigma × Z
+ ///
+ /// The mean of the Gaussian distribution.
+ /// The standard deviation of the Gaussian distribution.
+ /// A random sample from N(mean, sigma²).
+ private double GenerateGaussianNoise(double mean, double sigma)
+ {
+ // Box-Muller transform
+ double u1 = 1.0 - _random.NextDouble(); // Uniform(0,1]
+ double u2 = 1.0 - _random.NextDouble();
+ double standardNormal = Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Cos(2.0 * Math.PI * u2);
+
+ return mean + sigma * standardNormal;
+ }
+
+ ///
+ /// Gets the total privacy budget consumed so far.
+ ///
+ ///
+ /// For Beginners: Returns how much privacy budget has been used up.
+ /// Privacy budget is cumulative - once spent, it's gone.
+ ///
+ /// For example:
+ /// - Round 1: ε = 0.5 consumed, total = 0.5
+ /// - Round 2: ε = 0.5 consumed, total = 1.0
+ /// - Round 3: ε = 0.5 consumed, total = 1.5
+ ///
+ /// If you started with total budget 10.0, you have 8.5 remaining.
+ ///
+ /// Note: This uses basic composition. Advanced composition (Rényi DP) gives
+ /// tighter bounds and would show less budget consumed.
+ ///
+ /// The cumulative privacy budget (epsilon) consumed.
+ public double GetPrivacyBudgetConsumed()
+ {
+ return _privacyBudgetConsumed;
+ }
+
+ ///
+ /// Gets the name of the privacy mechanism.
+ ///
+ /// A string describing the mechanism.
+ public string GetMechanismName()
+ {
+ return $"Gaussian DP (clip={_clipNorm})";
+ }
+
+ ///
+ /// Gets the gradient clipping norm used for sensitivity bounding.
+ ///
+ ///
+ /// For Beginners: Returns the maximum allowed parameter norm.
+ /// Parameters larger than this are scaled down before adding noise.
+ ///
+ /// The clipping norm value.
+ public double GetClipNorm()
+ {
+ return _clipNorm;
+ }
+
+ ///
+ /// Resets the privacy budget counter.
+ ///
+ ///
+ /// For Beginners: Resets the privacy budget tracker to zero.
+ ///
+ /// WARNING: This should only be used when starting a completely new training run.
+ /// Do not reset during active training as it would give false privacy accounting.
+ ///
+ /// Use cases:
+ /// - Starting new experiment with same mechanism instance
+ /// - Testing and debugging
+ /// - Separate training phases with independent privacy guarantees
+ ///
+ public void ResetPrivacyBudget()
+ {
+ _privacyBudgetConsumed = 0.0;
+ }
+}
diff --git a/src/FederatedLearning/Privacy/SecureAggregation.cs b/src/FederatedLearning/Privacy/SecureAggregation.cs
new file mode 100644
index 000000000..efe999339
--- /dev/null
+++ b/src/FederatedLearning/Privacy/SecureAggregation.cs
@@ -0,0 +1,381 @@
+namespace AiDotNet.FederatedLearning.Privacy;
+
+using System;
+using System.Security.Cryptography;
+
+///
+/// Implements secure aggregation for federated learning using cryptographic techniques.
+///
+///
+/// Secure aggregation is a cryptographic protocol that allows a server to compute the sum
+/// of client updates without seeing individual contributions. Only the final aggregate is
+/// visible to the server.
+///
+/// For Beginners: Secure aggregation is like a secret ballot election where votes
+/// are counted but individual votes remain private.
+///
+/// How it works (simplified):
+/// 1. Each client generates pairwise secret keys with other clients
+/// 2. Clients mask their model updates with these secret keys
+/// 3. Server receives masked updates: masked_update_i = update_i + Σ(secrets_ij)
+/// 4. Secret masks cancel out when summing: Σ(masked_update_i) = Σ(update_i)
+/// 5. Server gets the sum without seeing individual updates
+///
+/// Example with 3 clients:
+/// - Client 1 shares secrets: s₁₂ with Client 2, s₁₃ with Client 3
+/// - Client 2 shares secrets: s₂₁ with Client 1, s₂₃ with Client 3
+/// - Client 3 shares secrets: s₃₁ with Client 1, s₃₂ with Client 2
+///
+/// Note: s₁₂ = -s₂₁ (secrets cancel in pairs)
+///
+/// Client 1 sends: update₁ + s₁₂ + s₁₃
+/// Client 2 sends: update₂ + s₂₁ + s₂₃
+/// Client 3 sends: update₃ + s₃₁ + s₃₂
+///
+/// Server computes sum:
+/// (update₁ + s₁₂ + s₁₃) + (update₂ + s₂₁ + s₂₃) + (update₃ + s₃₁ + s₃₂)
+/// = update₁ + update₂ + update₃ + (s₁₂ + s₂₁) + (s₁₃ + s₃₁) + (s₂₃ + s₃₂)
+/// = update₁ + update₂ + update₃ + 0 + 0 + 0
+/// = Σ(updates) ← Only this is visible to server!
+///
+/// This implementation uses a simplified version with random masking for demonstration.
+/// Production systems should use proper cryptographic protocols like:
+/// - Bonawitz et al.'s Secure Aggregation protocol
+/// - Threshold homomorphic encryption
+/// - Secret sharing schemes
+///
+/// Benefits:
+/// - Server cannot see individual client updates
+/// - Protects against honest-but-curious server
+/// - No trusted third party needed
+/// - Computation overhead is reasonable
+///
+/// Limitations:
+/// - Requires coordination between clients
+/// - All (or threshold) clients must participate for masks to cancel
+/// - Dropout handling requires additional mechanisms
+/// - Communication overhead for key exchange
+///
+/// When to use Secure Aggregation:
+/// - Don't fully trust the central server
+/// - Regulatory requirements for data protection
+/// - Want cryptographic privacy guarantees
+/// - Willing to handle additional complexity
+///
+/// Can be combined with differential privacy for stronger protection:
+/// - Secure aggregation: Protects individual updates from server
+/// - Differential privacy: Protects individual data points from anyone
+///
+/// Reference: Bonawitz, K., et al. (2017). "Practical Secure Aggregation for Privacy-Preserving
+/// Machine Learning." CCS 2017.
+///
+/// The numeric type for model parameters (e.g., double, float).
+public class SecureAggregation
+ where T : struct, IComparable, IConvertible
+{
+ private readonly Dictionary> _pairwiseSecrets;
+ private readonly Random _random;
+ private readonly int _parameterCount;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// For Beginners: Sets up the secure aggregation protocol for a specific
+ /// number of model parameters.
+ ///
+ /// In practice, this would involve:
+ /// - Secure key exchange between clients
+ /// - Authenticated channels
+ /// - Agreement on random seed for deterministic mask generation
+ ///
+ /// This simplified implementation uses pseudorandom masks that cancel out.
+ ///
+ /// The total number of model parameters to protect.
+ /// Optional random seed for reproducibility.
+ public SecureAggregation(int parameterCount, int? randomSeed = null)
+ {
+ if (parameterCount <= 0)
+ {
+ throw new ArgumentException("Parameter count must be positive.", nameof(parameterCount));
+ }
+
+ _parameterCount = parameterCount;
+ _pairwiseSecrets = new Dictionary>();
+ _random = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random();
+ }
+
+ ///
+ /// Generates pairwise secrets between all clients.
+ ///
+ ///
+ /// For Beginners: This creates secret keys that clients will use to mask
+ /// their updates. The secrets are designed so they cancel out when aggregated.
+ ///
+ /// For each pair of clients (i, j):
+ /// - Generate random secret s_ij
+ /// - Set s_ji = -s_ij (so they cancel: s_ij + s_ji = 0)
+ ///
+ /// In production, this would use:
+ /// - Diffie-Hellman key exchange
+ /// - Public key infrastructure
+ /// - Secure random number generation
+ ///
+ /// List of all participating client IDs.
+ public void GeneratePairwiseSecrets(List clientIds)
+ {
+ if (clientIds == null || clientIds.Count < 2)
+ {
+ throw new ArgumentException("Need at least 2 clients for secure aggregation.", nameof(clientIds));
+ }
+
+ _pairwiseSecrets.Clear();
+
+ // Generate pairwise secrets for each pair of clients
+ for (int i = 0; i < clientIds.Count; i++)
+ {
+ int clientI = clientIds[i];
+ _pairwiseSecrets[clientI] = new Dictionary();
+
+ for (int j = i + 1; j < clientIds.Count; j++)
+ {
+ int clientJ = clientIds[j];
+
+ // Generate random secret for this pair
+ double[] secret = new double[_parameterCount];
+ for (int k = 0; k < _parameterCount; k++)
+ {
+ // Use cryptographically secure random in production
+ secret[k] = (_random.NextDouble() - 0.5) * 2.0; // Range: [-1, 1]
+ }
+
+ // Store secret for client i with respect to client j
+ _pairwiseSecrets[clientI][clientJ] = secret;
+
+ // Store negated secret for client j with respect to client i
+ // This ensures secrets cancel: secret_ij + secret_ji = 0
+ if (!_pairwiseSecrets.ContainsKey(clientJ))
+ {
+ _pairwiseSecrets[clientJ] = new Dictionary();
+ }
+
+ double[] negatedSecret = new double[_parameterCount];
+ for (int k = 0; k < _parameterCount; k++)
+ {
+ negatedSecret[k] = -secret[k];
+ }
+ _pairwiseSecrets[clientJ][clientI] = negatedSecret;
+ }
+ }
+ }
+
+ ///
+ /// Masks a client's model update with pairwise secrets.
+ ///
+ ///
+ /// For Beginners: This adds secret masks to the client's update so the
+ /// server can't see the original values. Only after aggregating all clients
+ /// do the masks cancel out.
+ ///
+ /// Mathematical operation:
+ /// masked_update = original_update + Σ(secrets_with_other_clients)
+ ///
+ /// For example, Client 1 with 3 clients total:
+ /// - Original update: [0.5, -0.3, 0.8]
+ /// - Secret with Client 2: [0.1, 0.2, -0.1]
+ /// - Secret with Client 3: [-0.2, 0.1, 0.15]
+ /// - Masked update: [0.4, 0.0, 0.85]
+ ///
+ /// Server sees: [0.4, 0.0, 0.85] ← Cannot recover original [0.5, -0.3, 0.8]
+ /// But after aggregating all clients, secrets cancel and server gets correct sum.
+ ///
+ /// The ID of the client whose update to mask.
+ /// The client's model update.
+ /// The masked model update.
+ public Dictionary MaskUpdate(int clientId, Dictionary clientUpdate)
+ {
+ if (clientUpdate == null || clientUpdate.Count == 0)
+ {
+ throw new ArgumentException("Client update cannot be null or empty.", nameof(clientUpdate));
+ }
+
+ if (!_pairwiseSecrets.ContainsKey(clientId))
+ {
+ throw new ArgumentException($"No secrets found for client {clientId}. Call GeneratePairwiseSecrets first.", nameof(clientId));
+ }
+
+ // Create masked update
+ var maskedUpdate = new Dictionary();
+
+ // Flatten all parameters to apply masks
+ var flatParams = FlattenParameters(clientUpdate);
+ var maskedFlatParams = new double[flatParams.Length];
+ Array.Copy(flatParams, maskedFlatParams, flatParams.Length);
+
+ // Add all pairwise secrets for this client
+ foreach (var otherClientSecrets in _pairwiseSecrets[clientId].Values)
+ {
+ for (int i = 0; i < Math.Min(maskedFlatParams.Length, otherClientSecrets.Length); i++)
+ {
+ maskedFlatParams[i] += otherClientSecrets[i];
+ }
+ }
+
+ // Unflatten back to original structure
+ int paramIndex = 0;
+ foreach (var layerName in clientUpdate.Keys)
+ {
+ var originalLayer = clientUpdate[layerName];
+ var maskedLayer = new T[originalLayer.Length];
+
+ for (int i = 0; i < originalLayer.Length && paramIndex < maskedFlatParams.Length; i++, paramIndex++)
+ {
+ maskedLayer[i] = (T)Convert.ChangeType(maskedFlatParams[paramIndex], typeof(T));
+ }
+
+ maskedUpdate[layerName] = maskedLayer;
+ }
+
+ return maskedUpdate;
+ }
+
+ ///
+ /// Aggregates masked updates from all clients, recovering the true sum.
+ ///
+ ///
+ /// For Beginners: This sums up all the masked updates. Because the secret
+ /// masks cancel out, the result is the true sum of client updates.
+ ///
+ /// Mathematical property:
+ /// Σ(masked_update_i) = Σ(update_i + secrets_i)
+ /// = Σ(update_i) + Σ(secrets_i)
+ /// = Σ(update_i) + 0 ← secrets cancel
+ /// = True sum of updates
+ ///
+ /// The server performs this aggregation without ever seeing individual updates!
+ ///
+ /// For example with 2 clients:
+ /// Client 1 masked: [0.4, 0.0, 0.85] = [0.5, -0.3, 0.8] + [-0.1, 0.3, 0.05]
+ /// Client 2 masked: [0.7, 0.7, 1.05] = [0.6, 0.4, 1.1] + [0.1, -0.3, -0.05]
+ ///
+ /// Sum of masked: [1.1, 0.7, 1.9]
+ /// True sum: [0.5, -0.3, 0.8] + [0.6, 0.4, 1.1] = [1.1, 0.1, 1.9] ← Matches!
+ /// (Note: Secrets [-0.1, 0.3, 0.05] + [0.1, -0.3, -0.05] = [0, 0, 0] ← Cancelled)
+ ///
+ /// Dictionary of client IDs to their masked updates.
+ /// Dictionary of client IDs to their aggregation weights.
+ /// The securely aggregated model (sum of original updates with masks cancelled).
+ public Dictionary AggregateSecurely(
+ Dictionary> maskedUpdates,
+ Dictionary clientWeights)
+ {
+ if (maskedUpdates == null || maskedUpdates.Count == 0)
+ {
+ throw new ArgumentException("Masked updates cannot be null or empty.", nameof(maskedUpdates));
+ }
+
+ // Get model structure from first client
+ var firstUpdate = maskedUpdates.First().Value;
+ var aggregatedUpdate = new Dictionary();
+
+ // Initialize aggregated update with zeros
+ foreach (var layerName in firstUpdate.Keys)
+ {
+ aggregatedUpdate[layerName] = new T[firstUpdate[layerName].Length];
+ }
+
+ // Sum all masked updates
+ // The pairwise secrets will cancel out, leaving only the true sum
+ foreach (var clientId in maskedUpdates.Keys)
+ {
+ var maskedUpdate = maskedUpdates[clientId];
+
+ foreach (var layerName in maskedUpdate.Keys)
+ {
+ var maskedParams = maskedUpdate[layerName];
+ var aggregatedParams = aggregatedUpdate[layerName];
+
+ for (int i = 0; i < maskedParams.Length; i++)
+ {
+ double currentValue = Convert.ToDouble(aggregatedParams[i]);
+ double maskedValue = Convert.ToDouble(maskedParams[i]);
+ aggregatedParams[i] = (T)Convert.ChangeType(currentValue + maskedValue, typeof(T));
+ }
+ }
+ }
+
+ // If using weighted aggregation, divide by total weight
+ if (clientWeights != null && clientWeights.Count > 0)
+ {
+ double totalWeight = clientWeights.Values.Sum();
+
+ if (totalWeight > 0)
+ {
+ foreach (var layerName in aggregatedUpdate.Keys)
+ {
+ var aggregatedParams = aggregatedUpdate[layerName];
+
+ for (int i = 0; i < aggregatedParams.Length; i++)
+ {
+ double value = Convert.ToDouble(aggregatedParams[i]);
+ aggregatedParams[i] = (T)Convert.ChangeType(value / totalWeight, typeof(T));
+ }
+ }
+ }
+ }
+
+ return aggregatedUpdate;
+ }
+
+ ///
+ /// Flattens a hierarchical model structure into a single parameter array.
+ ///
+ ///
+ /// For Beginners: Converts the model from a dictionary of layers to a
+ /// single flat array of all parameters. Makes it easier to apply masks uniformly.
+ ///
+ /// The model to flatten.
+ /// A flat array of all parameters.
+ private double[] FlattenParameters(Dictionary model)
+ {
+ int totalParams = model.Values.Sum(layer => layer.Length);
+ double[] flatParams = new double[totalParams];
+
+ int index = 0;
+ foreach (var layer in model.Values)
+ {
+ foreach (var param in layer)
+ {
+ flatParams[index++] = Convert.ToDouble(param);
+ }
+ }
+
+ return flatParams;
+ }
+
+ ///
+ /// Clears all stored pairwise secrets.
+ ///
+ ///
+ /// For Beginners: Removes all secret keys from memory. Should be called
+ /// after aggregation is complete for security.
+ ///
+ /// Security best practice:
+ /// - Generate fresh secrets for each round
+ /// - Clear old secrets to prevent reuse
+ /// - Minimize time secrets are stored in memory
+ ///
+ public void ClearSecrets()
+ {
+ _pairwiseSecrets.Clear();
+ }
+
+ ///
+ /// Gets the number of clients with stored secrets.
+ ///
+ /// The count of clients.
+ public int GetClientCount()
+ {
+ return _pairwiseSecrets.Count;
+ }
+}
diff --git a/src/FederatedLearning/README.md b/src/FederatedLearning/README.md
new file mode 100644
index 000000000..1e0b4288c
--- /dev/null
+++ b/src/FederatedLearning/README.md
@@ -0,0 +1,250 @@
+# Federated Learning Framework
+
+This directory contains a comprehensive implementation of Federated Learning (FL) for the AiDotNet library, addressing Issue #398 (Phase 3).
+
+## Overview
+
+Federated Learning is a privacy-preserving distributed machine learning approach where multiple clients (devices, institutions, edge nodes) collaboratively train a shared model without sharing their raw data. Only model updates are exchanged, ensuring data privacy.
+
+## Features Implemented
+
+### Core Algorithms
+
+#### 1. FedAvg (Federated Averaging)
+- **Location**: `Aggregators/FedAvgAggregationStrategy.cs`
+- **Description**: The foundational FL algorithm that performs weighted averaging of client model updates
+- **Use Case**: Standard federated learning with IID or mildly non-IID data
+- **Reference**: McMahan et al. (2017) - "Communication-Efficient Learning of Deep Networks from Decentralized Data"
+
+#### 2. FedProx (Federated Proximal)
+- **Location**: `Aggregators/FedProxAggregationStrategy.cs`
+- **Description**: Handles system heterogeneity by adding proximal terms to prevent client drift
+- **Use Case**: Non-IID data, heterogeneous client capabilities, stragglers
+- **Key Parameter**: μ (proximal term coefficient)
+- **Reference**: Li et al. (2020) - "Federated Optimization in Heterogeneous Networks"
+
+#### 3. FedBN (Federated Batch Normalization)
+- **Location**: `Aggregators/FedBNAggregationStrategy.cs`
+- **Description**: Keeps batch normalization layers local while aggregating other layers
+- **Use Case**: Deep neural networks with batch normalization, non-IID data with distribution shift
+- **Reference**: Li et al. (2021) - "Federated Learning on Non-IID Data Silos"
+
+### Privacy Features
+
+#### 1. Differential Privacy
+- **Location**: `Privacy/GaussianDifferentialPrivacy.cs`
+- **Description**: Implements (ε, δ)-differential privacy using the Gaussian mechanism
+- **Features**:
+ - Gradient clipping for sensitivity bounding
+ - Calibrated Gaussian noise addition
+ - Privacy budget tracking
+ - Configurable ε (privacy budget) and δ (failure probability)
+- **Reference**: Dwork & Roth (2014) - "The Algorithmic Foundations of Differential Privacy"
+
+#### 2. Secure Aggregation
+- **Location**: `Privacy/SecureAggregation.cs`
+- **Description**: Cryptographic protocol to aggregate updates without revealing individual contributions
+- **Features**:
+ - Pairwise secret masking
+ - Server only sees aggregated result
+ - Protection against honest-but-curious server
+- **Reference**: Bonawitz et al. (2017) - "Practical Secure Aggregation for Privacy-Preserving Machine Learning"
+
+### Personalization
+
+#### Personalized Federated Learning
+- **Location**: `Personalization/PersonalizedFederatedLearning.cs`
+- **Description**: Enables client-specific model layers while maintaining shared global layers
+- **Features**:
+ - Layer-wise personalization
+ - Configurable personalization fraction
+ - Model split statistics
+ - Flexible personalization strategies
+- **Use Case**: Non-IID data, client-specific adaptations, multi-task learning
+- **Reference**: Fallah et al. (2020) - "Personalized Federated Learning: A Meta-Learning Approach"
+
+## Architecture
+
+### Interfaces
+
+All core interfaces are located in `src/Interfaces/`:
+
+- **IFederatedTrainer**: Main trainer for orchestrating federated learning
+- **IAggregationStrategy**: Strategy pattern for different aggregation algorithms
+- **IPrivacyMechanism**: Privacy-preserving mechanisms for protecting client data
+- **IClientModel**: Client-side model operations
+
+### Configuration
+
+- **FederatedLearningOptions** (`src/Models/Options/FederatedLearningOptions.cs`): Comprehensive configuration options
+ - Client management (number of clients, selection fraction)
+ - Training parameters (learning rate, epochs, batch size)
+ - Privacy settings (differential privacy, secure aggregation)
+ - Convergence criteria
+ - Personalization settings
+ - Compression options
+
+- **FederatedLearningMetadata** (`src/Models/FederatedLearningMetadata.cs`): Training metrics and statistics
+ - Performance metrics (accuracy, loss)
+ - Resource usage (time, communication)
+ - Privacy budget tracking
+ - Per-round detailed metrics
+
+## Usage Examples
+
+### Basic FedAvg
+
+```csharp
+using AiDotNet.FederatedLearning.Aggregators;
+using AiDotNet.Models.Options;
+
+// Create FedAvg aggregation strategy
+var aggregator = new FedAvgAggregationStrategy();
+
+// Define client models and weights
+var clientModels = new Dictionary>
+{
+ [0] = new Dictionary { ["layer1"] = new[] { 1.0, 2.0 } },
+ [1] = new Dictionary { ["layer1"] = new[] { 3.0, 4.0 } }
+};
+
+var clientWeights = new Dictionary
+{
+ [0] = 100.0, // 100 samples
+ [1] = 200.0 // 200 samples
+};
+
+// Aggregate models
+var globalModel = aggregator.Aggregate(clientModels, clientWeights);
+```
+
+### Differential Privacy
+
+```csharp
+using AiDotNet.FederatedLearning.Privacy;
+
+// Create differential privacy mechanism
+var dp = new GaussianDifferentialPrivacy(
+ clipNorm: 1.0,
+ randomSeed: 42 // For reproducibility
+);
+
+// Apply privacy to model
+var privateModel = dp.ApplyPrivacy(
+ model: clientModel,
+ epsilon: 1.0, // Privacy budget
+ delta: 1e-5 // Failure probability
+);
+
+// Check privacy budget consumed
+Console.WriteLine($"Privacy budget used: {dp.GetPrivacyBudgetConsumed()}");
+```
+
+### Personalized Federated Learning
+
+```csharp
+using AiDotNet.FederatedLearning.Personalization;
+
+// Create personalization handler
+var pfl = new PersonalizedFederatedLearning(
+ personalizationFraction: 0.2 // Keep last 20% of layers personalized
+);
+
+// Identify which layers to personalize
+pfl.IdentifyPersonalizedLayers(modelStructure, strategy: "last_n");
+
+// Separate model into global and personalized parts
+pfl.SeparateModel(
+ fullModel: clientModel,
+ out var globalPart,
+ out var personalizedPart
+);
+
+// Send only globalPart to server for aggregation
+// Keep personalizedPart on client
+
+// After receiving aggregated global model from server
+var updatedFullModel = pfl.CombineModels(
+ globalUpdate: aggregatedGlobalModel,
+ personalizedLayers: personalizedPart
+);
+```
+
+## Configuration Options
+
+Key parameters in `FederatedLearningOptions`:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `NumberOfClients` | int | 10 | Total number of participating clients |
+| `ClientSelectionFraction` | double | 1.0 | Fraction of clients selected per round (0.0-1.0) |
+| `LocalEpochs` | int | 5 | Number of epochs each client trains locally |
+| `MaxRounds` | int | 100 | Maximum federated learning rounds |
+| `LearningRate` | double | 0.01 | Learning rate for local training |
+| `BatchSize` | int | 32 | Batch size for local training |
+| `UseDifferentialPrivacy` | bool | false | Enable differential privacy |
+| `PrivacyEpsilon` | double | 1.0 | Privacy budget (ε) |
+| `PrivacyDelta` | double | 1e-5 | Privacy failure probability (δ) |
+| `UseSecureAggregation` | bool | false | Enable secure aggregation |
+| `AggregationStrategy` | string | "FedAvg" | Aggregation algorithm to use |
+| `ProximalMu` | double | 0.01 | FedProx proximal term coefficient |
+| `EnablePersonalization` | bool | false | Enable personalized layers |
+| `PersonalizationLayerFraction` | double | 0.2 | Fraction of layers to personalize |
+
+## Success Criteria (from Issue #398)
+
+✅ **Core Algorithms**: FedAvg, FedProx, FedBN implemented
+✅ **Privacy Features**: Differential Privacy and Secure Aggregation implemented
+✅ **Personalization**: Personalized Federated Learning implemented
+✅ **Architecture**: Clean interfaces (IFederatedTrainer, IAggregationStrategy)
+✅ **Configuration**: Comprehensive options and metadata classes
+✅ **Documentation**: Extensive XML documentation with beginner-friendly explanations
+✅ **Testing**: Unit tests for core algorithms and privacy mechanisms
+
+## Testing
+
+Unit tests are located in `tests/AiDotNet.Tests/FederatedLearning/`:
+
+- `FedAvgAggregationStrategyTests.cs`: Tests for FedAvg aggregation
+- `GaussianDifferentialPrivacyTests.cs`: Tests for differential privacy mechanism
+
+Run tests:
+```bash
+dotnet test tests/AiDotNet.Tests/
+```
+
+## Future Enhancements
+
+Potential extensions for future phases:
+
+1. **LEAF Benchmark Integration**: Add support for LEAF federated datasets
+2. **Communication Efficiency**: Implement gradient compression and quantization
+3. **Advanced Privacy**: Add Rényi Differential Privacy for tighter composition bounds
+4. **Byzantine Robustness**: Implement Krum, Median, Trimmed Mean aggregation
+5. **Meta-Learning**: Add MAML (Model-Agnostic Meta-Learning) for personalization
+6. **Asynchronous FL**: Support for asynchronous client updates
+7. **Vertical FL**: Support for vertically partitioned data
+8. **Cross-Silo FL**: Enhanced support for enterprise federated learning scenarios
+
+## References
+
+1. McMahan, H. B., et al. (2017). "Communication-Efficient Learning of Deep Networks from Decentralized Data." AISTATS 2017.
+2. Li, T., et al. (2020). "Federated Optimization in Heterogeneous Networks." MLSys 2020.
+3. Li, X., et al. (2021). "Federated Learning on Non-IID Data Silos: An Experimental Study." ICDE 2021.
+4. Bonawitz, K., et al. (2017). "Practical Secure Aggregation for Privacy-Preserving Machine Learning." CCS 2017.
+5. Dwork, C., & Roth, A. (2014). "The Algorithmic Foundations of Differential Privacy." Foundations and Trends in Theoretical Computer Science.
+6. Fallah, A., et al. (2020). "Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach." NeurIPS 2020.
+
+## Contributing
+
+When adding new federated learning algorithms or features:
+
+1. Follow the existing interface patterns (IAggregationStrategy, IPrivacyMechanism, etc.)
+2. Add comprehensive XML documentation with beginner-friendly explanations
+3. Include mathematical formulations and references
+4. Add unit tests for new functionality
+5. Update this README with usage examples
+
+## License
+
+This implementation is part of the AiDotNet library and is licensed under Apache-2.0.
diff --git a/src/Interfaces/IAggregationStrategy.cs b/src/Interfaces/IAggregationStrategy.cs
new file mode 100644
index 000000000..7595ab88b
--- /dev/null
+++ b/src/Interfaces/IAggregationStrategy.cs
@@ -0,0 +1,67 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines strategies for aggregating model updates from multiple clients in federated learning.
+///
+///
+/// This interface represents different methods for combining model updates from distributed clients
+/// into a single improved global model.
+///
+/// For Beginners: An aggregation strategy is like a voting system or consensus mechanism
+/// that decides how to combine different opinions into a single decision.
+///
+/// Think of aggregation strategies as different ways to combine contributions:
+/// - Simple average: Everyone's input counts equally
+/// - Weighted average: Some contributors' inputs count more based on criteria (data size, accuracy)
+/// - Robust methods: Ignore outliers or malicious contributions
+///
+/// For example, in a federated learning scenario with hospitals:
+/// - Hospital A has 10,000 patients: gets weight of 10,000
+/// - Hospital B has 5,000 patients: gets weight of 5,000
+/// - The aggregation strategy might weight Hospital A's updates more heavily
+///
+/// Different strategies handle different challenges:
+/// - FedAvg: Standard weighted averaging
+/// - FedProx: Handles clients with different update frequencies
+/// - Krum: Robust to Byzantine (malicious) clients
+/// - Median aggregation: Resistant to outliers
+///
+/// The type of model being aggregated.
+public interface IAggregationStrategy
+{
+ ///
+ /// Aggregates model updates from multiple clients into a single global model update.
+ ///
+ ///
+ /// This method combines model updates from clients using the strategy's specific algorithm.
+ ///
+ /// For Beginners: Aggregation is like combining multiple rough drafts of a document
+ /// into one polished version that incorporates the best parts of each.
+ ///
+ /// The aggregation process typically:
+ /// 1. Takes model updates (weight changes) from each client
+ /// 2. Considers the weight or importance of each client (based on data size, accuracy, etc.)
+ /// 3. Combines these updates using the strategy's algorithm
+ /// 4. Returns a single aggregated model that represents the collective improvement
+ ///
+ /// For example with weighted averaging (FedAvg):
+ /// - Client 1 (1000 samples): model update A
+ /// - Client 2 (500 samples): model update B
+ /// - Client 3 (1500 samples): model update C
+ /// - Aggregated update = (1000*A + 500*B + 1500*C) / 3000
+ ///
+ /// Dictionary mapping client IDs to their trained models.
+ /// Dictionary mapping client IDs to their aggregation weights (typically based on data size).
+ /// The aggregated global model.
+ TModel Aggregate(Dictionary clientModels, Dictionary clientWeights);
+
+ ///
+ /// Gets the name of the aggregation strategy.
+ ///
+ ///
+ /// For Beginners: This helps identify which aggregation method is being used,
+ /// useful for logging, debugging, and comparing different strategies.
+ ///
+ /// A string describing the aggregation strategy (e.g., "FedAvg", "FedProx", "Krum").
+ string GetStrategyName();
+}
diff --git a/src/Interfaces/IClientModel.cs b/src/Interfaces/IClientModel.cs
new file mode 100644
index 000000000..b31772995
--- /dev/null
+++ b/src/Interfaces/IClientModel.cs
@@ -0,0 +1,129 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines the functionality for a client-side model in federated learning.
+///
+///
+/// This interface represents a model that exists on a client device or node in a federated
+/// learning system. Each client maintains its own copy of the global model and trains it
+/// on local data.
+///
+/// For Beginners: A client model is like a student's personal copy of study materials.
+/// Each student (client) has their own copy, studies it with their own resources, and
+/// contributes improvements back to the class.
+///
+/// Think of client models as distributed learners:
+/// - Each client has a copy of the global model
+/// - Clients train on their own private data
+/// - Local training happens independently and in parallel
+/// - Only model updates (not data) are sent to the server
+///
+/// For example, in smartphone keyboard prediction:
+/// - Each phone has a copy of the global typing prediction model
+/// - The phone learns from the user's typing patterns
+/// - It sends model improvements (not actual typed text) to the server
+/// - The server combines improvements from millions of phones
+/// - Each phone gets the improved model back
+///
+/// This design ensures:
+/// - Data privacy: Raw data never leaves the client
+/// - Personalization: Can adapt to local data distribution
+/// - Scalability: Training happens in parallel across all clients
+///
+/// The type of the local training data.
+/// The type of the model update to send to the server.
+public interface IClientModel
+{
+ ///
+ /// Trains the local model on the client's private data.
+ ///
+ ///
+ /// Local training is the core of federated learning where each client improves the model
+ /// using their own data without sharing that data with anyone.
+ ///
+ /// For Beginners: This is like studying independently with your own materials.
+ /// You use your personal notes and resources to learn, and later share what you learned,
+ /// not the actual materials.
+ ///
+ /// The training process:
+ /// 1. Receive the global model from the server
+ /// 2. Train on local data for specified number of epochs
+ /// 3. Compute the difference between updated and original model (the "update")
+ /// 4. Prepare this update to send back to the server
+ ///
+ /// For example:
+ /// - Client receives global model with accuracy 80%
+ /// - Trains on local data for 5 epochs
+ /// - Local model now has accuracy 85% on local data
+ /// - Computes weight changes (delta) that improved the model
+ /// - Sends these weight changes to server, not the local data
+ ///
+ /// The client's private training data.
+ /// Number of training iterations to perform on local data.
+ /// Step size for gradient descent optimization.
+ void TrainLocal(TData localData, int epochs, double learningRate);
+
+ ///
+ /// Computes and retrieves the model update to send to the server.
+ ///
+ ///
+ /// The model update represents the improvements the client made through local training.
+ /// This is typically the difference between the current model and the initial global model.
+ ///
+ /// For Beginners: This is like preparing a summary of what you learned from studying,
+ /// rather than sharing your entire study materials. You share the insights, not the sources.
+ ///
+ /// The update typically contains:
+ /// - Weight differences: New weights - original weights
+ /// - Gradients: Direction and magnitude of improvement
+ /// - Metadata: Number of local samples, local loss, etc.
+ ///
+ /// For example:
+ /// - Original weight for feature "age": 0.5
+ /// - After training, weight for "age": 0.6
+ /// - Update to send: +0.1
+ /// - This tells the server how to adjust that weight
+ ///
+ /// The model update containing weight changes or gradients.
+ TUpdate GetModelUpdate();
+
+ ///
+ /// Updates the local model with the new global model from the server.
+ ///
+ ///
+ /// After the server aggregates updates from all clients, it sends the improved global
+ /// model back to clients for the next round of training.
+ ///
+ /// For Beginners: This is like receiving the updated textbook that incorporates
+ /// everyone's contributions. You replace your old version with this improved version
+ /// before the next study session.
+ ///
+ /// The update process:
+ /// 1. Receive aggregated global model from server
+ /// 2. Replace local model weights with global model weights
+ /// 3. Optionally keep some personalized layers
+ /// 4. Ready for next round of local training
+ ///
+ /// For example:
+ /// - Round 1: Trained local model, sent update
+ /// - Server aggregated all updates
+ /// - Round 2: Receive improved global model
+ /// - Use this as starting point for next round of training
+ ///
+ /// The aggregated global model from the server.
+ void UpdateFromGlobal(TUpdate globalModelUpdate);
+
+ ///
+ /// Gets the number of training samples available on this client.
+ ///
+ ///
+ /// Sample count is used to weight client contributions during aggregation.
+ /// Clients with more data typically receive higher weights.
+ ///
+ /// For Beginners: This is like indicating how many practice problems you solved.
+ /// If you solved 1000 problems and someone else solved 100, your insights about
+ /// problem-solving patterns are likely more reliable.
+ ///
+ /// The number of training samples on this client.
+ int GetSampleCount();
+}
diff --git a/src/Interfaces/IFederatedTrainer.cs b/src/Interfaces/IFederatedTrainer.cs
new file mode 100644
index 000000000..c8b6c7bda
--- /dev/null
+++ b/src/Interfaces/IFederatedTrainer.cs
@@ -0,0 +1,127 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines the core functionality for federated learning trainers that coordinate distributed training across multiple clients.
+///
+///
+/// This interface represents the fundamental operations for federated learning systems where multiple clients
+/// (devices, institutions, edge nodes) collaboratively train a shared model without sharing their raw data.
+///
+/// For Beginners: Federated learning is like group study where everyone learns from their own materials
+/// but shares only their insights, not their actual study materials.
+///
+/// Think of federated learning as a privacy-preserving collaborative learning approach:
+/// - Multiple clients (hospitals, phones, banks) have their own local data
+/// - Each client trains a model on their local data independently
+/// - Only model updates (not raw data) are shared with a central server
+/// - The server aggregates these updates to improve the global model
+/// - The improved global model is sent back to clients for the next round
+///
+/// For example, in healthcare:
+/// - Multiple hospitals want to train a disease detection model
+/// - Each hospital has patient data that cannot be shared due to privacy regulations
+/// - Each hospital trains the model on their own data
+/// - Only the learned patterns (model weights) are shared and combined
+/// - This creates a better model while keeping patient data private
+///
+/// This interface provides methods for coordinating the federated training process.
+///
+/// The type of the model being trained.
+/// The type of the training data.
+/// The type of metadata returned by the training process.
+public interface IFederatedTrainer
+{
+ ///
+ /// Initializes the federated learning process with client configurations and the global model.
+ ///
+ ///
+ /// This method sets up the initial state for federated learning by:
+ /// - Initializing the global model that will be shared across all clients
+ /// - Registering client configurations (number of clients, data distribution, etc.)
+ /// - Setting up communication channels for model updates
+ ///
+ /// For Beginners: Initialization is like setting up a study group before the first session.
+ /// You need to know who's participating, what materials everyone has, and establish
+ /// how you'll share information.
+ ///
+ /// The initial global model to be distributed to clients.
+ /// The number of clients participating in federated learning.
+ void Initialize(TModel globalModel, int numberOfClients);
+
+ ///
+ /// Executes one round of federated learning where clients train locally and updates are aggregated.
+ ///
+ ///
+ /// A federated learning round consists of several steps:
+ /// 1. The global model is sent to selected clients
+ /// 2. Each client trains the model on their local data
+ /// 3. Clients send their model updates back to the server
+ /// 4. The server aggregates these updates using an aggregation strategy
+ /// 5. The global model is updated with the aggregated result
+ ///
+ /// For Beginners: Think of this as one iteration in a collaborative learning cycle.
+ /// Everyone gets the current version of the shared knowledge, studies independently,
+ /// then contributes their improvements. These improvements are combined to create
+ /// an even better version for the next round.
+ ///
+ /// For example:
+ /// - Round 1: Clients receive initial model, train for 5 epochs, send updates
+ /// - Server aggregates updates and improves global model
+ /// - Round 2: Clients receive improved model, train again, send updates
+ /// - This continues until the model reaches desired accuracy
+ ///
+ /// Dictionary mapping client IDs to their local training data.
+ /// Fraction of clients to select for this round (0.0 to 1.0).
+ /// Number of training epochs each client should perform locally.
+ /// Metadata about the training round including accuracy, loss, and convergence metrics.
+ TMetadata TrainRound(Dictionary clientData, double clientSelectionFraction = 1.0, int localEpochs = 1);
+
+ ///
+ /// Executes multiple rounds of federated learning until convergence or maximum rounds reached.
+ ///
+ ///
+ /// This method orchestrates the entire federated learning process by:
+ /// - Running multiple training rounds
+ /// - Monitoring convergence (when the model stops improving significantly)
+ /// - Tracking performance metrics across rounds
+ /// - Applying privacy mechanisms if configured
+ ///
+ /// For Beginners: This is the complete federated learning process from start to finish.
+ /// It's like running an entire semester of study group sessions, where you continue meeting
+ /// until everyone has mastered the material or you've run out of time.
+ ///
+ /// Dictionary mapping client IDs to their local training data.
+ /// Maximum number of federated learning rounds to execute.
+ /// Fraction of clients to select per round (0.0 to 1.0).
+ /// Number of training epochs each client performs per round.
+ /// Aggregated metadata across all training rounds.
+ TMetadata Train(Dictionary clientData, int rounds, double clientSelectionFraction = 1.0, int localEpochs = 1);
+
+ ///
+ /// Retrieves the current global model after federated training.
+ ///
+ ///
+ /// The global model represents the collective knowledge learned from all participating clients.
+ ///
+ /// For Beginners: This is the final product of the collaborative learning process -
+ /// a model that benefits from all participants' data without ever accessing their raw data directly.
+ ///
+ /// The trained global model.
+ TModel GetGlobalModel();
+
+ ///
+ /// Sets the aggregation strategy used to combine client updates.
+ ///
+ ///
+ /// Different aggregation strategies handle various challenges in federated learning:
+ /// - FedAvg: Simple weighted averaging of model updates
+ /// - FedProx: Handles clients with different computational capabilities
+ /// - FedBN: Special handling for batch normalization layers
+ ///
+ /// For Beginners: The aggregation strategy is the rule for combining everyone's
+ /// contributions. Different rules work better for different situations, like how you might
+ /// weight expert opinions more heavily in certain contexts.
+ ///
+ /// The aggregation strategy to use.
+ void SetAggregationStrategy(IAggregationStrategy strategy);
+}
diff --git a/src/Interfaces/IPrivacyMechanism.cs b/src/Interfaces/IPrivacyMechanism.cs
new file mode 100644
index 000000000..502777893
--- /dev/null
+++ b/src/Interfaces/IPrivacyMechanism.cs
@@ -0,0 +1,90 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines privacy-preserving mechanisms for federated learning to protect client data.
+///
+///
+/// This interface represents techniques to ensure that model updates don't leak sensitive
+/// information about individual data points in clients' local datasets.
+///
+/// For Beginners: Privacy mechanisms are like filters that protect sensitive information
+/// while still allowing useful knowledge to be shared.
+///
+/// Think of privacy mechanisms as protective measures:
+/// - Differential Privacy: Adds carefully calibrated noise to make individual data unidentifiable
+/// - Secure Aggregation: Encrypts updates so the server only sees the combined result
+/// - Homomorphic Encryption: Allows computation on encrypted data
+///
+/// For example, in a hospital scenario:
+/// - Without privacy: Model updates might reveal information about specific patients
+/// - With differential privacy: Random noise is added so you can't identify individual patients
+/// - The noise is calibrated so the overall patterns remain accurate
+///
+/// Privacy mechanisms provide mathematical guarantees:
+/// - Epsilon (ε): Privacy budget - lower values mean stronger privacy
+/// - Delta (δ): Probability that privacy guarantee fails
+/// - Common setting: ε=1.0, δ=1e-5 means strong privacy with high confidence
+///
+/// The type of model to apply privacy mechanisms to.
+public interface IPrivacyMechanism
+{
+ ///
+ /// Applies privacy-preserving techniques to a model update before sharing it.
+ ///
+ ///
+ /// This method transforms model updates to provide privacy guarantees while maintaining utility.
+ ///
+ /// For Beginners: This is like redacting sensitive parts of a document before sharing it.
+ /// You remove or obscure information that could identify individuals while keeping the
+ /// useful content intact.
+ ///
+ /// Common techniques:
+ /// - Differential Privacy: Adds random noise proportional to sensitivity
+ /// - Gradient Clipping: Limits the magnitude of updates to prevent outliers
+ /// - Local DP: Each client adds noise before sending updates
+ /// - Central DP: Server adds noise after aggregation
+ ///
+ /// For example with differential privacy:
+ /// 1. Client trains model and computes weight updates
+ /// 2. Applies gradient clipping to limit maximum change
+ /// 3. Adds calibrated Gaussian noise to each weight
+ /// 4. Sends noisy update to server
+ /// 5. Even if server is compromised, individual data remains private
+ ///
+ /// The model update to apply privacy to.
+ /// Privacy budget parameter - smaller values provide stronger privacy.
+ /// Probability of privacy guarantee failure - typically very small (e.g., 1e-5).
+ /// The model update with privacy mechanisms applied.
+ TModel ApplyPrivacy(TModel model, double epsilon, double delta);
+
+ ///
+ /// Gets the current privacy budget consumed by this mechanism.
+ ///
+ ///
+ /// Privacy budget is a finite resource in differential privacy. Each time you share
+ /// information, you "spend" some privacy budget. Once exhausted, you can no longer
+ /// provide strong privacy guarantees.
+ ///
+ /// For Beginners: Think of privacy budget like a bank account for privacy.
+ /// Each time you share data, you withdraw from this account. When the account is empty,
+ /// you've used up your privacy guarantees and should stop sharing.
+ ///
+ /// For example:
+ /// - Start with privacy budget ε=10
+ /// - Round 1: Share update with ε=1, remaining budget = 9
+ /// - Round 2: Share update with ε=1, remaining budget = 8
+ /// - After 10 rounds, budget is exhausted
+ ///
+ /// The amount of privacy budget consumed so far.
+ double GetPrivacyBudgetConsumed();
+
+ ///
+ /// Gets the name of the privacy mechanism.
+ ///
+ ///
+ /// For Beginners: This identifies which privacy technique is being used,
+ /// helpful for documentation and comparing different privacy approaches.
+ ///
+ /// A string describing the privacy mechanism (e.g., "Gaussian Mechanism", "Laplace Mechanism").
+ string GetMechanismName();
+}
diff --git a/src/Models/FederatedLearningMetadata.cs b/src/Models/FederatedLearningMetadata.cs
new file mode 100644
index 000000000..186fb5312
--- /dev/null
+++ b/src/Models/FederatedLearningMetadata.cs
@@ -0,0 +1,301 @@
+namespace AiDotNet.Models;
+
+///
+/// Contains metadata and metrics about federated learning training progress and results.
+///
+///
+/// This class tracks various metrics throughout the federated learning process to help
+/// monitor training progress, diagnose issues, and evaluate model quality.
+///
+/// For Beginners: Metadata is like a training diary that records what happened
+/// during federated learning - how long it took, how accurate the model became, which
+/// clients participated, etc.
+///
+/// Think of this as a comprehensive training report containing:
+/// - Performance metrics: Accuracy, loss, convergence
+/// - Resource usage: Time, communication costs
+/// - Participation: Which clients contributed
+/// - Privacy tracking: Privacy budget consumption
+///
+/// For example, after training you might see:
+/// - Total rounds: 50 (out of max 100)
+/// - Final accuracy: 92.5%
+/// - Training time: 2 hours
+/// - Total clients participated: 100
+/// - Privacy budget used: ε=5.0 (out of 10.0 total)
+///
+public class FederatedLearningMetadata
+{
+ ///
+ /// Gets or sets the number of federated learning rounds completed.
+ ///
+ ///
+ /// For Beginners: A round is one complete cycle where clients train and the
+ /// server aggregates updates. This counts how many such cycles were completed.
+ ///
+ public int RoundsCompleted { get; set; }
+
+ ///
+ /// Gets or sets the final global model accuracy on validation data.
+ ///
+ ///
+ /// For Beginners: Accuracy measures how often the model makes correct predictions.
+ /// For example, 0.95 means the model is correct 95% of the time.
+ ///
+ /// This is measured on validation data that wasn't used for training.
+ ///
+ public double FinalAccuracy { get; set; }
+
+ ///
+ /// Gets or sets the final global model loss value.
+ ///
+ ///
+ /// For Beginners: Loss measures how far the model's predictions are from the
+ /// true values. Lower loss indicates better performance.
+ ///
+ /// For example:
+ /// - Initial loss: 2.5
+ /// - Final loss: 0.3
+ /// - The model has improved significantly
+ ///
+ public double FinalLoss { get; set; }
+
+ ///
+ /// Gets or sets the total training time in seconds.
+ ///
+ ///
+ /// For Beginners: The total wall-clock time from start to finish of federated
+ /// learning, including all rounds, communication, and aggregation.
+ ///
+ public double TotalTrainingTimeSeconds { get; set; }
+
+ ///
+ /// Gets or sets the average time per round in seconds.
+ ///
+ ///
+ /// For Beginners: How long each round takes on average. Useful for estimating
+ /// how long future training runs will take.
+ ///
+ /// For example:
+ /// - 50 rounds completed in 5000 seconds
+ /// - Average: 100 seconds per round
+ /// - Next training with 100 rounds will take ~10,000 seconds
+ ///
+ public double AverageRoundTimeSeconds { get; set; }
+
+ ///
+ /// Gets or sets the history of loss values across all rounds.
+ ///
+ ///
+ /// For Beginners: A list showing how loss changed after each round.
+ /// Useful for plotting learning curves and diagnosing training issues.
+ ///
+ /// For example: [2.5, 1.8, 1.2, 0.9, 0.7, 0.5, 0.4, 0.35, 0.32, 0.3]
+ /// Shows steady improvement from 2.5 to 0.3
+ ///
+ public List LossHistory { get; set; } = new List();
+
+ ///
+ /// Gets or sets the history of accuracy values across all rounds.
+ ///
+ ///
+ /// For Beginners: A list showing how accuracy improved after each round.
+ ///
+ /// For example: [0.60, 0.70, 0.78, 0.84, 0.88, 0.91, 0.93, 0.94, 0.945, 0.95]
+ /// Shows accuracy improving from 60% to 95%
+ ///
+ public List AccuracyHistory { get; set; } = new List();
+
+ ///
+ /// Gets or sets the total number of clients that participated across all rounds.
+ ///
+ ///
+ /// For Beginners: How many different clients contributed to the model.
+ ///
+ /// For example:
+ /// - 100 clients available
+ /// - 10 clients selected per round
+ /// - Over 50 rounds, might have 80 unique participants
+ ///
+ public int TotalClientsParticipated { get; set; }
+
+ ///
+ /// Gets or sets the average number of clients selected per round.
+ ///
+ ///
+ /// For Beginners: How many clients were active in each round on average.
+ ///
+ /// For example:
+ /// - Round 1: 10 clients
+ /// - Round 2: 8 clients (some unavailable)
+ /// - Round 3: 10 clients
+ /// - Average: 9.3 clients per round
+ ///
+ public double AverageClientsPerRound { get; set; }
+
+ ///
+ /// Gets or sets the total communication cost in megabytes.
+ ///
+ ///
+ /// For Beginners: The total amount of data transferred between clients and server
+ /// throughout training. Important for understanding bandwidth requirements.
+ ///
+ /// For example:
+ /// - Each model update: 10 MB
+ /// - 10 clients per round
+ /// - 50 rounds
+ /// - Total: 10 MB × 10 × 50 × 2 (up and down) = 10,000 MB = 10 GB
+ ///
+ public double TotalCommunicationMB { get; set; }
+
+ ///
+ /// Gets or sets the total privacy budget (epsilon) consumed.
+ ///
+ ///
+ /// For Beginners: If differential privacy is used, this tracks how much privacy
+ /// budget has been spent. Privacy budget is finite - once exhausted, no more privacy
+ /// guarantees.
+ ///
+ /// For example:
+ /// - ε per round: 0.1
+ /// - 50 rounds completed
+ /// - Total consumed: 5.0
+ /// - If total budget is 10.0, have 5.0 remaining
+ ///
+ public double TotalPrivacyBudgetConsumed { get; set; }
+
+ ///
+ /// Gets or sets whether training converged before reaching maximum rounds.
+ ///
+ ///
+ /// For Beginners: Convergence means the model stopped improving significantly.
+ /// If true, training ended early because the model reached a good solution.
+ ///
+ /// For example:
+ /// - Max rounds: 100
+ /// - Converged at round 50
+ /// - Converged = true
+ /// - Saved time by stopping early
+ ///
+ public bool Converged { get; set; }
+
+ ///
+ /// Gets or sets the round at which convergence was detected.
+ ///
+ ///
+ /// For Beginners: Which round did the model stop improving significantly?
+ ///
+ /// Useful for:
+ /// - Setting better MaxRounds for future training
+ /// - Understanding training dynamics
+ /// - Comparing different algorithms
+ ///
+ public int ConvergenceRound { get; set; }
+
+ ///
+ /// Gets or sets the aggregation strategy used during training.
+ ///
+ ///
+ /// For Beginners: Records which aggregation algorithm was used (FedAvg, FedProx, etc.).
+ /// Important for reproducibility and understanding results.
+ ///
+ public string AggregationStrategyUsed { get; set; } = string.Empty;
+
+ ///
+ /// Gets or sets whether differential privacy was enabled.
+ ///
+ ///
+ /// For Beginners: Records whether privacy mechanisms were active during training.
+ /// Important for understanding any accuracy trade-offs.
+ ///
+ public bool DifferentialPrivacyEnabled { get; set; }
+
+ ///
+ /// Gets or sets whether secure aggregation was enabled.
+ ///
+ ///
+ /// For Beginners: Records whether client updates were encrypted during aggregation.
+ ///
+ public bool SecureAggregationEnabled { get; set; }
+
+ ///
+ /// Gets or sets additional notes or observations about the training run.
+ ///
+ ///
+ /// For Beginners: A freeform field for recording anything unusual or noteworthy
+ /// that happened during training.
+ ///
+ /// For example:
+ /// - "Client 5 dropped out after round 30"
+ /// - "Convergence was slower than expected"
+ /// - "High variance in client update quality"
+ ///
+ public string Notes { get; set; } = string.Empty;
+
+ ///
+ /// Gets or sets the per-round detailed metrics.
+ ///
+ ///
+ /// For Beginners: Detailed information about each individual round, including
+ /// which clients participated, their individual losses, communication costs, etc.
+ ///
+ /// Useful for:
+ /// - Detailed analysis of training dynamics
+ /// - Identifying problematic clients
+ /// - Understanding convergence patterns
+ ///
+ public List RoundMetrics { get; set; } = new List();
+}
+
+///
+/// Contains detailed metrics for a single federated learning round.
+///
+///
+/// For Beginners: Information about what happened in one specific training round.
+///
+public class RoundMetadata
+{
+ ///
+ /// Gets or sets the round number (0-indexed).
+ ///
+ public int RoundNumber { get; set; }
+
+ ///
+ /// Gets or sets the global model loss after this round.
+ ///
+ public double GlobalLoss { get; set; }
+
+ ///
+ /// Gets or sets the global model accuracy after this round.
+ ///
+ public double GlobalAccuracy { get; set; }
+
+ ///
+ /// Gets or sets the IDs of clients selected for this round.
+ ///
+ public List SelectedClientIds { get; set; } = new List();
+
+ ///
+ /// Gets or sets the time taken for this round in seconds.
+ ///
+ public double RoundTimeSeconds { get; set; }
+
+ ///
+ /// Gets or sets the communication cost for this round in megabytes.
+ ///
+ public double CommunicationMB { get; set; }
+
+ ///
+ /// Gets or sets the average local loss across selected clients.
+ ///
+ ///
+ /// For Beginners: The average loss that clients achieved on their local data
+ /// before sending updates. Comparing this to global loss can reveal overfitting.
+ ///
+ public double AverageLocalLoss { get; set; }
+
+ ///
+ /// Gets or sets the privacy budget consumed in this round.
+ ///
+ public double PrivacyBudgetConsumed { get; set; }
+}
diff --git a/src/Models/Options/FederatedLearningOptions.cs b/src/Models/Options/FederatedLearningOptions.cs
new file mode 100644
index 000000000..c3ed5ab4c
--- /dev/null
+++ b/src/Models/Options/FederatedLearningOptions.cs
@@ -0,0 +1,350 @@
+namespace AiDotNet.Models.Options;
+
+///
+/// Configuration options for federated learning training.
+///
+///
+/// This class contains all the configurable parameters needed to set up and run a federated learning system.
+///
+/// For Beginners: Options are like the settings panel for federated learning.
+/// Just as you configure settings for a video game (difficulty, graphics quality, etc.),
+/// these options let you configure how federated learning should work.
+///
+/// Key configuration areas:
+/// - Client Management: How many clients, how to select them
+/// - Training: Learning rates, epochs, batch sizes
+/// - Privacy: Differential privacy parameters
+/// - Communication: How often to aggregate, compression settings
+/// - Convergence: When to stop training
+///
+/// For example, a typical configuration might be:
+/// - 100 total clients (e.g., hospitals)
+/// - Select 10 clients per round (10% participation)
+/// - Each client trains for 5 local epochs
+/// - Use privacy budget ε=1.0, δ=1e-5
+/// - Run for maximum 100 rounds or until convergence
+///
+public class FederatedLearningOptions
+{
+ ///
+ /// Gets or sets the total number of clients participating in federated learning.
+ ///
+ ///
+ /// For Beginners: This is the total pool of participants available for training.
+ /// In each round, a subset of these clients may be selected.
+ ///
+ /// For example:
+ /// - Mobile keyboard app: Millions of phones
+ /// - Healthcare: 50 hospitals
+ /// - Financial: 100 bank branches
+ ///
+ public int NumberOfClients { get; set; } = 10;
+
+ ///
+ /// Gets or sets the fraction of clients to select for each training round (0.0 to 1.0).
+ ///
+ ///
+ /// For Beginners: Not all clients participate in every round. This setting controls
+ /// what percentage of clients are active in each round.
+ ///
+ /// Common values:
+ /// - 1.0: All clients participate (small deployments)
+ /// - 0.1: 10% participate (large deployments, reduces communication)
+ /// - 0.01: 1% participate (massive deployments like mobile devices)
+ ///
+ /// For example, with 1000 clients and fraction 0.1:
+ /// - Each round randomly selects 100 clients
+ /// - Reduces server load and communication costs
+ /// - Still converges if enough clients are selected
+ ///
+ public double ClientSelectionFraction { get; set; } = 1.0;
+
+ ///
+ /// Gets or sets the number of local training epochs each client performs per round.
+ ///
+ ///
+ /// For Beginners: An epoch is one complete pass through the client's local dataset.
+ /// More epochs mean more local training but also more computation time.
+ ///
+ /// Trade-offs:
+ /// - More epochs (5-10): Better local adaptation, slower rounds
+ /// - Fewer epochs (1-2): Faster rounds, more communication needed
+ ///
+ /// For example:
+ /// - Client has 1000 samples
+ /// - LocalEpochs = 5
+ /// - Client processes all 1000 samples 5 times before sending update
+ ///
+ public int LocalEpochs { get; set; } = 5;
+
+ ///
+ /// Gets or sets the maximum number of federated learning rounds to execute.
+ ///
+ ///
+ /// For Beginners: A round is one complete cycle where clients train and the server
+ /// aggregates updates. This sets the maximum number of such cycles.
+ ///
+ /// Typical values:
+ /// - Quick experiments: 10-50 rounds
+ /// - Production training: 100-1000 rounds
+ /// - Large scale: 1000+ rounds
+ ///
+ /// Training may stop early if convergence criteria are met.
+ ///
+ public int MaxRounds { get; set; } = 100;
+
+ ///
+ /// Gets or sets the learning rate for local client training.
+ ///
+ ///
+ /// For Beginners: Learning rate controls how big of a step the model takes when
+ /// learning from data. Too large and learning is unstable; too small and learning is slow.
+ ///
+ /// Common values:
+ /// - Deep learning: 0.001 - 0.01
+ /// - Traditional ML: 0.01 - 0.1
+ ///
+ /// For example:
+ /// - LearningRate = 0.01 means adjust weights by 1% of the gradient
+ /// - Higher values = faster learning but less stability
+ /// - Lower values = slower but more stable learning
+ ///
+ public double LearningRate { get; set; } = 0.01;
+
+ ///
+ /// Gets or sets the batch size for local training.
+ ///
+ ///
+ /// For Beginners: Instead of processing all data at once, we process it in smaller
+ /// batches. Batch size is how many samples to process before updating the model.
+ ///
+ /// Trade-offs:
+ /// - Larger batches (64-512): More stable gradients, requires more memory
+ /// - Smaller batches (8-32): Less memory, more noise in updates
+ ///
+ /// For example:
+ /// - Client has 1000 samples, BatchSize = 32
+ /// - Data is split into 32 batches of ~31 samples each
+ /// - Model is updated after processing each batch
+ ///
+ public int BatchSize { get; set; } = 32;
+
+ ///
+ /// Gets or sets whether to use differential privacy.
+ ///
+ ///
+ /// For Beginners: Differential privacy adds mathematical noise to protect individual
+ /// data points from being identified in the model updates.
+ ///
+ /// When enabled:
+ /// - Privacy guarantees are provided
+ /// - Some accuracy may be sacrificed
+ /// - Individual contributions are hidden
+ ///
+ /// Use cases requiring privacy:
+ /// - Healthcare data
+ /// - Financial records
+ /// - Personal communications
+ ///
+ public bool UseDifferentialPrivacy { get; set; } = false;
+
+ ///
+ /// Gets or sets the epsilon (ε) parameter for differential privacy (privacy budget).
+ ///
+ ///
+ /// For Beginners: Epsilon controls the privacy-utility tradeoff. Lower values mean
+ /// stronger privacy but potentially less accurate models.
+ ///
+ /// Common values:
+ /// - ε = 0.1: Very strong privacy, significant accuracy loss
+ /// - ε = 1.0: Strong privacy, moderate accuracy loss (recommended)
+ /// - ε = 10.0: Weak privacy, minimal accuracy loss
+ ///
+ /// For example:
+ /// - With ε = 1.0, an adversary cannot distinguish whether any specific individual's
+ /// data was used in training (within factor e^1 ≈ 2.7)
+ ///
+ public double PrivacyEpsilon { get; set; } = 1.0;
+
+ ///
+ /// Gets or sets the delta (δ) parameter for differential privacy (failure probability).
+ ///
+ ///
+ /// For Beginners: Delta is the probability that the privacy guarantee fails.
+ /// It should be very small, typically much less than 1/number_of_data_points.
+ ///
+ /// Common values:
+ /// - δ = 1e-5 (0.00001): Standard choice
+ /// - δ = 1e-6: Stronger guarantee
+ ///
+ /// For example:
+ /// - δ = 1e-5 means there's a 0.001% chance privacy is compromised
+ /// - Should be smaller than 1/total_number_of_samples across all clients
+ ///
+ public double PrivacyDelta { get; set; } = 1e-5;
+
+ ///
+ /// Gets or sets whether to use secure aggregation.
+ ///
+ ///
+ /// For Beginners: Secure aggregation encrypts client updates so that the server
+ /// can only see the aggregated result, not individual contributions.
+ ///
+ /// Benefits:
+ /// - Server cannot see individual client updates
+ /// - Protects against honest-but-curious server
+ /// - Only the final aggregated model is visible
+ ///
+ /// For example:
+ /// - Without: Server sees each hospital's model update
+ /// - With: Server only sees combined update from all hospitals
+ /// - No single hospital's contribution is visible
+ ///
+ public bool UseSecureAggregation { get; set; } = false;
+
+ ///
+ /// Gets or sets the convergence threshold for early stopping.
+ ///
+ ///
+ /// For Beginners: Training stops early if improvement between rounds falls below
+ /// this threshold, indicating the model has converged (stopped improving significantly).
+ ///
+ /// For example:
+ /// - ConvergenceThreshold = 0.001
+ /// - If loss improves by less than 0.001 for several consecutive rounds, stop training
+ /// - Saves time and resources by avoiding unnecessary rounds
+ ///
+ public double ConvergenceThreshold { get; set; } = 0.001;
+
+ ///
+ /// Gets or sets the minimum number of rounds to train before checking convergence.
+ ///
+ ///
+ /// For Beginners: Don't check for convergence too early. Wait at least this many
+ /// rounds before considering early stopping.
+ ///
+ /// This prevents:
+ /// - Stopping too early due to initial volatility
+ /// - Missing later improvements
+ ///
+ /// Typical value: 10-20 rounds
+ ///
+ public int MinRoundsBeforeConvergence { get; set; } = 10;
+
+ ///
+ /// Gets or sets the aggregation strategy name to use.
+ ///
+ ///
+ /// For Beginners: This determines how client updates are combined.
+ ///
+ /// Available strategies:
+ /// - "FedAvg": Weighted average (standard choice)
+ /// - "FedProx": Handles system heterogeneity
+ /// - "FedBN": Special handling for batch normalization
+ ///
+ /// Different strategies work better for different scenarios.
+ ///
+ public string AggregationStrategy { get; set; } = "FedAvg";
+
+ ///
+ /// Gets or sets the proximal term coefficient for FedProx algorithm.
+ ///
+ ///
+ /// For Beginners: FedProx adds a penalty to prevent client models from
+ /// deviating too much from the global model. This parameter controls the penalty strength.
+ ///
+ /// Common values:
+ /// - 0.0: No proximal term (equivalent to FedAvg)
+ /// - 0.01 - 0.1: Mild constraint
+ /// - 1.0+: Strong constraint
+ ///
+ /// Use FedProx when:
+ /// - Clients have very different data distributions
+ /// - Some clients are much slower than others
+ /// - You want more stable convergence
+ ///
+ public double ProximalMu { get; set; } = 0.01;
+
+ ///
+ /// Gets or sets whether to enable personalization.
+ ///
+ ///
+ /// For Beginners: Personalization allows each client to maintain some client-specific
+ /// model parameters while sharing common parameters with other clients.
+ ///
+ /// Benefits:
+ /// - Better performance on local data
+ /// - Handles non-IID data (data that varies across clients)
+ /// - Combines benefits of global and local models
+ ///
+ /// For example:
+ /// - Global layers: Learn general patterns from all clients
+ /// - Personalized layers: Adapt to each client's specific data
+ ///
+ public bool EnablePersonalization { get; set; } = false;
+
+ ///
+ /// Gets or sets the fraction of model layers to keep personalized (not aggregated).
+ ///
+ ///
+ /// For Beginners: When personalization is enabled, this determines what fraction
+ /// of the model remains client-specific vs. shared globally.
+ ///
+ /// For example:
+ /// - PersonalizationLayerFraction = 0.2
+ /// - Last 20% of model layers stay personalized
+ /// - First 80% are aggregated globally
+ ///
+ /// Typical use:
+ /// - Output layers personalized, feature extractors shared
+ ///
+ public double PersonalizationLayerFraction { get; set; } = 0.2;
+
+ ///
+ /// Gets or sets whether to use gradient compression to reduce communication costs.
+ ///
+ ///
+ /// For Beginners: Compression reduces the size of model updates sent between
+ /// clients and server, saving bandwidth and time.
+ ///
+ /// Techniques:
+ /// - Quantization: Use fewer bits per parameter
+ /// - Sparsification: Send only top-k largest updates
+ /// - Sketching: Use randomized compression
+ ///
+ /// Trade-off:
+ /// - Reduces communication by 10-100x
+ /// - May slightly slow convergence
+ ///
+ public bool UseCompression { get; set; } = false;
+
+ ///
+ /// Gets or sets the compression ratio (0.0 to 1.0) if compression is enabled.
+ ///
+ ///
+ /// For Beginners: Controls how much to compress. Lower values mean more compression
+ /// but potentially more accuracy loss.
+ ///
+ /// For example:
+ /// - 0.01: Keep top 1% of gradients (99% compression)
+ /// - 0.1: Keep top 10% of gradients (90% compression)
+ /// - 1.0: No compression
+ ///
+ public double CompressionRatio { get; set; } = 0.1;
+
+ ///
+ /// Gets or sets a random seed for reproducibility.
+ ///
+ ///
+ /// For Beginners: Random seed makes randomness reproducible. Using the same
+ /// seed will produce the same random client selections, initializations, etc.
+ ///
+ /// Benefits:
+ /// - Reproducible experiments
+ /// - Easier debugging
+ /// - Fair comparison between methods
+ ///
+ /// Set to null for truly random behavior.
+ ///
+ public int? RandomSeed { get; set; } = null;
+}
diff --git a/tests/AiDotNet.Tests/FederatedLearning/FedAvgAggregationStrategyTests.cs b/tests/AiDotNet.Tests/FederatedLearning/FedAvgAggregationStrategyTests.cs
new file mode 100644
index 000000000..8568b3c58
--- /dev/null
+++ b/tests/AiDotNet.Tests/FederatedLearning/FedAvgAggregationStrategyTests.cs
@@ -0,0 +1,201 @@
+namespace AiDotNet.Tests.FederatedLearning;
+
+using Xunit;
+using AiDotNet.FederatedLearning.Aggregators;
+using System;
+using System.Collections.Generic;
+
+///
+/// Unit tests for FedAvg (Federated Averaging) aggregation strategy.
+///
+public class FedAvgAggregationStrategyTests
+{
+ [Fact]
+ public void Aggregate_WithEqualWeights_ReturnsAverageModel()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+
+ // Create two client models with simple parameters
+ var clientModels = new Dictionary>
+ {
+ [0] = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0, 2.0, 3.0 }
+ },
+ [1] = new Dictionary
+ {
+ ["layer1"] = new double[] { 3.0, 4.0, 5.0 }
+ }
+ };
+
+ var clientWeights = new Dictionary
+ {
+ [0] = 1.0,
+ [1] = 1.0
+ };
+
+ // Act
+ var aggregatedModel = strategy.Aggregate(clientModels, clientWeights);
+
+ // Assert
+ Assert.NotNull(aggregatedModel);
+ Assert.Contains("layer1", aggregatedModel.Keys);
+ Assert.Equal(3, aggregatedModel["layer1"].Length);
+
+ // Expected: (1+3)/2=2, (2+4)/2=3, (3+5)/2=4
+ Assert.Equal(2.0, aggregatedModel["layer1"][0], precision: 5);
+ Assert.Equal(3.0, aggregatedModel["layer1"][1], precision: 5);
+ Assert.Equal(4.0, aggregatedModel["layer1"][2], precision: 5);
+ }
+
+ [Fact]
+ public void Aggregate_WithDifferentWeights_ReturnsWeightedAverage()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+
+ var clientModels = new Dictionary>
+ {
+ [0] = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0, 2.0 }
+ },
+ [1] = new Dictionary
+ {
+ ["layer1"] = new double[] { 3.0, 4.0 }
+ }
+ };
+
+ // Client 1 has 3x the weight (3x more data)
+ var clientWeights = new Dictionary
+ {
+ [0] = 1.0,
+ [1] = 3.0
+ };
+
+ // Act
+ var aggregatedModel = strategy.Aggregate(clientModels, clientWeights);
+
+ // Assert
+ // Expected: (1*1 + 3*3)/(1+3) = 10/4 = 2.5
+ // (2*1 + 4*3)/(1+3) = 14/4 = 3.5
+ Assert.Equal(2.5, aggregatedModel["layer1"][0], precision: 5);
+ Assert.Equal(3.5, aggregatedModel["layer1"][1], precision: 5);
+ }
+
+ [Fact]
+ public void Aggregate_WithMultipleLayers_AggregatesAllLayers()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+
+ var clientModels = new Dictionary>
+ {
+ [0] = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0 },
+ ["layer2"] = new double[] { 2.0 }
+ },
+ [1] = new Dictionary
+ {
+ ["layer1"] = new double[] { 3.0 },
+ ["layer2"] = new double[] { 4.0 }
+ }
+ };
+
+ var clientWeights = new Dictionary
+ {
+ [0] = 1.0,
+ [1] = 1.0
+ };
+
+ // Act
+ var aggregatedModel = strategy.Aggregate(clientModels, clientWeights);
+
+ // Assert
+ Assert.Equal(2, aggregatedModel.Count);
+ Assert.Contains("layer1", aggregatedModel.Keys);
+ Assert.Contains("layer2", aggregatedModel.Keys);
+ Assert.Equal(2.0, aggregatedModel["layer1"][0], precision: 5);
+ Assert.Equal(3.0, aggregatedModel["layer2"][0], precision: 5);
+ }
+
+ [Fact]
+ public void Aggregate_WithEmptyClientModels_ThrowsArgumentException()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+ var emptyModels = new Dictionary>();
+ var clientWeights = new Dictionary();
+
+ // Act & Assert
+ Assert.Throws(() => strategy.Aggregate(emptyModels, clientWeights));
+ }
+
+ [Fact]
+ public void Aggregate_WithNullClientModels_ThrowsArgumentException()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+ Dictionary>? nullModels = null;
+ var clientWeights = new Dictionary();
+
+ // Act & Assert
+ Assert.Throws(() => strategy.Aggregate(nullModels!, clientWeights));
+ }
+
+ [Fact]
+ public void GetStrategyName_ReturnsCorrectName()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+
+ // Act
+ var name = strategy.GetStrategyName();
+
+ // Assert
+ Assert.Equal("FedAvg", name);
+ }
+
+ [Fact]
+ public void Aggregate_WithThreeClients_ComputesCorrectWeightedAverage()
+ {
+ // Arrange
+ var strategy = new FedAvgAggregationStrategy();
+
+ var clientModels = new Dictionary>
+ {
+ [0] = new Dictionary
+ {
+ ["weights"] = new double[] { 0.1, 0.2, 0.3 }
+ },
+ [1] = new Dictionary
+ {
+ ["weights"] = new double[] { 0.2, 0.3, 0.4 }
+ },
+ [2] = new Dictionary
+ {
+ ["weights"] = new double[] { 0.3, 0.4, 0.5 }
+ }
+ };
+
+ var clientWeights = new Dictionary
+ {
+ [0] = 100.0, // 100 samples
+ [1] = 200.0, // 200 samples
+ [2] = 300.0 // 300 samples
+ };
+
+ // Act
+ var aggregatedModel = strategy.Aggregate(clientModels, clientWeights);
+
+ // Assert
+ // Expected: (0.1*100 + 0.2*200 + 0.3*300) / 600 = (10 + 40 + 90) / 600 = 140/600 = 0.2333...
+ // (0.2*100 + 0.3*200 + 0.4*300) / 600 = (20 + 60 + 120) / 600 = 200/600 = 0.3333...
+ // (0.3*100 + 0.4*200 + 0.5*300) / 600 = (30 + 80 + 150) / 600 = 260/600 = 0.4333...
+ Assert.Equal(0.2333333, aggregatedModel["weights"][0], precision: 5);
+ Assert.Equal(0.3333333, aggregatedModel["weights"][1], precision: 5);
+ Assert.Equal(0.4333333, aggregatedModel["weights"][2], precision: 5);
+ }
+}
diff --git a/tests/AiDotNet.Tests/FederatedLearning/GaussianDifferentialPrivacyTests.cs b/tests/AiDotNet.Tests/FederatedLearning/GaussianDifferentialPrivacyTests.cs
new file mode 100644
index 000000000..e4c3f1ded
--- /dev/null
+++ b/tests/AiDotNet.Tests/FederatedLearning/GaussianDifferentialPrivacyTests.cs
@@ -0,0 +1,199 @@
+namespace AiDotNet.Tests.FederatedLearning;
+
+using Xunit;
+using AiDotNet.FederatedLearning.Privacy;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+///
+/// Unit tests for Gaussian Differential Privacy mechanism.
+///
+public class GaussianDifferentialPrivacyTests
+{
+ [Fact]
+ public void Constructor_WithValidClipNorm_InitializesSuccessfully()
+ {
+ // Arrange & Act
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0);
+
+ // Assert
+ Assert.NotNull(dp);
+ Assert.Equal(0.0, dp.GetPrivacyBudgetConsumed());
+ }
+
+ [Fact]
+ public void Constructor_WithNegativeClipNorm_ThrowsArgumentException()
+ {
+ // Act & Assert
+ Assert.Throws(() => new GaussianDifferentialPrivacy(clipNorm: -1.0));
+ }
+
+ [Fact]
+ public void ApplyPrivacy_WithValidParameters_AddsNoiseToModel()
+ {
+ // Arrange
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 10.0, randomSeed: 42);
+
+ var originalModel = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0, 2.0, 3.0 }
+ };
+
+ // Act
+ var noisyModel = dp.ApplyPrivacy(originalModel, epsilon: 1.0, delta: 1e-5);
+
+ // Assert
+ Assert.NotNull(noisyModel);
+ Assert.Contains("layer1", noisyModel.Keys);
+
+ // Model should be different due to noise
+ bool hasNoise = false;
+ for (int i = 0; i < originalModel["layer1"].Length; i++)
+ {
+ if (Math.Abs(noisyModel["layer1"][i] - originalModel["layer1"][i]) > 0.0001)
+ {
+ hasNoise = true;
+ break;
+ }
+ }
+ Assert.True(hasNoise, "Noise should have been added to the model");
+ }
+
+ [Fact]
+ public void ApplyPrivacy_UpdatesPrivacyBudget()
+ {
+ // Arrange
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0);
+
+ var model = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0 }
+ };
+
+ // Act
+ dp.ApplyPrivacy(model, epsilon: 0.5, delta: 1e-5);
+
+ // Assert
+ Assert.Equal(0.5, dp.GetPrivacyBudgetConsumed());
+
+ // Apply privacy again
+ dp.ApplyPrivacy(model, epsilon: 0.3, delta: 1e-5);
+ Assert.Equal(0.8, dp.GetPrivacyBudgetConsumed(), precision: 5);
+ }
+
+ [Fact]
+ public void ApplyPrivacy_WithZeroEpsilon_ThrowsArgumentException()
+ {
+ // Arrange
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0);
+
+ var model = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0 }
+ };
+
+ // Act & Assert
+ Assert.Throws(() => dp.ApplyPrivacy(model, epsilon: 0.0, delta: 1e-5));
+ }
+
+ [Fact]
+ public void ApplyPrivacy_WithInvalidDelta_ThrowsArgumentException()
+ {
+ // Arrange
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0);
+
+ var model = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0 }
+ };
+
+ // Act & Assert
+ Assert.Throws(() => dp.ApplyPrivacy(model, epsilon: 1.0, delta: 0.0));
+ Assert.Throws(() => dp.ApplyPrivacy(model, epsilon: 1.0, delta: 1.0));
+ }
+
+ [Fact]
+ public void ResetPrivacyBudget_ResetsToZero()
+ {
+ // Arrange
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0);
+
+ var model = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0 }
+ };
+
+ dp.ApplyPrivacy(model, epsilon: 1.0, delta: 1e-5);
+ Assert.Equal(1.0, dp.GetPrivacyBudgetConsumed());
+
+ // Act
+ dp.ResetPrivacyBudget();
+
+ // Assert
+ Assert.Equal(0.0, dp.GetPrivacyBudgetConsumed());
+ }
+
+ [Fact]
+ public void GetMechanismName_ReturnsCorrectName()
+ {
+ // Arrange
+ var dp = new GaussianDifferentialPrivacy(clipNorm: 2.5);
+
+ // Act
+ var name = dp.GetMechanismName();
+
+ // Assert
+ Assert.Contains("Gaussian DP", name);
+ Assert.Contains("2.5", name);
+ }
+
+ [Fact]
+ public void ApplyPrivacy_WithSameSeed_ProducesSameNoise()
+ {
+ // Arrange
+ var dp1 = new GaussianDifferentialPrivacy(clipNorm: 1.0, randomSeed: 123);
+ var dp2 = new GaussianDifferentialPrivacy(clipNorm: 1.0, randomSeed: 123);
+
+ var model = new Dictionary
+ {
+ ["layer1"] = new double[] { 1.0, 2.0, 3.0 }
+ };
+
+ // Act
+ var noisyModel1 = dp1.ApplyPrivacy(model, epsilon: 1.0, delta: 1e-5);
+ var noisyModel2 = dp2.ApplyPrivacy(model, epsilon: 1.0, delta: 1e-5);
+
+ // Assert
+ for (int i = 0; i < noisyModel1["layer1"].Length; i++)
+ {
+ Assert.Equal(noisyModel1["layer1"][i], noisyModel2["layer1"][i], precision: 10);
+ }
+ }
+
+ [Fact]
+ public void ApplyPrivacy_PerformsGradientClipping()
+ {
+ // Arrange
+ var clipNorm = 1.0;
+ var dp = new GaussianDifferentialPrivacy(clipNorm: clipNorm, randomSeed: 42);
+
+ // Create model with large norm (sqrt(100 + 100 + 100) = ~17.3)
+ var model = new Dictionary
+ {
+ ["layer1"] = new double[] { 10.0, 10.0, 10.0 }
+ };
+
+ // Act
+ var clippedModel = dp.ApplyPrivacy(model, epsilon: 10.0, delta: 1e-5);
+
+ // Assert
+ // Calculate L2 norm of clipped model
+ double sumSquares = clippedModel["layer1"].Sum(x => x * x);
+ double norm = Math.Sqrt(sumSquares);
+
+ // Norm should be approximately clipNorm (plus some noise variance)
+ // With high epsilon (10.0), noise is minimal, so norm should be close to clipNorm
+ Assert.True(norm < clipNorm * 2.0, $"Norm {norm} should be reasonably close to clip norm {clipNorm}");
+ }
+}