diff --git a/src/FederatedLearning/Aggregators/FedAvgAggregationStrategy.cs b/src/FederatedLearning/Aggregators/FedAvgAggregationStrategy.cs new file mode 100644 index 000000000..a5efe6a29 --- /dev/null +++ b/src/FederatedLearning/Aggregators/FedAvgAggregationStrategy.cs @@ -0,0 +1,158 @@ +namespace AiDotNet.FederatedLearning.Aggregators; + +using AiDotNet.Interfaces; + +/// +/// Implements the Federated Averaging (FedAvg) aggregation strategy. +/// +/// +/// FedAvg is the foundational aggregation algorithm for federated learning, proposed by +/// McMahan et al. in 2017. It performs a weighted average of client model updates based +/// on the number of training samples each client has. +/// +/// For Beginners: FedAvg is like calculating a weighted class average where students +/// who solved more practice problems have more influence on the final answer. +/// +/// How FedAvg works: +/// 1. Each client trains on their local data and computes model updates +/// 2. Clients send their updated model weights to the server +/// 3. Server computes weighted average: weight = (client_samples / total_samples) +/// 4. New global model = Σ(weight_i × client_model_i) +/// +/// For example, with 3 hospitals: +/// - Hospital A: 1000 patients, model accuracy 90% +/// - Hospital B: 500 patients, model accuracy 88% +/// - Hospital C: 1500 patients, model accuracy 92% +/// +/// Total patients: 3000 +/// Hospital A weight: 1000/3000 = 0.333 +/// Hospital B weight: 500/3000 = 0.167 +/// Hospital C weight: 1500/3000 = 0.500 +/// +/// For each model parameter: +/// global_param = 0.333 × A_param + 0.167 × B_param + 0.500 × C_param +/// +/// Benefits: +/// - Simple and efficient +/// - Well-studied theoretically +/// - Works well when clients have similar data distributions (IID data) +/// +/// Limitations: +/// - Assumes clients are equally reliable +/// - Can struggle with non-IID data (different distributions across clients) +/// - No built-in handling for stragglers (slow clients) +/// +/// Reference: McMahan, H. B., et al. (2017). "Communication-Efficient Learning of Deep Networks +/// from Decentralized Data." AISTATS 2017. +/// +/// The numeric type for model parameters (e.g., double, float). +public class FedAvgAggregationStrategy : IAggregationStrategy> + where T : struct, IComparable, IConvertible +{ + /// + /// Aggregates client models using weighted averaging based on the number of samples. + /// + /// + /// This method implements the core FedAvg algorithm: + /// + /// Mathematical formulation: + /// w_global = Σ(n_k / n_total) × w_k + /// + /// where: + /// - w_global: global model weights + /// - w_k: client k's model weights + /// - n_k: number of samples at client k + /// - n_total: total samples across all clients + /// + /// For Beginners: This combines all client models into one by taking a weighted + /// average, where clients with more data have more influence. + /// + /// Step-by-step process: + /// 1. Calculate total samples across all clients + /// 2. For each client, compute weight = client_samples / total_samples + /// 3. For each model parameter, compute weighted sum + /// 4. Return the aggregated model + /// + /// For example, if we have 2 clients with a simple model (one parameter): + /// - Client 1: 300 samples, parameter value = 0.8 + /// - Client 2: 700 samples, parameter value = 0.6 + /// + /// Total samples: 1000 + /// Client 1 weight: 300/1000 = 0.3 + /// Client 2 weight: 700/1000 = 0.7 + /// Aggregated parameter: 0.3 × 0.8 + 0.7 × 0.6 = 0.24 + 0.42 = 0.66 + /// + /// Dictionary mapping client IDs to their model parameters. + /// Dictionary mapping client IDs to their sample counts (weights). + /// The aggregated global model parameters. + public Dictionary Aggregate( + Dictionary> clientModels, + Dictionary clientWeights) + { + if (clientModels == null || clientModels.Count == 0) + { + throw new ArgumentException("Client models cannot be null or empty.", nameof(clientModels)); + } + + if (clientWeights == null || clientWeights.Count == 0) + { + throw new ArgumentException("Client weights cannot be null or empty.", nameof(clientWeights)); + } + + // Calculate total weight (total number of samples across all clients) + double totalWeight = clientWeights.Values.Sum(); + + if (totalWeight <= 0) + { + throw new ArgumentException("Total weight must be positive.", nameof(clientWeights)); + } + + // Get the first client's model structure to initialize the aggregated model + var firstClientModel = clientModels.First().Value; + var aggregatedModel = new Dictionary(); + + // Initialize aggregated model with zeros + foreach (var layerName in firstClientModel.Keys) + { + aggregatedModel[layerName] = new T[firstClientModel[layerName].Length]; + } + + // Perform weighted aggregation + foreach (var clientId in clientModels.Keys) + { + var clientModel = clientModels[clientId]; + var clientWeight = clientWeights[clientId]; + + // Normalized weight for this client + double normalizedWeight = clientWeight / totalWeight; + + // Add weighted contribution from this client to the aggregated model + foreach (var layerName in clientModel.Keys) + { + var clientParams = clientModel[layerName]; + var aggregatedParams = aggregatedModel[layerName]; + + for (int i = 0; i < clientParams.Length; i++) + { + // Convert T to double, perform weighted addition, convert back to T + double currentValue = Convert.ToDouble(aggregatedParams[i]); + double clientValue = Convert.ToDouble(clientParams[i]); + double weightedValue = currentValue + (normalizedWeight * clientValue); + + aggregatedParams[i] = (T)Convert.ChangeType(weightedValue, typeof(T)); + } + } + } + + return aggregatedModel; + } + + /// + /// Gets the name of the aggregation strategy. + /// + /// The string "FedAvg". + public string GetStrategyName() + { + return "FedAvg"; + } +} diff --git a/src/FederatedLearning/Aggregators/FedBNAggregationStrategy.cs b/src/FederatedLearning/Aggregators/FedBNAggregationStrategy.cs new file mode 100644 index 000000000..ccee14b1f --- /dev/null +++ b/src/FederatedLearning/Aggregators/FedBNAggregationStrategy.cs @@ -0,0 +1,257 @@ +namespace AiDotNet.FederatedLearning.Aggregators; + +using AiDotNet.Interfaces; + +/// +/// Implements the Federated Batch Normalization (FedBN) aggregation strategy. +/// +/// +/// FedBN is a specialized aggregation strategy that handles batch normalization layers +/// differently from other layers in neural networks. Proposed by Li et al. in 2021, +/// it addresses the challenge of non-IID data by keeping batch normalization parameters local. +/// +/// For Beginners: FedBN recognizes that some parts of a neural network should remain +/// personalized to each client rather than being averaged globally. +/// +/// The key insight: +/// - Batch Normalization (BN) layers learn statistics specific to each client's data +/// - Averaging BN parameters across clients with different data distributions hurts performance +/// - Solution: Keep BN layers local, only aggregate other layers (Conv, FC, etc.) +/// +/// How FedBN works: +/// 1. During aggregation, identify batch normalization layers +/// 2. Aggregate only non-BN layers using weighted averaging +/// 3. Keep each client's BN layers unchanged (personalized) +/// 4. Send back global model with client-specific BN layers +/// +/// For example, in a CNN with layers: +/// - Conv1 (filters) → BN1 (normalization) → ReLU → Conv2 → BN2 → FC (classification) +/// +/// FedBN aggregates: +/// - ✓ Conv1 filters: Averaged across clients +/// - ✗ BN1 params: Kept local to each client +/// - ✓ Conv2 filters: Averaged across clients +/// - ✗ BN2 params: Kept local to each client +/// - ✓ FC weights: Averaged across clients +/// +/// Why this matters: +/// - Different clients may have different data ranges, distributions +/// - Hospital A images: brightness range [0, 100] +/// - Hospital B images: brightness range [50, 200] +/// - Each needs different normalization parameters +/// - Shared feature extractors (Conv layers) + personalized normalization works better +/// +/// When to use FedBN: +/// - Training deep neural networks (especially CNNs) +/// - Non-IID data with distribution shift +/// - Batch normalization or layer normalization in architecture +/// - Want to improve accuracy without changing training much +/// +/// Benefits: +/// - Significantly improves accuracy on non-IID data +/// - Simple modification to FedAvg +/// - No additional communication cost +/// - Each client keeps personalized normalization +/// +/// Limitations: +/// - Only helps when using batch normalization +/// - Doesn't address other heterogeneity challenges +/// - Requires identifying BN layers in model structure +/// +/// Reference: Li, X., et al. (2021). "Federated Learning on Non-IID Data Silos: An Experimental Study." +/// ICDE 2021. +/// +/// The numeric type for model parameters (e.g., double, float). +public class FedBNAggregationStrategy : IAggregationStrategy> + where T : struct, IComparable, IConvertible +{ + private readonly HashSet _batchNormLayerPatterns; + + /// + /// Initializes a new instance of the class. + /// + /// + /// For Beginners: Creates a FedBN aggregator that knows how to identify + /// batch normalization layers in your model. + /// + /// Common BN layer naming patterns: + /// - "bn", "batchnorm", "batch_norm": Explicit BN layers + /// - "gamma", "beta": BN trainable parameters + /// - "running_mean", "running_var": BN statistics + /// + /// The strategy will exclude these from aggregation, keeping them client-specific. + /// + /// + /// Patterns to identify batch normalization layers. If null, uses default patterns. + /// + public FedBNAggregationStrategy(HashSet? batchNormLayerPatterns = null) + { + // Default patterns for identifying batch normalization layers + _batchNormLayerPatterns = batchNormLayerPatterns ?? new HashSet(StringComparer.OrdinalIgnoreCase) + { + "bn", + "batchnorm", + "batch_norm", + "batch_normalization", + "gamma", + "beta", + "running_mean", + "running_var", + "moving_mean", + "moving_variance" + }; + } + + /// + /// Aggregates client models while keeping batch normalization layers local. + /// + /// + /// This method implements selective aggregation: + /// + /// For Beginners: Think of this as a smart averaging that knows some parameters + /// should stay personal (like BN layers) while others should be shared (like Conv layers). + /// + /// Step-by-step process: + /// 1. For each layer in the model: + /// - Check if it's a batch normalization layer (by name matching) + /// - If BN: Keep first client's values (could be any client's, they stay local) + /// - If not BN: Compute weighted average across all clients + /// 2. Return aggregated model + /// + /// Mathematical formulation: + /// For non-BN layers: + /// w_global[layer] = Σ(n_k / n_total) × w_k[layer] + /// + /// For BN layers: + /// w_global[layer] = w_client[layer] (keeps local values) + /// + /// For example, with 3 clients and a model with: + /// - "conv1": [0.5, 0.6, 0.7] at clients → Average these + /// - "bn1_gamma": [1.0, 1.2, 0.9] at clients → Keep local (don't average) + /// - "conv2": [0.3, 0.4, 0.5] at clients → Average these + /// - "bn2_beta": [0.1, 0.2, 0.15] at clients → Keep local (don't average) + /// + /// Note: In practice, each client would maintain their own BN parameters. + /// The "global" model returned includes BN params that each client will replace + /// with their local version upon receiving the update. + /// + /// Dictionary mapping client IDs to their model parameters. + /// Dictionary mapping client IDs to their sample counts (weights). + /// The aggregated global model parameters with BN layers excluded from aggregation. + public Dictionary Aggregate( + Dictionary> clientModels, + Dictionary clientWeights) + { + if (clientModels == null || clientModels.Count == 0) + { + throw new ArgumentException("Client models cannot be null or empty.", nameof(clientModels)); + } + + if (clientWeights == null || clientWeights.Count == 0) + { + throw new ArgumentException("Client weights cannot be null or empty.", nameof(clientWeights)); + } + + double totalWeight = clientWeights.Values.Sum(); + + if (totalWeight <= 0) + { + throw new ArgumentException("Total weight must be positive.", nameof(clientWeights)); + } + + var firstClientModel = clientModels.First().Value; + var aggregatedModel = new Dictionary(); + + // Process each layer + foreach (var layerName in firstClientModel.Keys) + { + // Check if this is a batch normalization layer + bool isBatchNormLayer = IsBatchNormalizationLayer(layerName); + + if (isBatchNormLayer) + { + // For BN layers, keep the first client's parameters (they stay local) + // In practice, each client will maintain their own BN params + aggregatedModel[layerName] = (T[])firstClientModel[layerName].Clone(); + } + else + { + // For non-BN layers, perform weighted aggregation (like FedAvg) + aggregatedModel[layerName] = new T[firstClientModel[layerName].Length]; + + foreach (var clientId in clientModels.Keys) + { + var clientModel = clientModels[clientId]; + var clientWeight = clientWeights[clientId]; + double normalizedWeight = clientWeight / totalWeight; + + var clientParams = clientModel[layerName]; + var aggregatedParams = aggregatedModel[layerName]; + + for (int i = 0; i < clientParams.Length; i++) + { + double currentValue = Convert.ToDouble(aggregatedParams[i]); + double clientValue = Convert.ToDouble(clientParams[i]); + double weightedValue = currentValue + (normalizedWeight * clientValue); + + aggregatedParams[i] = (T)Convert.ChangeType(weightedValue, typeof(T)); + } + } + } + } + + return aggregatedModel; + } + + /// + /// Determines whether a layer is a batch normalization layer based on its name. + /// + /// + /// For Beginners: This checks if a layer name contains any of the known + /// batch normalization patterns. + /// + /// For example: + /// - "conv1_weights" → false (not BN) + /// - "bn1_gamma" → true (contains "bn") + /// - "batch_norm_2_beta" → true (contains "batch_norm") + /// - "fc_bias" → false (not BN) + /// + /// The name of the layer to check. + /// True if the layer is a batch normalization layer, false otherwise. + private bool IsBatchNormalizationLayer(string layerName) + { + string lowerLayerName = layerName.ToLowerInvariant(); + + foreach (var pattern in _batchNormLayerPatterns) + { + if (lowerLayerName.Contains(pattern.ToLowerInvariant())) + { + return true; + } + } + + return false; + } + + /// + /// Gets the name of the aggregation strategy. + /// + /// The string "FedBN". + public string GetStrategyName() + { + return "FedBN"; + } + + /// + /// Gets the batch normalization layer patterns used for identification. + /// + /// + /// For Beginners: Returns the list of patterns used to recognize which + /// layers are batch normalization layers. + /// + /// A set of BN layer patterns. + public IReadOnlySet GetBatchNormPatterns() + { + return _batchNormLayerPatterns; + } +} diff --git a/src/FederatedLearning/Aggregators/FedProxAggregationStrategy.cs b/src/FederatedLearning/Aggregators/FedProxAggregationStrategy.cs new file mode 100644 index 000000000..0f033d666 --- /dev/null +++ b/src/FederatedLearning/Aggregators/FedProxAggregationStrategy.cs @@ -0,0 +1,194 @@ +namespace AiDotNet.FederatedLearning.Aggregators; + +using AiDotNet.Interfaces; + +/// +/// Implements the Federated Proximal (FedProx) aggregation strategy. +/// +/// +/// FedProx is an extension of FedAvg that handles system and statistical heterogeneity +/// in federated learning. It was proposed by Li et al. in 2020 to address challenges +/// when clients have different computational capabilities or data distributions. +/// +/// For Beginners: FedProx is like FedAvg with a "safety rope" that prevents +/// individual clients from pulling the shared model too far in their own direction. +/// +/// Key differences from FedAvg: +/// 1. Adds a proximal term to local training objective +/// 2. Prevents client models from deviating too much from global model +/// 3. Improves convergence when clients have heterogeneous data or capabilities +/// +/// How FedProx works: +/// During local training, each client minimizes: +/// Local Loss + (μ/2) × ||w - w_global||² +/// +/// where: +/// - Local Loss: Standard loss on client's data +/// - μ (mu): Proximal term coefficient (controls constraint strength) +/// - w: Client's current model weights +/// - w_global: Global model weights received from server +/// - ||w - w_global||²: Squared distance between client and global model +/// +/// For example, with μ = 0.01: +/// - Client trains on local data +/// - Proximal term penalizes large deviations from global model +/// - If client's data is very different, can still adapt but with limitation +/// - Prevents overfitting to local data distribution +/// +/// When to use FedProx: +/// - Non-IID data (different distributions across clients) +/// - System heterogeneity (some clients much slower/faster) +/// - Want more stable convergence than FedAvg +/// - Stragglers problem (some clients take much longer) +/// +/// Benefits: +/// - Better convergence on non-IID data +/// - More robust to stragglers +/// - Theoretically proven convergence guarantees +/// - Small computational overhead +/// +/// Limitations: +/// - Requires tuning μ parameter +/// - Slightly slower local training per iteration +/// - May converge slower if μ is too large +/// +/// Reference: Li, T., et al. (2020). "Federated Optimization in Heterogeneous Networks." +/// MLSys 2020. +/// +/// The numeric type for model parameters (e.g., double, float). +public class FedProxAggregationStrategy : IAggregationStrategy> + where T : struct, IComparable, IConvertible +{ + private readonly double _mu; + + /// + /// Initializes a new instance of the class. + /// + /// + /// For Beginners: Creates a FedProx aggregator with a specified proximal term strength. + /// + /// The μ (mu) parameter controls the trade-off between local adaptation and global consistency: + /// - μ = 0: Equivalent to FedAvg (no constraint) + /// - μ = 0.01: Mild constraint (recommended starting point) + /// - μ = 0.1: Moderate constraint + /// - μ = 1.0+: Strong constraint (may be too restrictive) + /// + /// Recommendations: + /// - Start with μ = 0.01 + /// - Increase if convergence is unstable + /// - Decrease if convergence is too slow + /// + /// The proximal term coefficient (typically 0.01 to 1.0). + public FedProxAggregationStrategy(double mu = 0.01) + { + if (mu < 0) + { + throw new ArgumentException("Mu must be non-negative.", nameof(mu)); + } + + _mu = mu; + } + + /// + /// Aggregates client models using FedProx weighted averaging. + /// + /// + /// The aggregation step in FedProx is identical to FedAvg. The key difference is in + /// the local training objective (which includes the proximal term), not in aggregation. + /// + /// For Beginners: At the server side, FedProx aggregates just like FedAvg. + /// The magic happens during client-side training where the proximal term keeps + /// client models from straying too far. + /// + /// Aggregation formula (same as FedAvg): + /// w_global = Σ(n_k / n_total) × w_k + /// + /// The proximal term μ affects how w_k is computed during local training, but not + /// how we aggregate the models here. + /// + /// For implementation in local training (not shown here): + /// - Gradient = ∇Loss + μ(w - w_global) + /// - This additional term pulls weights towards global model + /// + /// Dictionary mapping client IDs to their model parameters. + /// Dictionary mapping client IDs to their sample counts (weights). + /// The aggregated global model parameters. + public Dictionary Aggregate( + Dictionary> clientModels, + Dictionary clientWeights) + { + if (clientModels == null || clientModels.Count == 0) + { + throw new ArgumentException("Client models cannot be null or empty.", nameof(clientModels)); + } + + if (clientWeights == null || clientWeights.Count == 0) + { + throw new ArgumentException("Client weights cannot be null or empty.", nameof(clientWeights)); + } + + // Calculate total weight + double totalWeight = clientWeights.Values.Sum(); + + if (totalWeight <= 0) + { + throw new ArgumentException("Total weight must be positive.", nameof(clientWeights)); + } + + // Initialize aggregated model + var firstClientModel = clientModels.First().Value; + var aggregatedModel = new Dictionary(); + + foreach (var layerName in firstClientModel.Keys) + { + aggregatedModel[layerName] = new T[firstClientModel[layerName].Length]; + } + + // Perform weighted aggregation (same as FedAvg) + foreach (var clientId in clientModels.Keys) + { + var clientModel = clientModels[clientId]; + var clientWeight = clientWeights[clientId]; + double normalizedWeight = clientWeight / totalWeight; + + foreach (var layerName in clientModel.Keys) + { + var clientParams = clientModel[layerName]; + var aggregatedParams = aggregatedModel[layerName]; + + for (int i = 0; i < clientParams.Length; i++) + { + double currentValue = Convert.ToDouble(aggregatedParams[i]); + double clientValue = Convert.ToDouble(clientParams[i]); + double weightedValue = currentValue + (normalizedWeight * clientValue); + + aggregatedParams[i] = (T)Convert.ChangeType(weightedValue, typeof(T)); + } + } + } + + return aggregatedModel; + } + + /// + /// Gets the name of the aggregation strategy. + /// + /// A string indicating "FedProx" with the μ parameter value. + public string GetStrategyName() + { + return $"FedProx(μ={_mu})"; + } + + /// + /// Gets the proximal term coefficient μ. + /// + /// + /// For Beginners: Returns the strength of the constraint that keeps client + /// models from deviating too far from the global model. + /// + /// The μ parameter value. + public double GetMu() + { + return _mu; + } +} diff --git a/src/FederatedLearning/Personalization/PersonalizedFederatedLearning.cs b/src/FederatedLearning/Personalization/PersonalizedFederatedLearning.cs new file mode 100644 index 000000000..6f28d480a --- /dev/null +++ b/src/FederatedLearning/Personalization/PersonalizedFederatedLearning.cs @@ -0,0 +1,371 @@ +namespace AiDotNet.FederatedLearning.Personalization; + +using System; +using System.Collections.Generic; +using System.Linq; + +/// +/// Implements personalized federated learning where each client maintains some client-specific parameters. +/// +/// +/// Personalized Federated Learning (PFL) addresses the challenge of heterogeneous data distributions +/// across clients by allowing each client to maintain personalized model components while still +/// benefiting from collaborative learning. +/// +/// For Beginners: Personalized FL is like having a shared textbook but personal notes. +/// Everyone learns from the same core material (global model) but adapts it to their specific +/// needs (personalized layers). +/// +/// Key concept: +/// - Global layers: Shared across all clients, learn common patterns +/// - Personalized layers: Client-specific, adapt to local data distribution +/// - Clients train both but only global layers are aggregated +/// +/// How it works: +/// 1. Model is split into global and personalized parts +/// 2. During local training, both parts are updated +/// 3. Only global parts are sent to server for aggregation +/// 4. Personalized parts stay on the client +/// 5. Client receives updated global parts and keeps personalized parts +/// +/// For example, in healthcare: +/// - Hospital A: Urban population, young average age +/// - Hospital B: Rural population, old average age +/// - Hospital C: Suburban population, mixed age +/// +/// Model structure: +/// - Global layers (shared): General disease detection features +/// - Personalized layers: Adapt to local demographics +/// +/// Benefits: +/// - Better performance on non-IID data +/// - Each client gets a model optimized for their data +/// - Preserves privacy (personalized parts never leave client) +/// - Relatively simple to implement +/// +/// Common approaches: +/// 1. Layer-wise personalization: Last few layers personalized +/// 2. Feature-wise personalization: Some features personalized +/// 3. Meta-learning: Learn how to adapt quickly to local data +/// 4. Multi-task learning: Treat each client as a separate task +/// +/// When to use PFL: +/// - Clients have significantly different data distributions +/// - Standard FedAvg performance is poor +/// - Can afford client-side storage for personalized parameters +/// - Want better local performance even at cost of global performance +/// +/// Limitations: +/// - Requires more storage on client (for personalized params) +/// - May sacrifice some global model quality +/// - Need to choose which layers to personalize +/// - Risk of overfitting to local data +/// +/// Reference: +/// - Wang, K., et al. (2019). "Federated Evaluation of On-device Personalization." arXiv preprint. +/// - Fallah, A., et al. (2020). "Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach." NeurIPS 2020. +/// +/// The numeric type for model parameters (e.g., double, float). +public class PersonalizedFederatedLearning + where T : struct, IComparable, IConvertible +{ + private readonly double _personalizationFraction; + private readonly HashSet _personalizedLayers; + + /// + /// Initializes a new instance of the class. + /// + /// + /// For Beginners: Sets up personalized federated learning with a specified + /// fraction of the model kept personalized. + /// + /// The personalization fraction determines the split: + /// - 0.0: No personalization (standard federated learning) + /// - 0.2: Last 20% of layers personalized (common choice) + /// - 0.5: Half personalized, half global + /// - 1.0: Fully personalized (no collaboration) + /// + /// Typical strategy: + /// - Keep early layers (feature extractors) global + /// - Keep late layers (task-specific) personalized + /// + /// For example, in a CNN: + /// - Conv layers 1-3: Global (learn general visual features) + /// - Conv layers 4-5: Personalized (adapt to local image characteristics) + /// - FC layers: Personalized (task-specific classification) + /// + /// + /// The fraction of model layers to keep personalized (0.0 to 1.0). + /// Typically 0.2 for last 20% of layers. + /// + public PersonalizedFederatedLearning(double personalizationFraction = 0.2) + { + if (personalizationFraction < 0.0 || personalizationFraction > 1.0) + { + throw new ArgumentException("Personalization fraction must be between 0 and 1.", nameof(personalizationFraction)); + } + + _personalizationFraction = personalizationFraction; + _personalizedLayers = new HashSet(); + } + + /// + /// Identifies which layers should be personalized based on the model structure. + /// + /// + /// For Beginners: This decides which parts of the model will be personalized + /// vs. shared globally. + /// + /// Common strategies: + /// 1. Last-N layers: Personalize the final layers (default) + /// 2. By name: Personalize layers matching specific patterns + /// 3. By type: Personalize certain layer types (e.g., batch norm) + /// + /// For example, with 10 layers and 20% personalization: + /// - Layers 0-7: Global (shared) + /// - Layers 8-9: Personalized (last 2 layers = 20%) + /// + /// The intuition: + /// - Early layers learn low-level features (edges, textures) → should be shared + /// - Late layers learn high-level, task-specific features → can be personalized + /// + /// The model structure with layer names. + /// The strategy for selecting personalized layers ("last_n", "by_pattern"). + /// Optional patterns for "by_pattern" strategy. + public void IdentifyPersonalizedLayers( + Dictionary modelStructure, + string strategy = "last_n", + HashSet? customPatterns = null) + { + if (modelStructure == null || modelStructure.Count == 0) + { + throw new ArgumentException("Model structure cannot be null or empty.", nameof(modelStructure)); + } + + _personalizedLayers.Clear(); + + if (strategy == "last_n") + { + // Personalize the last N% of layers + int totalLayers = modelStructure.Count; + int personalizedCount = (int)Math.Ceiling(totalLayers * _personalizationFraction); + + var layerNames = modelStructure.Keys.ToList(); + + // Take the last personalizedCount layers + for (int i = totalLayers - personalizedCount; i < totalLayers; i++) + { + _personalizedLayers.Add(layerNames[i]); + } + } + else if (strategy == "by_pattern" && customPatterns != null) + { + // Personalize layers matching specific patterns + foreach (var layerName in modelStructure.Keys) + { + foreach (var pattern in customPatterns) + { + if (layerName.Contains(pattern, StringComparison.OrdinalIgnoreCase)) + { + _personalizedLayers.Add(layerName); + break; + } + } + } + } + else + { + throw new ArgumentException($"Unknown strategy: {strategy}. Use 'last_n' or 'by_pattern'.", nameof(strategy)); + } + } + + /// + /// Separates a model into global and personalized components. + /// + /// + /// For Beginners: Splits the model into two parts: + /// - Global part: Will be sent to server and aggregated + /// - Personalized part: Stays on client + /// + /// For example: + /// Original model: {layer1: [...], layer2: [...], layer3: [...], layer4: [...]} + /// If layers 3-4 are personalized: + /// - Global: {layer1: [...], layer2: [...]} + /// - Personalized: {layer3: [...], layer4: [...]} + /// + /// This separation enables: + /// - Efficient communication (only send global parts) + /// - Privacy (personalized parts never leave client) + /// - Flexibility (different personalization per client) + /// + /// The complete model with all layers. + /// Output: The global layers to be aggregated. + /// Output: The personalized layers to keep local. + public void SeparateModel( + Dictionary fullModel, + out Dictionary globalPart, + out Dictionary personalizedPart) + { + if (fullModel == null || fullModel.Count == 0) + { + throw new ArgumentException("Full model cannot be null or empty.", nameof(fullModel)); + } + + globalPart = new Dictionary(); + personalizedPart = new Dictionary(); + + foreach (var layer in fullModel) + { + if (_personalizedLayers.Contains(layer.Key)) + { + // This layer is personalized - keep local + personalizedPart[layer.Key] = (T[])layer.Value.Clone(); + } + else + { + // This layer is global - will be aggregated + globalPart[layer.Key] = (T[])layer.Value.Clone(); + } + } + } + + /// + /// Combines global model update with client's personalized layers. + /// + /// + /// For Beginners: After the server aggregates global layers, each client + /// combines the updated global layers with their own personalized layers to form + /// the complete model for the next round. + /// + /// Process: + /// 1. Receive updated global layers from server + /// 2. Retrieve client's personalized layers from local storage + /// 3. Merge them into one complete model + /// 4. Ready for next round of local training + /// + /// For example: + /// Server sends global update: {layer1: [...], layer2: [...]} + /// Client has personalized: {layer3: [...], layer4: [...]} + /// Combined model: {layer1: [...], layer2: [...], layer3: [...], layer4: [...]} + /// + /// This ensures: + /// - Global knowledge is incorporated (layers 1-2 updated) + /// - Local adaptation is preserved (layers 3-4 unchanged) + /// - Model structure remains consistent + /// + /// The updated global layers from server. + /// The client's personalized layers. + /// The complete model combining both parts. + public Dictionary CombineModels( + Dictionary globalUpdate, + Dictionary personalizedLayers) + { + if (globalUpdate == null) + { + throw new ArgumentNullException(nameof(globalUpdate)); + } + + if (personalizedLayers == null) + { + throw new ArgumentNullException(nameof(personalizedLayers)); + } + + var combinedModel = new Dictionary(); + + // Add all global layers + foreach (var layer in globalUpdate) + { + combinedModel[layer.Key] = (T[])layer.Value.Clone(); + } + + // Add all personalized layers + foreach (var layer in personalizedLayers) + { + combinedModel[layer.Key] = (T[])layer.Value.Clone(); + } + + return combinedModel; + } + + /// + /// Checks if a specific layer is personalized. + /// + /// + /// For Beginners: Returns whether a given layer should be kept local + /// (personalized) or sent to the server (global). + /// + /// The name of the layer to check. + /// True if the layer is personalized, false if global. + public bool IsLayerPersonalized(string layerName) + { + return _personalizedLayers.Contains(layerName); + } + + /// + /// Gets the set of all personalized layer names. + /// + /// + /// For Beginners: Returns the list of which layers are personalized. + /// Useful for logging, debugging, and understanding the model split. + /// + /// A read-only set of personalized layer names. + public IReadOnlySet GetPersonalizedLayers() + { + return _personalizedLayers; + } + + /// + /// Gets the personalization fraction. + /// + /// The fraction of layers that are personalized. + public double GetPersonalizationFraction() + { + return _personalizationFraction; + } + + /// + /// Calculates statistics about the model split. + /// + /// + /// For Beginners: Provides useful information about how the model is divided: + /// - How many parameters are global vs. personalized + /// - What percentage of the model is personalized + /// - Communication savings from personalization + /// + /// This helps understand the trade-offs: + /// - More personalized → Less communication, more storage per client + /// - More global → More communication, less storage per client + /// + /// The complete model. + /// A dictionary with statistics. + public Dictionary GetModelSplitStatistics(Dictionary fullModel) + { + if (fullModel == null || fullModel.Count == 0) + { + throw new ArgumentException("Full model cannot be null or empty.", nameof(fullModel)); + } + + int totalParams = fullModel.Values.Sum(layer => layer.Length); + int personalizedParams = fullModel + .Where(layer => _personalizedLayers.Contains(layer.Key)) + .Sum(layer => layer.Value.Length); + int globalParams = totalParams - personalizedParams; + + int totalLayers = fullModel.Count; + int personalizedLayerCount = _personalizedLayers.Count; + int globalLayerCount = totalLayers - personalizedLayerCount; + + return new Dictionary + { + ["total_parameters"] = totalParams, + ["global_parameters"] = globalParams, + ["personalized_parameters"] = personalizedParams, + ["global_parameter_fraction"] = totalParams > 0 ? (double)globalParams / totalParams : 0, + ["personalized_parameter_fraction"] = totalParams > 0 ? (double)personalizedParams / totalParams : 0, + ["total_layers"] = totalLayers, + ["global_layers"] = globalLayerCount, + ["personalized_layers"] = personalizedLayerCount, + ["communication_reduction"] = totalParams > 0 ? (double)personalizedParams / totalParams : 0 + }; + } +} diff --git a/src/FederatedLearning/Privacy/GaussianDifferentialPrivacy.cs b/src/FederatedLearning/Privacy/GaussianDifferentialPrivacy.cs new file mode 100644 index 000000000..5f99c92e5 --- /dev/null +++ b/src/FederatedLearning/Privacy/GaussianDifferentialPrivacy.cs @@ -0,0 +1,337 @@ +namespace AiDotNet.FederatedLearning.Privacy; + +using AiDotNet.Interfaces; +using System; + +/// +/// Implements differential privacy using the Gaussian mechanism. +/// +/// +/// The Gaussian mechanism provides (ε, δ)-differential privacy by adding calibrated +/// Gaussian (normal distribution) noise to model parameters. This is one of the most +/// widely used privacy mechanisms in federated learning. +/// +/// For Beginners: Differential privacy is like adding static to a phone conversation. +/// You add just enough noise that individual voices can't be identified, but the overall +/// message still gets through clearly. +/// +/// How Gaussian Differential Privacy works: +/// 1. Clip gradients/parameters to bound sensitivity (maximum change any single data point can cause) +/// 2. Add Gaussian noise: noise ~ N(0, σ²) where σ depends on ε, δ, and sensitivity +/// 3. The noise calibration ensures (ε, δ)-DP guarantee +/// +/// Mathematical formulation: +/// - Sensitivity Δ: Maximum L2 norm of gradient for any single training example +/// - Noise scale σ = (Δ/ε) × sqrt(2 × ln(1.25/δ)) +/// - For each parameter w: w_private = w + N(0, σ²) +/// +/// Privacy parameters: +/// - ε (epsilon): Privacy budget - smaller is more private +/// * ε = 0.1: Very strong privacy, significant noise +/// * ε = 1.0: Strong privacy, moderate noise (recommended) +/// * ε = 10: Weak privacy, minimal noise +/// +/// - δ (delta): Failure probability - should be very small +/// * Typically δ = 1/n² where n is dataset size +/// * Common choice: δ = 1e-5 +/// +/// For example, protecting hospital patient data: +/// - Original gradient: [0.5, -0.3, 0.8, -0.2] +/// - Clip to max norm 1.0: [0.45, -0.27, 0.72, -0.18] (clipped) +/// - Add Gaussian noise with σ=0.1: [0.47, -0.29, 0.75, -0.21] +/// - Result: Individual patient influence is masked by noise +/// +/// Privacy composition: +/// - Each time you share data, you consume privacy budget +/// - After T rounds with ε per round: total ε_total ≈ ε × sqrt(2T × ln(1/δ)) +/// - This is more efficient than naive composition (ε × T) +/// +/// Trade-offs: +/// - More privacy (smaller ε) → more noise → lower accuracy +/// - Less privacy (larger ε) → less noise → higher accuracy +/// - Must find acceptable balance for your application +/// +/// When to use Gaussian DP: +/// - Need provable privacy guarantees +/// - Working with sensitive data (healthcare, finance) +/// - Regulatory requirements (GDPR, HIPAA) +/// - Publishing models or sharing with untrusted parties +/// +/// Reference: Dwork, C., & Roth, A. (2014). "The Algorithmic Foundations of Differential Privacy." +/// Abadi, M., et al. (2016). "Deep Learning with Differential Privacy." CCS 2016. +/// +/// The numeric type for model parameters (e.g., double, float). +public class GaussianDifferentialPrivacy : IPrivacyMechanism> + where T : struct, IComparable, IConvertible +{ + private double _privacyBudgetConsumed; + private readonly double _clipNorm; + private readonly Random _random; + + /// + /// Initializes a new instance of the class. + /// + /// + /// For Beginners: Creates a differential privacy mechanism with a specified + /// gradient clipping threshold. + /// + /// Gradient clipping (clipNorm) is crucial for DP: + /// - Bounds the maximum influence any single data point can have + /// - Makes noise calibration possible + /// - Common values: 0.1 - 10.0 depending on model and data + /// + /// Lower clipNorm: + /// - Stronger privacy guarantee + /// - More aggressive clipping + /// - May slow convergence + /// + /// Higher clipNorm: + /// - Less clipping + /// - Faster convergence + /// - Requires more noise for same privacy + /// + /// Recommendations: + /// - Start with clipNorm = 1.0 + /// - Monitor gradient norms during training + /// - Adjust based on typical gradient magnitudes + /// + /// The maximum L2 norm for gradient clipping (sensitivity bound). + /// Optional random seed for reproducibility. + public GaussianDifferentialPrivacy(double clipNorm = 1.0, int? randomSeed = null) + { + if (clipNorm <= 0) + { + throw new ArgumentException("Clip norm must be positive.", nameof(clipNorm)); + } + + _clipNorm = clipNorm; + _privacyBudgetConsumed = 0.0; + _random = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random(); + } + + /// + /// Applies differential privacy to model parameters by adding calibrated Gaussian noise. + /// + /// + /// This method implements the Gaussian mechanism for (ε, δ)-differential privacy: + /// + /// For Beginners: This adds carefully calculated random noise to protect privacy + /// while maintaining model utility. + /// + /// Step-by-step process: + /// 1. Calculate current L2 norm of model parameters + /// 2. If norm > clipNorm, scale down parameters to clipNorm + /// 3. Calculate noise scale σ based on ε, δ, and sensitivity + /// 4. Add Gaussian noise N(0, σ²) to each parameter + /// 5. Update privacy budget consumed + /// + /// Mathematical details: + /// - Sensitivity Δ = clipNorm (worst-case parameter change) + /// - σ = (Δ/ε) × sqrt(2 × ln(1.25/δ)) + /// - Noise ~ N(0, σ²) added to each parameter independently + /// + /// For example, with ε=1.0, δ=1e-5, clipNorm=1.0: + /// - σ = (1.0/1.0) × sqrt(2 × ln(125000)) ≈ 4.7 + /// - Each parameter gets noise from N(0, 4.7²) + /// - Original params: [0.5, -0.3, 0.8] + /// - Noisy params: [0.52, -0.35, 0.83] (example with small noise realization) + /// + /// Privacy accounting: + /// - Each call consumes ε privacy budget + /// - Total budget accumulates: ε_total = ε_1 + ε_2 + ... (simplified) + /// - Advanced: Use Rényi DP for tighter composition bounds + /// + /// The model parameters to add noise to. + /// Privacy budget for this operation (smaller = more private). + /// Failure probability (typically 1e-5 or smaller). + /// The model with differential privacy applied. + public Dictionary ApplyPrivacy(Dictionary model, double epsilon, double delta) + { + if (model == null || model.Count == 0) + { + throw new ArgumentException("Model cannot be null or empty.", nameof(model)); + } + + if (epsilon <= 0) + { + throw new ArgumentException("Epsilon must be positive.", nameof(epsilon)); + } + + if (delta <= 0 || delta >= 1) + { + throw new ArgumentException("Delta must be in (0, 1).", nameof(delta)); + } + + // Create a copy of the model + var noisyModel = new Dictionary(); + foreach (var layer in model) + { + noisyModel[layer.Key] = (T[])layer.Value.Clone(); + } + + // Step 1: Gradient clipping - Calculate L2 norm of all parameters + double l2Norm = CalculateL2Norm(noisyModel); + + // If norm exceeds clip threshold, scale down + if (l2Norm > _clipNorm) + { + double scaleFactor = _clipNorm / l2Norm; + + foreach (var layerName in noisyModel.Keys) + { + var parameters = noisyModel[layerName]; + for (int i = 0; i < parameters.Length; i++) + { + double value = Convert.ToDouble(parameters[i]); + parameters[i] = (T)Convert.ChangeType(value * scaleFactor, typeof(T)); + } + } + } + + // Step 2: Calculate noise scale based on Gaussian mechanism + // σ = (Δ/ε) × sqrt(2 × ln(1.25/δ)) + // where Δ = clipNorm (sensitivity) + double sensitivity = _clipNorm; + double noiseSigma = (sensitivity / epsilon) * Math.Sqrt(2.0 * Math.Log(1.25 / delta)); + + // Step 3: Add Gaussian noise to each parameter + foreach (var layerName in noisyModel.Keys) + { + var parameters = noisyModel[layerName]; + for (int i = 0; i < parameters.Length; i++) + { + double value = Convert.ToDouble(parameters[i]); + double noise = GenerateGaussianNoise(0.0, noiseSigma); + parameters[i] = (T)Convert.ChangeType(value + noise, typeof(T)); + } + } + + // Update privacy budget consumed + _privacyBudgetConsumed += epsilon; + + return noisyModel; + } + + /// + /// Calculates the L2 norm (Euclidean norm) of all model parameters. + /// + /// + /// For Beginners: L2 norm is the "length" of the parameter vector in + /// high-dimensional space. It's calculated as sqrt(sum of squares). + /// + /// For example, with parameters [3, 4]: + /// - L2 norm = sqrt(3² + 4²) = sqrt(9 + 16) = sqrt(25) = 5 + /// + /// Used for gradient clipping to bound sensitivity. + /// + /// The model to calculate norm for. + /// The L2 norm of all parameters. + private double CalculateL2Norm(Dictionary model) + { + double sumOfSquares = 0.0; + + foreach (var layer in model.Values) + { + foreach (var param in layer) + { + double value = Convert.ToDouble(param); + sumOfSquares += value * value; + } + } + + return Math.Sqrt(sumOfSquares); + } + + /// + /// Generates a sample from a Gaussian (normal) distribution. + /// + /// + /// Uses the Box-Muller transform to generate Gaussian random variables from + /// uniform random variables. + /// + /// For Beginners: This creates random noise from a bell curve distribution. + /// Most noise values will be close to the mean, with rare large values. + /// + /// Box-Muller transform: + /// - Generate two uniform random numbers U1, U2 in [0, 1] + /// - Z = sqrt(-2 × ln(U1)) × cos(2π × U2) + /// - Z follows standard normal N(0, 1) + /// - Scale and shift: X = mean + sigma × Z + /// + /// The mean of the Gaussian distribution. + /// The standard deviation of the Gaussian distribution. + /// A random sample from N(mean, sigma²). + private double GenerateGaussianNoise(double mean, double sigma) + { + // Box-Muller transform + double u1 = 1.0 - _random.NextDouble(); // Uniform(0,1] + double u2 = 1.0 - _random.NextDouble(); + double standardNormal = Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Cos(2.0 * Math.PI * u2); + + return mean + sigma * standardNormal; + } + + /// + /// Gets the total privacy budget consumed so far. + /// + /// + /// For Beginners: Returns how much privacy budget has been used up. + /// Privacy budget is cumulative - once spent, it's gone. + /// + /// For example: + /// - Round 1: ε = 0.5 consumed, total = 0.5 + /// - Round 2: ε = 0.5 consumed, total = 1.0 + /// - Round 3: ε = 0.5 consumed, total = 1.5 + /// + /// If you started with total budget 10.0, you have 8.5 remaining. + /// + /// Note: This uses basic composition. Advanced composition (Rényi DP) gives + /// tighter bounds and would show less budget consumed. + /// + /// The cumulative privacy budget (epsilon) consumed. + public double GetPrivacyBudgetConsumed() + { + return _privacyBudgetConsumed; + } + + /// + /// Gets the name of the privacy mechanism. + /// + /// A string describing the mechanism. + public string GetMechanismName() + { + return $"Gaussian DP (clip={_clipNorm})"; + } + + /// + /// Gets the gradient clipping norm used for sensitivity bounding. + /// + /// + /// For Beginners: Returns the maximum allowed parameter norm. + /// Parameters larger than this are scaled down before adding noise. + /// + /// The clipping norm value. + public double GetClipNorm() + { + return _clipNorm; + } + + /// + /// Resets the privacy budget counter. + /// + /// + /// For Beginners: Resets the privacy budget tracker to zero. + /// + /// WARNING: This should only be used when starting a completely new training run. + /// Do not reset during active training as it would give false privacy accounting. + /// + /// Use cases: + /// - Starting new experiment with same mechanism instance + /// - Testing and debugging + /// - Separate training phases with independent privacy guarantees + /// + public void ResetPrivacyBudget() + { + _privacyBudgetConsumed = 0.0; + } +} diff --git a/src/FederatedLearning/Privacy/SecureAggregation.cs b/src/FederatedLearning/Privacy/SecureAggregation.cs new file mode 100644 index 000000000..efe999339 --- /dev/null +++ b/src/FederatedLearning/Privacy/SecureAggregation.cs @@ -0,0 +1,381 @@ +namespace AiDotNet.FederatedLearning.Privacy; + +using System; +using System.Security.Cryptography; + +/// +/// Implements secure aggregation for federated learning using cryptographic techniques. +/// +/// +/// Secure aggregation is a cryptographic protocol that allows a server to compute the sum +/// of client updates without seeing individual contributions. Only the final aggregate is +/// visible to the server. +/// +/// For Beginners: Secure aggregation is like a secret ballot election where votes +/// are counted but individual votes remain private. +/// +/// How it works (simplified): +/// 1. Each client generates pairwise secret keys with other clients +/// 2. Clients mask their model updates with these secret keys +/// 3. Server receives masked updates: masked_update_i = update_i + Σ(secrets_ij) +/// 4. Secret masks cancel out when summing: Σ(masked_update_i) = Σ(update_i) +/// 5. Server gets the sum without seeing individual updates +/// +/// Example with 3 clients: +/// - Client 1 shares secrets: s₁₂ with Client 2, s₁₃ with Client 3 +/// - Client 2 shares secrets: s₂₁ with Client 1, s₂₃ with Client 3 +/// - Client 3 shares secrets: s₃₁ with Client 1, s₃₂ with Client 2 +/// +/// Note: s₁₂ = -s₂₁ (secrets cancel in pairs) +/// +/// Client 1 sends: update₁ + s₁₂ + s₁₃ +/// Client 2 sends: update₂ + s₂₁ + s₂₃ +/// Client 3 sends: update₃ + s₃₁ + s₃₂ +/// +/// Server computes sum: +/// (update₁ + s₁₂ + s₁₃) + (update₂ + s₂₁ + s₂₃) + (update₃ + s₃₁ + s₃₂) +/// = update₁ + update₂ + update₃ + (s₁₂ + s₂₁) + (s₁₃ + s₃₁) + (s₂₃ + s₃₂) +/// = update₁ + update₂ + update₃ + 0 + 0 + 0 +/// = Σ(updates) ← Only this is visible to server! +/// +/// This implementation uses a simplified version with random masking for demonstration. +/// Production systems should use proper cryptographic protocols like: +/// - Bonawitz et al.'s Secure Aggregation protocol +/// - Threshold homomorphic encryption +/// - Secret sharing schemes +/// +/// Benefits: +/// - Server cannot see individual client updates +/// - Protects against honest-but-curious server +/// - No trusted third party needed +/// - Computation overhead is reasonable +/// +/// Limitations: +/// - Requires coordination between clients +/// - All (or threshold) clients must participate for masks to cancel +/// - Dropout handling requires additional mechanisms +/// - Communication overhead for key exchange +/// +/// When to use Secure Aggregation: +/// - Don't fully trust the central server +/// - Regulatory requirements for data protection +/// - Want cryptographic privacy guarantees +/// - Willing to handle additional complexity +/// +/// Can be combined with differential privacy for stronger protection: +/// - Secure aggregation: Protects individual updates from server +/// - Differential privacy: Protects individual data points from anyone +/// +/// Reference: Bonawitz, K., et al. (2017). "Practical Secure Aggregation for Privacy-Preserving +/// Machine Learning." CCS 2017. +/// +/// The numeric type for model parameters (e.g., double, float). +public class SecureAggregation + where T : struct, IComparable, IConvertible +{ + private readonly Dictionary> _pairwiseSecrets; + private readonly Random _random; + private readonly int _parameterCount; + + /// + /// Initializes a new instance of the class. + /// + /// + /// For Beginners: Sets up the secure aggregation protocol for a specific + /// number of model parameters. + /// + /// In practice, this would involve: + /// - Secure key exchange between clients + /// - Authenticated channels + /// - Agreement on random seed for deterministic mask generation + /// + /// This simplified implementation uses pseudorandom masks that cancel out. + /// + /// The total number of model parameters to protect. + /// Optional random seed for reproducibility. + public SecureAggregation(int parameterCount, int? randomSeed = null) + { + if (parameterCount <= 0) + { + throw new ArgumentException("Parameter count must be positive.", nameof(parameterCount)); + } + + _parameterCount = parameterCount; + _pairwiseSecrets = new Dictionary>(); + _random = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random(); + } + + /// + /// Generates pairwise secrets between all clients. + /// + /// + /// For Beginners: This creates secret keys that clients will use to mask + /// their updates. The secrets are designed so they cancel out when aggregated. + /// + /// For each pair of clients (i, j): + /// - Generate random secret s_ij + /// - Set s_ji = -s_ij (so they cancel: s_ij + s_ji = 0) + /// + /// In production, this would use: + /// - Diffie-Hellman key exchange + /// - Public key infrastructure + /// - Secure random number generation + /// + /// List of all participating client IDs. + public void GeneratePairwiseSecrets(List clientIds) + { + if (clientIds == null || clientIds.Count < 2) + { + throw new ArgumentException("Need at least 2 clients for secure aggregation.", nameof(clientIds)); + } + + _pairwiseSecrets.Clear(); + + // Generate pairwise secrets for each pair of clients + for (int i = 0; i < clientIds.Count; i++) + { + int clientI = clientIds[i]; + _pairwiseSecrets[clientI] = new Dictionary(); + + for (int j = i + 1; j < clientIds.Count; j++) + { + int clientJ = clientIds[j]; + + // Generate random secret for this pair + double[] secret = new double[_parameterCount]; + for (int k = 0; k < _parameterCount; k++) + { + // Use cryptographically secure random in production + secret[k] = (_random.NextDouble() - 0.5) * 2.0; // Range: [-1, 1] + } + + // Store secret for client i with respect to client j + _pairwiseSecrets[clientI][clientJ] = secret; + + // Store negated secret for client j with respect to client i + // This ensures secrets cancel: secret_ij + secret_ji = 0 + if (!_pairwiseSecrets.ContainsKey(clientJ)) + { + _pairwiseSecrets[clientJ] = new Dictionary(); + } + + double[] negatedSecret = new double[_parameterCount]; + for (int k = 0; k < _parameterCount; k++) + { + negatedSecret[k] = -secret[k]; + } + _pairwiseSecrets[clientJ][clientI] = negatedSecret; + } + } + } + + /// + /// Masks a client's model update with pairwise secrets. + /// + /// + /// For Beginners: This adds secret masks to the client's update so the + /// server can't see the original values. Only after aggregating all clients + /// do the masks cancel out. + /// + /// Mathematical operation: + /// masked_update = original_update + Σ(secrets_with_other_clients) + /// + /// For example, Client 1 with 3 clients total: + /// - Original update: [0.5, -0.3, 0.8] + /// - Secret with Client 2: [0.1, 0.2, -0.1] + /// - Secret with Client 3: [-0.2, 0.1, 0.15] + /// - Masked update: [0.4, 0.0, 0.85] + /// + /// Server sees: [0.4, 0.0, 0.85] ← Cannot recover original [0.5, -0.3, 0.8] + /// But after aggregating all clients, secrets cancel and server gets correct sum. + /// + /// The ID of the client whose update to mask. + /// The client's model update. + /// The masked model update. + public Dictionary MaskUpdate(int clientId, Dictionary clientUpdate) + { + if (clientUpdate == null || clientUpdate.Count == 0) + { + throw new ArgumentException("Client update cannot be null or empty.", nameof(clientUpdate)); + } + + if (!_pairwiseSecrets.ContainsKey(clientId)) + { + throw new ArgumentException($"No secrets found for client {clientId}. Call GeneratePairwiseSecrets first.", nameof(clientId)); + } + + // Create masked update + var maskedUpdate = new Dictionary(); + + // Flatten all parameters to apply masks + var flatParams = FlattenParameters(clientUpdate); + var maskedFlatParams = new double[flatParams.Length]; + Array.Copy(flatParams, maskedFlatParams, flatParams.Length); + + // Add all pairwise secrets for this client + foreach (var otherClientSecrets in _pairwiseSecrets[clientId].Values) + { + for (int i = 0; i < Math.Min(maskedFlatParams.Length, otherClientSecrets.Length); i++) + { + maskedFlatParams[i] += otherClientSecrets[i]; + } + } + + // Unflatten back to original structure + int paramIndex = 0; + foreach (var layerName in clientUpdate.Keys) + { + var originalLayer = clientUpdate[layerName]; + var maskedLayer = new T[originalLayer.Length]; + + for (int i = 0; i < originalLayer.Length && paramIndex < maskedFlatParams.Length; i++, paramIndex++) + { + maskedLayer[i] = (T)Convert.ChangeType(maskedFlatParams[paramIndex], typeof(T)); + } + + maskedUpdate[layerName] = maskedLayer; + } + + return maskedUpdate; + } + + /// + /// Aggregates masked updates from all clients, recovering the true sum. + /// + /// + /// For Beginners: This sums up all the masked updates. Because the secret + /// masks cancel out, the result is the true sum of client updates. + /// + /// Mathematical property: + /// Σ(masked_update_i) = Σ(update_i + secrets_i) + /// = Σ(update_i) + Σ(secrets_i) + /// = Σ(update_i) + 0 ← secrets cancel + /// = True sum of updates + /// + /// The server performs this aggregation without ever seeing individual updates! + /// + /// For example with 2 clients: + /// Client 1 masked: [0.4, 0.0, 0.85] = [0.5, -0.3, 0.8] + [-0.1, 0.3, 0.05] + /// Client 2 masked: [0.7, 0.7, 1.05] = [0.6, 0.4, 1.1] + [0.1, -0.3, -0.05] + /// + /// Sum of masked: [1.1, 0.7, 1.9] + /// True sum: [0.5, -0.3, 0.8] + [0.6, 0.4, 1.1] = [1.1, 0.1, 1.9] ← Matches! + /// (Note: Secrets [-0.1, 0.3, 0.05] + [0.1, -0.3, -0.05] = [0, 0, 0] ← Cancelled) + /// + /// Dictionary of client IDs to their masked updates. + /// Dictionary of client IDs to their aggregation weights. + /// The securely aggregated model (sum of original updates with masks cancelled). + public Dictionary AggregateSecurely( + Dictionary> maskedUpdates, + Dictionary clientWeights) + { + if (maskedUpdates == null || maskedUpdates.Count == 0) + { + throw new ArgumentException("Masked updates cannot be null or empty.", nameof(maskedUpdates)); + } + + // Get model structure from first client + var firstUpdate = maskedUpdates.First().Value; + var aggregatedUpdate = new Dictionary(); + + // Initialize aggregated update with zeros + foreach (var layerName in firstUpdate.Keys) + { + aggregatedUpdate[layerName] = new T[firstUpdate[layerName].Length]; + } + + // Sum all masked updates + // The pairwise secrets will cancel out, leaving only the true sum + foreach (var clientId in maskedUpdates.Keys) + { + var maskedUpdate = maskedUpdates[clientId]; + + foreach (var layerName in maskedUpdate.Keys) + { + var maskedParams = maskedUpdate[layerName]; + var aggregatedParams = aggregatedUpdate[layerName]; + + for (int i = 0; i < maskedParams.Length; i++) + { + double currentValue = Convert.ToDouble(aggregatedParams[i]); + double maskedValue = Convert.ToDouble(maskedParams[i]); + aggregatedParams[i] = (T)Convert.ChangeType(currentValue + maskedValue, typeof(T)); + } + } + } + + // If using weighted aggregation, divide by total weight + if (clientWeights != null && clientWeights.Count > 0) + { + double totalWeight = clientWeights.Values.Sum(); + + if (totalWeight > 0) + { + foreach (var layerName in aggregatedUpdate.Keys) + { + var aggregatedParams = aggregatedUpdate[layerName]; + + for (int i = 0; i < aggregatedParams.Length; i++) + { + double value = Convert.ToDouble(aggregatedParams[i]); + aggregatedParams[i] = (T)Convert.ChangeType(value / totalWeight, typeof(T)); + } + } + } + } + + return aggregatedUpdate; + } + + /// + /// Flattens a hierarchical model structure into a single parameter array. + /// + /// + /// For Beginners: Converts the model from a dictionary of layers to a + /// single flat array of all parameters. Makes it easier to apply masks uniformly. + /// + /// The model to flatten. + /// A flat array of all parameters. + private double[] FlattenParameters(Dictionary model) + { + int totalParams = model.Values.Sum(layer => layer.Length); + double[] flatParams = new double[totalParams]; + + int index = 0; + foreach (var layer in model.Values) + { + foreach (var param in layer) + { + flatParams[index++] = Convert.ToDouble(param); + } + } + + return flatParams; + } + + /// + /// Clears all stored pairwise secrets. + /// + /// + /// For Beginners: Removes all secret keys from memory. Should be called + /// after aggregation is complete for security. + /// + /// Security best practice: + /// - Generate fresh secrets for each round + /// - Clear old secrets to prevent reuse + /// - Minimize time secrets are stored in memory + /// + public void ClearSecrets() + { + _pairwiseSecrets.Clear(); + } + + /// + /// Gets the number of clients with stored secrets. + /// + /// The count of clients. + public int GetClientCount() + { + return _pairwiseSecrets.Count; + } +} diff --git a/src/FederatedLearning/README.md b/src/FederatedLearning/README.md new file mode 100644 index 000000000..1e0b4288c --- /dev/null +++ b/src/FederatedLearning/README.md @@ -0,0 +1,250 @@ +# Federated Learning Framework + +This directory contains a comprehensive implementation of Federated Learning (FL) for the AiDotNet library, addressing Issue #398 (Phase 3). + +## Overview + +Federated Learning is a privacy-preserving distributed machine learning approach where multiple clients (devices, institutions, edge nodes) collaboratively train a shared model without sharing their raw data. Only model updates are exchanged, ensuring data privacy. + +## Features Implemented + +### Core Algorithms + +#### 1. FedAvg (Federated Averaging) +- **Location**: `Aggregators/FedAvgAggregationStrategy.cs` +- **Description**: The foundational FL algorithm that performs weighted averaging of client model updates +- **Use Case**: Standard federated learning with IID or mildly non-IID data +- **Reference**: McMahan et al. (2017) - "Communication-Efficient Learning of Deep Networks from Decentralized Data" + +#### 2. FedProx (Federated Proximal) +- **Location**: `Aggregators/FedProxAggregationStrategy.cs` +- **Description**: Handles system heterogeneity by adding proximal terms to prevent client drift +- **Use Case**: Non-IID data, heterogeneous client capabilities, stragglers +- **Key Parameter**: μ (proximal term coefficient) +- **Reference**: Li et al. (2020) - "Federated Optimization in Heterogeneous Networks" + +#### 3. FedBN (Federated Batch Normalization) +- **Location**: `Aggregators/FedBNAggregationStrategy.cs` +- **Description**: Keeps batch normalization layers local while aggregating other layers +- **Use Case**: Deep neural networks with batch normalization, non-IID data with distribution shift +- **Reference**: Li et al. (2021) - "Federated Learning on Non-IID Data Silos" + +### Privacy Features + +#### 1. Differential Privacy +- **Location**: `Privacy/GaussianDifferentialPrivacy.cs` +- **Description**: Implements (ε, δ)-differential privacy using the Gaussian mechanism +- **Features**: + - Gradient clipping for sensitivity bounding + - Calibrated Gaussian noise addition + - Privacy budget tracking + - Configurable ε (privacy budget) and δ (failure probability) +- **Reference**: Dwork & Roth (2014) - "The Algorithmic Foundations of Differential Privacy" + +#### 2. Secure Aggregation +- **Location**: `Privacy/SecureAggregation.cs` +- **Description**: Cryptographic protocol to aggregate updates without revealing individual contributions +- **Features**: + - Pairwise secret masking + - Server only sees aggregated result + - Protection against honest-but-curious server +- **Reference**: Bonawitz et al. (2017) - "Practical Secure Aggregation for Privacy-Preserving Machine Learning" + +### Personalization + +#### Personalized Federated Learning +- **Location**: `Personalization/PersonalizedFederatedLearning.cs` +- **Description**: Enables client-specific model layers while maintaining shared global layers +- **Features**: + - Layer-wise personalization + - Configurable personalization fraction + - Model split statistics + - Flexible personalization strategies +- **Use Case**: Non-IID data, client-specific adaptations, multi-task learning +- **Reference**: Fallah et al. (2020) - "Personalized Federated Learning: A Meta-Learning Approach" + +## Architecture + +### Interfaces + +All core interfaces are located in `src/Interfaces/`: + +- **IFederatedTrainer**: Main trainer for orchestrating federated learning +- **IAggregationStrategy**: Strategy pattern for different aggregation algorithms +- **IPrivacyMechanism**: Privacy-preserving mechanisms for protecting client data +- **IClientModel**: Client-side model operations + +### Configuration + +- **FederatedLearningOptions** (`src/Models/Options/FederatedLearningOptions.cs`): Comprehensive configuration options + - Client management (number of clients, selection fraction) + - Training parameters (learning rate, epochs, batch size) + - Privacy settings (differential privacy, secure aggregation) + - Convergence criteria + - Personalization settings + - Compression options + +- **FederatedLearningMetadata** (`src/Models/FederatedLearningMetadata.cs`): Training metrics and statistics + - Performance metrics (accuracy, loss) + - Resource usage (time, communication) + - Privacy budget tracking + - Per-round detailed metrics + +## Usage Examples + +### Basic FedAvg + +```csharp +using AiDotNet.FederatedLearning.Aggregators; +using AiDotNet.Models.Options; + +// Create FedAvg aggregation strategy +var aggregator = new FedAvgAggregationStrategy(); + +// Define client models and weights +var clientModels = new Dictionary> +{ + [0] = new Dictionary { ["layer1"] = new[] { 1.0, 2.0 } }, + [1] = new Dictionary { ["layer1"] = new[] { 3.0, 4.0 } } +}; + +var clientWeights = new Dictionary +{ + [0] = 100.0, // 100 samples + [1] = 200.0 // 200 samples +}; + +// Aggregate models +var globalModel = aggregator.Aggregate(clientModels, clientWeights); +``` + +### Differential Privacy + +```csharp +using AiDotNet.FederatedLearning.Privacy; + +// Create differential privacy mechanism +var dp = new GaussianDifferentialPrivacy( + clipNorm: 1.0, + randomSeed: 42 // For reproducibility +); + +// Apply privacy to model +var privateModel = dp.ApplyPrivacy( + model: clientModel, + epsilon: 1.0, // Privacy budget + delta: 1e-5 // Failure probability +); + +// Check privacy budget consumed +Console.WriteLine($"Privacy budget used: {dp.GetPrivacyBudgetConsumed()}"); +``` + +### Personalized Federated Learning + +```csharp +using AiDotNet.FederatedLearning.Personalization; + +// Create personalization handler +var pfl = new PersonalizedFederatedLearning( + personalizationFraction: 0.2 // Keep last 20% of layers personalized +); + +// Identify which layers to personalize +pfl.IdentifyPersonalizedLayers(modelStructure, strategy: "last_n"); + +// Separate model into global and personalized parts +pfl.SeparateModel( + fullModel: clientModel, + out var globalPart, + out var personalizedPart +); + +// Send only globalPart to server for aggregation +// Keep personalizedPart on client + +// After receiving aggregated global model from server +var updatedFullModel = pfl.CombineModels( + globalUpdate: aggregatedGlobalModel, + personalizedLayers: personalizedPart +); +``` + +## Configuration Options + +Key parameters in `FederatedLearningOptions`: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `NumberOfClients` | int | 10 | Total number of participating clients | +| `ClientSelectionFraction` | double | 1.0 | Fraction of clients selected per round (0.0-1.0) | +| `LocalEpochs` | int | 5 | Number of epochs each client trains locally | +| `MaxRounds` | int | 100 | Maximum federated learning rounds | +| `LearningRate` | double | 0.01 | Learning rate for local training | +| `BatchSize` | int | 32 | Batch size for local training | +| `UseDifferentialPrivacy` | bool | false | Enable differential privacy | +| `PrivacyEpsilon` | double | 1.0 | Privacy budget (ε) | +| `PrivacyDelta` | double | 1e-5 | Privacy failure probability (δ) | +| `UseSecureAggregation` | bool | false | Enable secure aggregation | +| `AggregationStrategy` | string | "FedAvg" | Aggregation algorithm to use | +| `ProximalMu` | double | 0.01 | FedProx proximal term coefficient | +| `EnablePersonalization` | bool | false | Enable personalized layers | +| `PersonalizationLayerFraction` | double | 0.2 | Fraction of layers to personalize | + +## Success Criteria (from Issue #398) + +✅ **Core Algorithms**: FedAvg, FedProx, FedBN implemented +✅ **Privacy Features**: Differential Privacy and Secure Aggregation implemented +✅ **Personalization**: Personalized Federated Learning implemented +✅ **Architecture**: Clean interfaces (IFederatedTrainer, IAggregationStrategy) +✅ **Configuration**: Comprehensive options and metadata classes +✅ **Documentation**: Extensive XML documentation with beginner-friendly explanations +✅ **Testing**: Unit tests for core algorithms and privacy mechanisms + +## Testing + +Unit tests are located in `tests/AiDotNet.Tests/FederatedLearning/`: + +- `FedAvgAggregationStrategyTests.cs`: Tests for FedAvg aggregation +- `GaussianDifferentialPrivacyTests.cs`: Tests for differential privacy mechanism + +Run tests: +```bash +dotnet test tests/AiDotNet.Tests/ +``` + +## Future Enhancements + +Potential extensions for future phases: + +1. **LEAF Benchmark Integration**: Add support for LEAF federated datasets +2. **Communication Efficiency**: Implement gradient compression and quantization +3. **Advanced Privacy**: Add Rényi Differential Privacy for tighter composition bounds +4. **Byzantine Robustness**: Implement Krum, Median, Trimmed Mean aggregation +5. **Meta-Learning**: Add MAML (Model-Agnostic Meta-Learning) for personalization +6. **Asynchronous FL**: Support for asynchronous client updates +7. **Vertical FL**: Support for vertically partitioned data +8. **Cross-Silo FL**: Enhanced support for enterprise federated learning scenarios + +## References + +1. McMahan, H. B., et al. (2017). "Communication-Efficient Learning of Deep Networks from Decentralized Data." AISTATS 2017. +2. Li, T., et al. (2020). "Federated Optimization in Heterogeneous Networks." MLSys 2020. +3. Li, X., et al. (2021). "Federated Learning on Non-IID Data Silos: An Experimental Study." ICDE 2021. +4. Bonawitz, K., et al. (2017). "Practical Secure Aggregation for Privacy-Preserving Machine Learning." CCS 2017. +5. Dwork, C., & Roth, A. (2014). "The Algorithmic Foundations of Differential Privacy." Foundations and Trends in Theoretical Computer Science. +6. Fallah, A., et al. (2020). "Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach." NeurIPS 2020. + +## Contributing + +When adding new federated learning algorithms or features: + +1. Follow the existing interface patterns (IAggregationStrategy, IPrivacyMechanism, etc.) +2. Add comprehensive XML documentation with beginner-friendly explanations +3. Include mathematical formulations and references +4. Add unit tests for new functionality +5. Update this README with usage examples + +## License + +This implementation is part of the AiDotNet library and is licensed under Apache-2.0. diff --git a/src/Interfaces/IAggregationStrategy.cs b/src/Interfaces/IAggregationStrategy.cs new file mode 100644 index 000000000..7595ab88b --- /dev/null +++ b/src/Interfaces/IAggregationStrategy.cs @@ -0,0 +1,67 @@ +namespace AiDotNet.Interfaces; + +/// +/// Defines strategies for aggregating model updates from multiple clients in federated learning. +/// +/// +/// This interface represents different methods for combining model updates from distributed clients +/// into a single improved global model. +/// +/// For Beginners: An aggregation strategy is like a voting system or consensus mechanism +/// that decides how to combine different opinions into a single decision. +/// +/// Think of aggregation strategies as different ways to combine contributions: +/// - Simple average: Everyone's input counts equally +/// - Weighted average: Some contributors' inputs count more based on criteria (data size, accuracy) +/// - Robust methods: Ignore outliers or malicious contributions +/// +/// For example, in a federated learning scenario with hospitals: +/// - Hospital A has 10,000 patients: gets weight of 10,000 +/// - Hospital B has 5,000 patients: gets weight of 5,000 +/// - The aggregation strategy might weight Hospital A's updates more heavily +/// +/// Different strategies handle different challenges: +/// - FedAvg: Standard weighted averaging +/// - FedProx: Handles clients with different update frequencies +/// - Krum: Robust to Byzantine (malicious) clients +/// - Median aggregation: Resistant to outliers +/// +/// The type of model being aggregated. +public interface IAggregationStrategy +{ + /// + /// Aggregates model updates from multiple clients into a single global model update. + /// + /// + /// This method combines model updates from clients using the strategy's specific algorithm. + /// + /// For Beginners: Aggregation is like combining multiple rough drafts of a document + /// into one polished version that incorporates the best parts of each. + /// + /// The aggregation process typically: + /// 1. Takes model updates (weight changes) from each client + /// 2. Considers the weight or importance of each client (based on data size, accuracy, etc.) + /// 3. Combines these updates using the strategy's algorithm + /// 4. Returns a single aggregated model that represents the collective improvement + /// + /// For example with weighted averaging (FedAvg): + /// - Client 1 (1000 samples): model update A + /// - Client 2 (500 samples): model update B + /// - Client 3 (1500 samples): model update C + /// - Aggregated update = (1000*A + 500*B + 1500*C) / 3000 + /// + /// Dictionary mapping client IDs to their trained models. + /// Dictionary mapping client IDs to their aggregation weights (typically based on data size). + /// The aggregated global model. + TModel Aggregate(Dictionary clientModels, Dictionary clientWeights); + + /// + /// Gets the name of the aggregation strategy. + /// + /// + /// For Beginners: This helps identify which aggregation method is being used, + /// useful for logging, debugging, and comparing different strategies. + /// + /// A string describing the aggregation strategy (e.g., "FedAvg", "FedProx", "Krum"). + string GetStrategyName(); +} diff --git a/src/Interfaces/IClientModel.cs b/src/Interfaces/IClientModel.cs new file mode 100644 index 000000000..b31772995 --- /dev/null +++ b/src/Interfaces/IClientModel.cs @@ -0,0 +1,129 @@ +namespace AiDotNet.Interfaces; + +/// +/// Defines the functionality for a client-side model in federated learning. +/// +/// +/// This interface represents a model that exists on a client device or node in a federated +/// learning system. Each client maintains its own copy of the global model and trains it +/// on local data. +/// +/// For Beginners: A client model is like a student's personal copy of study materials. +/// Each student (client) has their own copy, studies it with their own resources, and +/// contributes improvements back to the class. +/// +/// Think of client models as distributed learners: +/// - Each client has a copy of the global model +/// - Clients train on their own private data +/// - Local training happens independently and in parallel +/// - Only model updates (not data) are sent to the server +/// +/// For example, in smartphone keyboard prediction: +/// - Each phone has a copy of the global typing prediction model +/// - The phone learns from the user's typing patterns +/// - It sends model improvements (not actual typed text) to the server +/// - The server combines improvements from millions of phones +/// - Each phone gets the improved model back +/// +/// This design ensures: +/// - Data privacy: Raw data never leaves the client +/// - Personalization: Can adapt to local data distribution +/// - Scalability: Training happens in parallel across all clients +/// +/// The type of the local training data. +/// The type of the model update to send to the server. +public interface IClientModel +{ + /// + /// Trains the local model on the client's private data. + /// + /// + /// Local training is the core of federated learning where each client improves the model + /// using their own data without sharing that data with anyone. + /// + /// For Beginners: This is like studying independently with your own materials. + /// You use your personal notes and resources to learn, and later share what you learned, + /// not the actual materials. + /// + /// The training process: + /// 1. Receive the global model from the server + /// 2. Train on local data for specified number of epochs + /// 3. Compute the difference between updated and original model (the "update") + /// 4. Prepare this update to send back to the server + /// + /// For example: + /// - Client receives global model with accuracy 80% + /// - Trains on local data for 5 epochs + /// - Local model now has accuracy 85% on local data + /// - Computes weight changes (delta) that improved the model + /// - Sends these weight changes to server, not the local data + /// + /// The client's private training data. + /// Number of training iterations to perform on local data. + /// Step size for gradient descent optimization. + void TrainLocal(TData localData, int epochs, double learningRate); + + /// + /// Computes and retrieves the model update to send to the server. + /// + /// + /// The model update represents the improvements the client made through local training. + /// This is typically the difference between the current model and the initial global model. + /// + /// For Beginners: This is like preparing a summary of what you learned from studying, + /// rather than sharing your entire study materials. You share the insights, not the sources. + /// + /// The update typically contains: + /// - Weight differences: New weights - original weights + /// - Gradients: Direction and magnitude of improvement + /// - Metadata: Number of local samples, local loss, etc. + /// + /// For example: + /// - Original weight for feature "age": 0.5 + /// - After training, weight for "age": 0.6 + /// - Update to send: +0.1 + /// - This tells the server how to adjust that weight + /// + /// The model update containing weight changes or gradients. + TUpdate GetModelUpdate(); + + /// + /// Updates the local model with the new global model from the server. + /// + /// + /// After the server aggregates updates from all clients, it sends the improved global + /// model back to clients for the next round of training. + /// + /// For Beginners: This is like receiving the updated textbook that incorporates + /// everyone's contributions. You replace your old version with this improved version + /// before the next study session. + /// + /// The update process: + /// 1. Receive aggregated global model from server + /// 2. Replace local model weights with global model weights + /// 3. Optionally keep some personalized layers + /// 4. Ready for next round of local training + /// + /// For example: + /// - Round 1: Trained local model, sent update + /// - Server aggregated all updates + /// - Round 2: Receive improved global model + /// - Use this as starting point for next round of training + /// + /// The aggregated global model from the server. + void UpdateFromGlobal(TUpdate globalModelUpdate); + + /// + /// Gets the number of training samples available on this client. + /// + /// + /// Sample count is used to weight client contributions during aggregation. + /// Clients with more data typically receive higher weights. + /// + /// For Beginners: This is like indicating how many practice problems you solved. + /// If you solved 1000 problems and someone else solved 100, your insights about + /// problem-solving patterns are likely more reliable. + /// + /// The number of training samples on this client. + int GetSampleCount(); +} diff --git a/src/Interfaces/IFederatedTrainer.cs b/src/Interfaces/IFederatedTrainer.cs new file mode 100644 index 000000000..c8b6c7bda --- /dev/null +++ b/src/Interfaces/IFederatedTrainer.cs @@ -0,0 +1,127 @@ +namespace AiDotNet.Interfaces; + +/// +/// Defines the core functionality for federated learning trainers that coordinate distributed training across multiple clients. +/// +/// +/// This interface represents the fundamental operations for federated learning systems where multiple clients +/// (devices, institutions, edge nodes) collaboratively train a shared model without sharing their raw data. +/// +/// For Beginners: Federated learning is like group study where everyone learns from their own materials +/// but shares only their insights, not their actual study materials. +/// +/// Think of federated learning as a privacy-preserving collaborative learning approach: +/// - Multiple clients (hospitals, phones, banks) have their own local data +/// - Each client trains a model on their local data independently +/// - Only model updates (not raw data) are shared with a central server +/// - The server aggregates these updates to improve the global model +/// - The improved global model is sent back to clients for the next round +/// +/// For example, in healthcare: +/// - Multiple hospitals want to train a disease detection model +/// - Each hospital has patient data that cannot be shared due to privacy regulations +/// - Each hospital trains the model on their own data +/// - Only the learned patterns (model weights) are shared and combined +/// - This creates a better model while keeping patient data private +/// +/// This interface provides methods for coordinating the federated training process. +/// +/// The type of the model being trained. +/// The type of the training data. +/// The type of metadata returned by the training process. +public interface IFederatedTrainer +{ + /// + /// Initializes the federated learning process with client configurations and the global model. + /// + /// + /// This method sets up the initial state for federated learning by: + /// - Initializing the global model that will be shared across all clients + /// - Registering client configurations (number of clients, data distribution, etc.) + /// - Setting up communication channels for model updates + /// + /// For Beginners: Initialization is like setting up a study group before the first session. + /// You need to know who's participating, what materials everyone has, and establish + /// how you'll share information. + /// + /// The initial global model to be distributed to clients. + /// The number of clients participating in federated learning. + void Initialize(TModel globalModel, int numberOfClients); + + /// + /// Executes one round of federated learning where clients train locally and updates are aggregated. + /// + /// + /// A federated learning round consists of several steps: + /// 1. The global model is sent to selected clients + /// 2. Each client trains the model on their local data + /// 3. Clients send their model updates back to the server + /// 4. The server aggregates these updates using an aggregation strategy + /// 5. The global model is updated with the aggregated result + /// + /// For Beginners: Think of this as one iteration in a collaborative learning cycle. + /// Everyone gets the current version of the shared knowledge, studies independently, + /// then contributes their improvements. These improvements are combined to create + /// an even better version for the next round. + /// + /// For example: + /// - Round 1: Clients receive initial model, train for 5 epochs, send updates + /// - Server aggregates updates and improves global model + /// - Round 2: Clients receive improved model, train again, send updates + /// - This continues until the model reaches desired accuracy + /// + /// Dictionary mapping client IDs to their local training data. + /// Fraction of clients to select for this round (0.0 to 1.0). + /// Number of training epochs each client should perform locally. + /// Metadata about the training round including accuracy, loss, and convergence metrics. + TMetadata TrainRound(Dictionary clientData, double clientSelectionFraction = 1.0, int localEpochs = 1); + + /// + /// Executes multiple rounds of federated learning until convergence or maximum rounds reached. + /// + /// + /// This method orchestrates the entire federated learning process by: + /// - Running multiple training rounds + /// - Monitoring convergence (when the model stops improving significantly) + /// - Tracking performance metrics across rounds + /// - Applying privacy mechanisms if configured + /// + /// For Beginners: This is the complete federated learning process from start to finish. + /// It's like running an entire semester of study group sessions, where you continue meeting + /// until everyone has mastered the material or you've run out of time. + /// + /// Dictionary mapping client IDs to their local training data. + /// Maximum number of federated learning rounds to execute. + /// Fraction of clients to select per round (0.0 to 1.0). + /// Number of training epochs each client performs per round. + /// Aggregated metadata across all training rounds. + TMetadata Train(Dictionary clientData, int rounds, double clientSelectionFraction = 1.0, int localEpochs = 1); + + /// + /// Retrieves the current global model after federated training. + /// + /// + /// The global model represents the collective knowledge learned from all participating clients. + /// + /// For Beginners: This is the final product of the collaborative learning process - + /// a model that benefits from all participants' data without ever accessing their raw data directly. + /// + /// The trained global model. + TModel GetGlobalModel(); + + /// + /// Sets the aggregation strategy used to combine client updates. + /// + /// + /// Different aggregation strategies handle various challenges in federated learning: + /// - FedAvg: Simple weighted averaging of model updates + /// - FedProx: Handles clients with different computational capabilities + /// - FedBN: Special handling for batch normalization layers + /// + /// For Beginners: The aggregation strategy is the rule for combining everyone's + /// contributions. Different rules work better for different situations, like how you might + /// weight expert opinions more heavily in certain contexts. + /// + /// The aggregation strategy to use. + void SetAggregationStrategy(IAggregationStrategy strategy); +} diff --git a/src/Interfaces/IPrivacyMechanism.cs b/src/Interfaces/IPrivacyMechanism.cs new file mode 100644 index 000000000..502777893 --- /dev/null +++ b/src/Interfaces/IPrivacyMechanism.cs @@ -0,0 +1,90 @@ +namespace AiDotNet.Interfaces; + +/// +/// Defines privacy-preserving mechanisms for federated learning to protect client data. +/// +/// +/// This interface represents techniques to ensure that model updates don't leak sensitive +/// information about individual data points in clients' local datasets. +/// +/// For Beginners: Privacy mechanisms are like filters that protect sensitive information +/// while still allowing useful knowledge to be shared. +/// +/// Think of privacy mechanisms as protective measures: +/// - Differential Privacy: Adds carefully calibrated noise to make individual data unidentifiable +/// - Secure Aggregation: Encrypts updates so the server only sees the combined result +/// - Homomorphic Encryption: Allows computation on encrypted data +/// +/// For example, in a hospital scenario: +/// - Without privacy: Model updates might reveal information about specific patients +/// - With differential privacy: Random noise is added so you can't identify individual patients +/// - The noise is calibrated so the overall patterns remain accurate +/// +/// Privacy mechanisms provide mathematical guarantees: +/// - Epsilon (ε): Privacy budget - lower values mean stronger privacy +/// - Delta (δ): Probability that privacy guarantee fails +/// - Common setting: ε=1.0, δ=1e-5 means strong privacy with high confidence +/// +/// The type of model to apply privacy mechanisms to. +public interface IPrivacyMechanism +{ + /// + /// Applies privacy-preserving techniques to a model update before sharing it. + /// + /// + /// This method transforms model updates to provide privacy guarantees while maintaining utility. + /// + /// For Beginners: This is like redacting sensitive parts of a document before sharing it. + /// You remove or obscure information that could identify individuals while keeping the + /// useful content intact. + /// + /// Common techniques: + /// - Differential Privacy: Adds random noise proportional to sensitivity + /// - Gradient Clipping: Limits the magnitude of updates to prevent outliers + /// - Local DP: Each client adds noise before sending updates + /// - Central DP: Server adds noise after aggregation + /// + /// For example with differential privacy: + /// 1. Client trains model and computes weight updates + /// 2. Applies gradient clipping to limit maximum change + /// 3. Adds calibrated Gaussian noise to each weight + /// 4. Sends noisy update to server + /// 5. Even if server is compromised, individual data remains private + /// + /// The model update to apply privacy to. + /// Privacy budget parameter - smaller values provide stronger privacy. + /// Probability of privacy guarantee failure - typically very small (e.g., 1e-5). + /// The model update with privacy mechanisms applied. + TModel ApplyPrivacy(TModel model, double epsilon, double delta); + + /// + /// Gets the current privacy budget consumed by this mechanism. + /// + /// + /// Privacy budget is a finite resource in differential privacy. Each time you share + /// information, you "spend" some privacy budget. Once exhausted, you can no longer + /// provide strong privacy guarantees. + /// + /// For Beginners: Think of privacy budget like a bank account for privacy. + /// Each time you share data, you withdraw from this account. When the account is empty, + /// you've used up your privacy guarantees and should stop sharing. + /// + /// For example: + /// - Start with privacy budget ε=10 + /// - Round 1: Share update with ε=1, remaining budget = 9 + /// - Round 2: Share update with ε=1, remaining budget = 8 + /// - After 10 rounds, budget is exhausted + /// + /// The amount of privacy budget consumed so far. + double GetPrivacyBudgetConsumed(); + + /// + /// Gets the name of the privacy mechanism. + /// + /// + /// For Beginners: This identifies which privacy technique is being used, + /// helpful for documentation and comparing different privacy approaches. + /// + /// A string describing the privacy mechanism (e.g., "Gaussian Mechanism", "Laplace Mechanism"). + string GetMechanismName(); +} diff --git a/src/Models/FederatedLearningMetadata.cs b/src/Models/FederatedLearningMetadata.cs new file mode 100644 index 000000000..186fb5312 --- /dev/null +++ b/src/Models/FederatedLearningMetadata.cs @@ -0,0 +1,301 @@ +namespace AiDotNet.Models; + +/// +/// Contains metadata and metrics about federated learning training progress and results. +/// +/// +/// This class tracks various metrics throughout the federated learning process to help +/// monitor training progress, diagnose issues, and evaluate model quality. +/// +/// For Beginners: Metadata is like a training diary that records what happened +/// during federated learning - how long it took, how accurate the model became, which +/// clients participated, etc. +/// +/// Think of this as a comprehensive training report containing: +/// - Performance metrics: Accuracy, loss, convergence +/// - Resource usage: Time, communication costs +/// - Participation: Which clients contributed +/// - Privacy tracking: Privacy budget consumption +/// +/// For example, after training you might see: +/// - Total rounds: 50 (out of max 100) +/// - Final accuracy: 92.5% +/// - Training time: 2 hours +/// - Total clients participated: 100 +/// - Privacy budget used: ε=5.0 (out of 10.0 total) +/// +public class FederatedLearningMetadata +{ + /// + /// Gets or sets the number of federated learning rounds completed. + /// + /// + /// For Beginners: A round is one complete cycle where clients train and the + /// server aggregates updates. This counts how many such cycles were completed. + /// + public int RoundsCompleted { get; set; } + + /// + /// Gets or sets the final global model accuracy on validation data. + /// + /// + /// For Beginners: Accuracy measures how often the model makes correct predictions. + /// For example, 0.95 means the model is correct 95% of the time. + /// + /// This is measured on validation data that wasn't used for training. + /// + public double FinalAccuracy { get; set; } + + /// + /// Gets or sets the final global model loss value. + /// + /// + /// For Beginners: Loss measures how far the model's predictions are from the + /// true values. Lower loss indicates better performance. + /// + /// For example: + /// - Initial loss: 2.5 + /// - Final loss: 0.3 + /// - The model has improved significantly + /// + public double FinalLoss { get; set; } + + /// + /// Gets or sets the total training time in seconds. + /// + /// + /// For Beginners: The total wall-clock time from start to finish of federated + /// learning, including all rounds, communication, and aggregation. + /// + public double TotalTrainingTimeSeconds { get; set; } + + /// + /// Gets or sets the average time per round in seconds. + /// + /// + /// For Beginners: How long each round takes on average. Useful for estimating + /// how long future training runs will take. + /// + /// For example: + /// - 50 rounds completed in 5000 seconds + /// - Average: 100 seconds per round + /// - Next training with 100 rounds will take ~10,000 seconds + /// + public double AverageRoundTimeSeconds { get; set; } + + /// + /// Gets or sets the history of loss values across all rounds. + /// + /// + /// For Beginners: A list showing how loss changed after each round. + /// Useful for plotting learning curves and diagnosing training issues. + /// + /// For example: [2.5, 1.8, 1.2, 0.9, 0.7, 0.5, 0.4, 0.35, 0.32, 0.3] + /// Shows steady improvement from 2.5 to 0.3 + /// + public List LossHistory { get; set; } = new List(); + + /// + /// Gets or sets the history of accuracy values across all rounds. + /// + /// + /// For Beginners: A list showing how accuracy improved after each round. + /// + /// For example: [0.60, 0.70, 0.78, 0.84, 0.88, 0.91, 0.93, 0.94, 0.945, 0.95] + /// Shows accuracy improving from 60% to 95% + /// + public List AccuracyHistory { get; set; } = new List(); + + /// + /// Gets or sets the total number of clients that participated across all rounds. + /// + /// + /// For Beginners: How many different clients contributed to the model. + /// + /// For example: + /// - 100 clients available + /// - 10 clients selected per round + /// - Over 50 rounds, might have 80 unique participants + /// + public int TotalClientsParticipated { get; set; } + + /// + /// Gets or sets the average number of clients selected per round. + /// + /// + /// For Beginners: How many clients were active in each round on average. + /// + /// For example: + /// - Round 1: 10 clients + /// - Round 2: 8 clients (some unavailable) + /// - Round 3: 10 clients + /// - Average: 9.3 clients per round + /// + public double AverageClientsPerRound { get; set; } + + /// + /// Gets or sets the total communication cost in megabytes. + /// + /// + /// For Beginners: The total amount of data transferred between clients and server + /// throughout training. Important for understanding bandwidth requirements. + /// + /// For example: + /// - Each model update: 10 MB + /// - 10 clients per round + /// - 50 rounds + /// - Total: 10 MB × 10 × 50 × 2 (up and down) = 10,000 MB = 10 GB + /// + public double TotalCommunicationMB { get; set; } + + /// + /// Gets or sets the total privacy budget (epsilon) consumed. + /// + /// + /// For Beginners: If differential privacy is used, this tracks how much privacy + /// budget has been spent. Privacy budget is finite - once exhausted, no more privacy + /// guarantees. + /// + /// For example: + /// - ε per round: 0.1 + /// - 50 rounds completed + /// - Total consumed: 5.0 + /// - If total budget is 10.0, have 5.0 remaining + /// + public double TotalPrivacyBudgetConsumed { get; set; } + + /// + /// Gets or sets whether training converged before reaching maximum rounds. + /// + /// + /// For Beginners: Convergence means the model stopped improving significantly. + /// If true, training ended early because the model reached a good solution. + /// + /// For example: + /// - Max rounds: 100 + /// - Converged at round 50 + /// - Converged = true + /// - Saved time by stopping early + /// + public bool Converged { get; set; } + + /// + /// Gets or sets the round at which convergence was detected. + /// + /// + /// For Beginners: Which round did the model stop improving significantly? + /// + /// Useful for: + /// - Setting better MaxRounds for future training + /// - Understanding training dynamics + /// - Comparing different algorithms + /// + public int ConvergenceRound { get; set; } + + /// + /// Gets or sets the aggregation strategy used during training. + /// + /// + /// For Beginners: Records which aggregation algorithm was used (FedAvg, FedProx, etc.). + /// Important for reproducibility and understanding results. + /// + public string AggregationStrategyUsed { get; set; } = string.Empty; + + /// + /// Gets or sets whether differential privacy was enabled. + /// + /// + /// For Beginners: Records whether privacy mechanisms were active during training. + /// Important for understanding any accuracy trade-offs. + /// + public bool DifferentialPrivacyEnabled { get; set; } + + /// + /// Gets or sets whether secure aggregation was enabled. + /// + /// + /// For Beginners: Records whether client updates were encrypted during aggregation. + /// + public bool SecureAggregationEnabled { get; set; } + + /// + /// Gets or sets additional notes or observations about the training run. + /// + /// + /// For Beginners: A freeform field for recording anything unusual or noteworthy + /// that happened during training. + /// + /// For example: + /// - "Client 5 dropped out after round 30" + /// - "Convergence was slower than expected" + /// - "High variance in client update quality" + /// + public string Notes { get; set; } = string.Empty; + + /// + /// Gets or sets the per-round detailed metrics. + /// + /// + /// For Beginners: Detailed information about each individual round, including + /// which clients participated, their individual losses, communication costs, etc. + /// + /// Useful for: + /// - Detailed analysis of training dynamics + /// - Identifying problematic clients + /// - Understanding convergence patterns + /// + public List RoundMetrics { get; set; } = new List(); +} + +/// +/// Contains detailed metrics for a single federated learning round. +/// +/// +/// For Beginners: Information about what happened in one specific training round. +/// +public class RoundMetadata +{ + /// + /// Gets or sets the round number (0-indexed). + /// + public int RoundNumber { get; set; } + + /// + /// Gets or sets the global model loss after this round. + /// + public double GlobalLoss { get; set; } + + /// + /// Gets or sets the global model accuracy after this round. + /// + public double GlobalAccuracy { get; set; } + + /// + /// Gets or sets the IDs of clients selected for this round. + /// + public List SelectedClientIds { get; set; } = new List(); + + /// + /// Gets or sets the time taken for this round in seconds. + /// + public double RoundTimeSeconds { get; set; } + + /// + /// Gets or sets the communication cost for this round in megabytes. + /// + public double CommunicationMB { get; set; } + + /// + /// Gets or sets the average local loss across selected clients. + /// + /// + /// For Beginners: The average loss that clients achieved on their local data + /// before sending updates. Comparing this to global loss can reveal overfitting. + /// + public double AverageLocalLoss { get; set; } + + /// + /// Gets or sets the privacy budget consumed in this round. + /// + public double PrivacyBudgetConsumed { get; set; } +} diff --git a/src/Models/Options/FederatedLearningOptions.cs b/src/Models/Options/FederatedLearningOptions.cs new file mode 100644 index 000000000..c3ed5ab4c --- /dev/null +++ b/src/Models/Options/FederatedLearningOptions.cs @@ -0,0 +1,350 @@ +namespace AiDotNet.Models.Options; + +/// +/// Configuration options for federated learning training. +/// +/// +/// This class contains all the configurable parameters needed to set up and run a federated learning system. +/// +/// For Beginners: Options are like the settings panel for federated learning. +/// Just as you configure settings for a video game (difficulty, graphics quality, etc.), +/// these options let you configure how federated learning should work. +/// +/// Key configuration areas: +/// - Client Management: How many clients, how to select them +/// - Training: Learning rates, epochs, batch sizes +/// - Privacy: Differential privacy parameters +/// - Communication: How often to aggregate, compression settings +/// - Convergence: When to stop training +/// +/// For example, a typical configuration might be: +/// - 100 total clients (e.g., hospitals) +/// - Select 10 clients per round (10% participation) +/// - Each client trains for 5 local epochs +/// - Use privacy budget ε=1.0, δ=1e-5 +/// - Run for maximum 100 rounds or until convergence +/// +public class FederatedLearningOptions +{ + /// + /// Gets or sets the total number of clients participating in federated learning. + /// + /// + /// For Beginners: This is the total pool of participants available for training. + /// In each round, a subset of these clients may be selected. + /// + /// For example: + /// - Mobile keyboard app: Millions of phones + /// - Healthcare: 50 hospitals + /// - Financial: 100 bank branches + /// + public int NumberOfClients { get; set; } = 10; + + /// + /// Gets or sets the fraction of clients to select for each training round (0.0 to 1.0). + /// + /// + /// For Beginners: Not all clients participate in every round. This setting controls + /// what percentage of clients are active in each round. + /// + /// Common values: + /// - 1.0: All clients participate (small deployments) + /// - 0.1: 10% participate (large deployments, reduces communication) + /// - 0.01: 1% participate (massive deployments like mobile devices) + /// + /// For example, with 1000 clients and fraction 0.1: + /// - Each round randomly selects 100 clients + /// - Reduces server load and communication costs + /// - Still converges if enough clients are selected + /// + public double ClientSelectionFraction { get; set; } = 1.0; + + /// + /// Gets or sets the number of local training epochs each client performs per round. + /// + /// + /// For Beginners: An epoch is one complete pass through the client's local dataset. + /// More epochs mean more local training but also more computation time. + /// + /// Trade-offs: + /// - More epochs (5-10): Better local adaptation, slower rounds + /// - Fewer epochs (1-2): Faster rounds, more communication needed + /// + /// For example: + /// - Client has 1000 samples + /// - LocalEpochs = 5 + /// - Client processes all 1000 samples 5 times before sending update + /// + public int LocalEpochs { get; set; } = 5; + + /// + /// Gets or sets the maximum number of federated learning rounds to execute. + /// + /// + /// For Beginners: A round is one complete cycle where clients train and the server + /// aggregates updates. This sets the maximum number of such cycles. + /// + /// Typical values: + /// - Quick experiments: 10-50 rounds + /// - Production training: 100-1000 rounds + /// - Large scale: 1000+ rounds + /// + /// Training may stop early if convergence criteria are met. + /// + public int MaxRounds { get; set; } = 100; + + /// + /// Gets or sets the learning rate for local client training. + /// + /// + /// For Beginners: Learning rate controls how big of a step the model takes when + /// learning from data. Too large and learning is unstable; too small and learning is slow. + /// + /// Common values: + /// - Deep learning: 0.001 - 0.01 + /// - Traditional ML: 0.01 - 0.1 + /// + /// For example: + /// - LearningRate = 0.01 means adjust weights by 1% of the gradient + /// - Higher values = faster learning but less stability + /// - Lower values = slower but more stable learning + /// + public double LearningRate { get; set; } = 0.01; + + /// + /// Gets or sets the batch size for local training. + /// + /// + /// For Beginners: Instead of processing all data at once, we process it in smaller + /// batches. Batch size is how many samples to process before updating the model. + /// + /// Trade-offs: + /// - Larger batches (64-512): More stable gradients, requires more memory + /// - Smaller batches (8-32): Less memory, more noise in updates + /// + /// For example: + /// - Client has 1000 samples, BatchSize = 32 + /// - Data is split into 32 batches of ~31 samples each + /// - Model is updated after processing each batch + /// + public int BatchSize { get; set; } = 32; + + /// + /// Gets or sets whether to use differential privacy. + /// + /// + /// For Beginners: Differential privacy adds mathematical noise to protect individual + /// data points from being identified in the model updates. + /// + /// When enabled: + /// - Privacy guarantees are provided + /// - Some accuracy may be sacrificed + /// - Individual contributions are hidden + /// + /// Use cases requiring privacy: + /// - Healthcare data + /// - Financial records + /// - Personal communications + /// + public bool UseDifferentialPrivacy { get; set; } = false; + + /// + /// Gets or sets the epsilon (ε) parameter for differential privacy (privacy budget). + /// + /// + /// For Beginners: Epsilon controls the privacy-utility tradeoff. Lower values mean + /// stronger privacy but potentially less accurate models. + /// + /// Common values: + /// - ε = 0.1: Very strong privacy, significant accuracy loss + /// - ε = 1.0: Strong privacy, moderate accuracy loss (recommended) + /// - ε = 10.0: Weak privacy, minimal accuracy loss + /// + /// For example: + /// - With ε = 1.0, an adversary cannot distinguish whether any specific individual's + /// data was used in training (within factor e^1 ≈ 2.7) + /// + public double PrivacyEpsilon { get; set; } = 1.0; + + /// + /// Gets or sets the delta (δ) parameter for differential privacy (failure probability). + /// + /// + /// For Beginners: Delta is the probability that the privacy guarantee fails. + /// It should be very small, typically much less than 1/number_of_data_points. + /// + /// Common values: + /// - δ = 1e-5 (0.00001): Standard choice + /// - δ = 1e-6: Stronger guarantee + /// + /// For example: + /// - δ = 1e-5 means there's a 0.001% chance privacy is compromised + /// - Should be smaller than 1/total_number_of_samples across all clients + /// + public double PrivacyDelta { get; set; } = 1e-5; + + /// + /// Gets or sets whether to use secure aggregation. + /// + /// + /// For Beginners: Secure aggregation encrypts client updates so that the server + /// can only see the aggregated result, not individual contributions. + /// + /// Benefits: + /// - Server cannot see individual client updates + /// - Protects against honest-but-curious server + /// - Only the final aggregated model is visible + /// + /// For example: + /// - Without: Server sees each hospital's model update + /// - With: Server only sees combined update from all hospitals + /// - No single hospital's contribution is visible + /// + public bool UseSecureAggregation { get; set; } = false; + + /// + /// Gets or sets the convergence threshold for early stopping. + /// + /// + /// For Beginners: Training stops early if improvement between rounds falls below + /// this threshold, indicating the model has converged (stopped improving significantly). + /// + /// For example: + /// - ConvergenceThreshold = 0.001 + /// - If loss improves by less than 0.001 for several consecutive rounds, stop training + /// - Saves time and resources by avoiding unnecessary rounds + /// + public double ConvergenceThreshold { get; set; } = 0.001; + + /// + /// Gets or sets the minimum number of rounds to train before checking convergence. + /// + /// + /// For Beginners: Don't check for convergence too early. Wait at least this many + /// rounds before considering early stopping. + /// + /// This prevents: + /// - Stopping too early due to initial volatility + /// - Missing later improvements + /// + /// Typical value: 10-20 rounds + /// + public int MinRoundsBeforeConvergence { get; set; } = 10; + + /// + /// Gets or sets the aggregation strategy name to use. + /// + /// + /// For Beginners: This determines how client updates are combined. + /// + /// Available strategies: + /// - "FedAvg": Weighted average (standard choice) + /// - "FedProx": Handles system heterogeneity + /// - "FedBN": Special handling for batch normalization + /// + /// Different strategies work better for different scenarios. + /// + public string AggregationStrategy { get; set; } = "FedAvg"; + + /// + /// Gets or sets the proximal term coefficient for FedProx algorithm. + /// + /// + /// For Beginners: FedProx adds a penalty to prevent client models from + /// deviating too much from the global model. This parameter controls the penalty strength. + /// + /// Common values: + /// - 0.0: No proximal term (equivalent to FedAvg) + /// - 0.01 - 0.1: Mild constraint + /// - 1.0+: Strong constraint + /// + /// Use FedProx when: + /// - Clients have very different data distributions + /// - Some clients are much slower than others + /// - You want more stable convergence + /// + public double ProximalMu { get; set; } = 0.01; + + /// + /// Gets or sets whether to enable personalization. + /// + /// + /// For Beginners: Personalization allows each client to maintain some client-specific + /// model parameters while sharing common parameters with other clients. + /// + /// Benefits: + /// - Better performance on local data + /// - Handles non-IID data (data that varies across clients) + /// - Combines benefits of global and local models + /// + /// For example: + /// - Global layers: Learn general patterns from all clients + /// - Personalized layers: Adapt to each client's specific data + /// + public bool EnablePersonalization { get; set; } = false; + + /// + /// Gets or sets the fraction of model layers to keep personalized (not aggregated). + /// + /// + /// For Beginners: When personalization is enabled, this determines what fraction + /// of the model remains client-specific vs. shared globally. + /// + /// For example: + /// - PersonalizationLayerFraction = 0.2 + /// - Last 20% of model layers stay personalized + /// - First 80% are aggregated globally + /// + /// Typical use: + /// - Output layers personalized, feature extractors shared + /// + public double PersonalizationLayerFraction { get; set; } = 0.2; + + /// + /// Gets or sets whether to use gradient compression to reduce communication costs. + /// + /// + /// For Beginners: Compression reduces the size of model updates sent between + /// clients and server, saving bandwidth and time. + /// + /// Techniques: + /// - Quantization: Use fewer bits per parameter + /// - Sparsification: Send only top-k largest updates + /// - Sketching: Use randomized compression + /// + /// Trade-off: + /// - Reduces communication by 10-100x + /// - May slightly slow convergence + /// + public bool UseCompression { get; set; } = false; + + /// + /// Gets or sets the compression ratio (0.0 to 1.0) if compression is enabled. + /// + /// + /// For Beginners: Controls how much to compress. Lower values mean more compression + /// but potentially more accuracy loss. + /// + /// For example: + /// - 0.01: Keep top 1% of gradients (99% compression) + /// - 0.1: Keep top 10% of gradients (90% compression) + /// - 1.0: No compression + /// + public double CompressionRatio { get; set; } = 0.1; + + /// + /// Gets or sets a random seed for reproducibility. + /// + /// + /// For Beginners: Random seed makes randomness reproducible. Using the same + /// seed will produce the same random client selections, initializations, etc. + /// + /// Benefits: + /// - Reproducible experiments + /// - Easier debugging + /// - Fair comparison between methods + /// + /// Set to null for truly random behavior. + /// + public int? RandomSeed { get; set; } = null; +} diff --git a/tests/AiDotNet.Tests/FederatedLearning/FedAvgAggregationStrategyTests.cs b/tests/AiDotNet.Tests/FederatedLearning/FedAvgAggregationStrategyTests.cs new file mode 100644 index 000000000..8568b3c58 --- /dev/null +++ b/tests/AiDotNet.Tests/FederatedLearning/FedAvgAggregationStrategyTests.cs @@ -0,0 +1,201 @@ +namespace AiDotNet.Tests.FederatedLearning; + +using Xunit; +using AiDotNet.FederatedLearning.Aggregators; +using System; +using System.Collections.Generic; + +/// +/// Unit tests for FedAvg (Federated Averaging) aggregation strategy. +/// +public class FedAvgAggregationStrategyTests +{ + [Fact] + public void Aggregate_WithEqualWeights_ReturnsAverageModel() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + + // Create two client models with simple parameters + var clientModels = new Dictionary> + { + [0] = new Dictionary + { + ["layer1"] = new double[] { 1.0, 2.0, 3.0 } + }, + [1] = new Dictionary + { + ["layer1"] = new double[] { 3.0, 4.0, 5.0 } + } + }; + + var clientWeights = new Dictionary + { + [0] = 1.0, + [1] = 1.0 + }; + + // Act + var aggregatedModel = strategy.Aggregate(clientModels, clientWeights); + + // Assert + Assert.NotNull(aggregatedModel); + Assert.Contains("layer1", aggregatedModel.Keys); + Assert.Equal(3, aggregatedModel["layer1"].Length); + + // Expected: (1+3)/2=2, (2+4)/2=3, (3+5)/2=4 + Assert.Equal(2.0, aggregatedModel["layer1"][0], precision: 5); + Assert.Equal(3.0, aggregatedModel["layer1"][1], precision: 5); + Assert.Equal(4.0, aggregatedModel["layer1"][2], precision: 5); + } + + [Fact] + public void Aggregate_WithDifferentWeights_ReturnsWeightedAverage() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + + var clientModels = new Dictionary> + { + [0] = new Dictionary + { + ["layer1"] = new double[] { 1.0, 2.0 } + }, + [1] = new Dictionary + { + ["layer1"] = new double[] { 3.0, 4.0 } + } + }; + + // Client 1 has 3x the weight (3x more data) + var clientWeights = new Dictionary + { + [0] = 1.0, + [1] = 3.0 + }; + + // Act + var aggregatedModel = strategy.Aggregate(clientModels, clientWeights); + + // Assert + // Expected: (1*1 + 3*3)/(1+3) = 10/4 = 2.5 + // (2*1 + 4*3)/(1+3) = 14/4 = 3.5 + Assert.Equal(2.5, aggregatedModel["layer1"][0], precision: 5); + Assert.Equal(3.5, aggregatedModel["layer1"][1], precision: 5); + } + + [Fact] + public void Aggregate_WithMultipleLayers_AggregatesAllLayers() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + + var clientModels = new Dictionary> + { + [0] = new Dictionary + { + ["layer1"] = new double[] { 1.0 }, + ["layer2"] = new double[] { 2.0 } + }, + [1] = new Dictionary + { + ["layer1"] = new double[] { 3.0 }, + ["layer2"] = new double[] { 4.0 } + } + }; + + var clientWeights = new Dictionary + { + [0] = 1.0, + [1] = 1.0 + }; + + // Act + var aggregatedModel = strategy.Aggregate(clientModels, clientWeights); + + // Assert + Assert.Equal(2, aggregatedModel.Count); + Assert.Contains("layer1", aggregatedModel.Keys); + Assert.Contains("layer2", aggregatedModel.Keys); + Assert.Equal(2.0, aggregatedModel["layer1"][0], precision: 5); + Assert.Equal(3.0, aggregatedModel["layer2"][0], precision: 5); + } + + [Fact] + public void Aggregate_WithEmptyClientModels_ThrowsArgumentException() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + var emptyModels = new Dictionary>(); + var clientWeights = new Dictionary(); + + // Act & Assert + Assert.Throws(() => strategy.Aggregate(emptyModels, clientWeights)); + } + + [Fact] + public void Aggregate_WithNullClientModels_ThrowsArgumentException() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + Dictionary>? nullModels = null; + var clientWeights = new Dictionary(); + + // Act & Assert + Assert.Throws(() => strategy.Aggregate(nullModels!, clientWeights)); + } + + [Fact] + public void GetStrategyName_ReturnsCorrectName() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + + // Act + var name = strategy.GetStrategyName(); + + // Assert + Assert.Equal("FedAvg", name); + } + + [Fact] + public void Aggregate_WithThreeClients_ComputesCorrectWeightedAverage() + { + // Arrange + var strategy = new FedAvgAggregationStrategy(); + + var clientModels = new Dictionary> + { + [0] = new Dictionary + { + ["weights"] = new double[] { 0.1, 0.2, 0.3 } + }, + [1] = new Dictionary + { + ["weights"] = new double[] { 0.2, 0.3, 0.4 } + }, + [2] = new Dictionary + { + ["weights"] = new double[] { 0.3, 0.4, 0.5 } + } + }; + + var clientWeights = new Dictionary + { + [0] = 100.0, // 100 samples + [1] = 200.0, // 200 samples + [2] = 300.0 // 300 samples + }; + + // Act + var aggregatedModel = strategy.Aggregate(clientModels, clientWeights); + + // Assert + // Expected: (0.1*100 + 0.2*200 + 0.3*300) / 600 = (10 + 40 + 90) / 600 = 140/600 = 0.2333... + // (0.2*100 + 0.3*200 + 0.4*300) / 600 = (20 + 60 + 120) / 600 = 200/600 = 0.3333... + // (0.3*100 + 0.4*200 + 0.5*300) / 600 = (30 + 80 + 150) / 600 = 260/600 = 0.4333... + Assert.Equal(0.2333333, aggregatedModel["weights"][0], precision: 5); + Assert.Equal(0.3333333, aggregatedModel["weights"][1], precision: 5); + Assert.Equal(0.4333333, aggregatedModel["weights"][2], precision: 5); + } +} diff --git a/tests/AiDotNet.Tests/FederatedLearning/GaussianDifferentialPrivacyTests.cs b/tests/AiDotNet.Tests/FederatedLearning/GaussianDifferentialPrivacyTests.cs new file mode 100644 index 000000000..e4c3f1ded --- /dev/null +++ b/tests/AiDotNet.Tests/FederatedLearning/GaussianDifferentialPrivacyTests.cs @@ -0,0 +1,199 @@ +namespace AiDotNet.Tests.FederatedLearning; + +using Xunit; +using AiDotNet.FederatedLearning.Privacy; +using System; +using System.Collections.Generic; +using System.Linq; + +/// +/// Unit tests for Gaussian Differential Privacy mechanism. +/// +public class GaussianDifferentialPrivacyTests +{ + [Fact] + public void Constructor_WithValidClipNorm_InitializesSuccessfully() + { + // Arrange & Act + var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0); + + // Assert + Assert.NotNull(dp); + Assert.Equal(0.0, dp.GetPrivacyBudgetConsumed()); + } + + [Fact] + public void Constructor_WithNegativeClipNorm_ThrowsArgumentException() + { + // Act & Assert + Assert.Throws(() => new GaussianDifferentialPrivacy(clipNorm: -1.0)); + } + + [Fact] + public void ApplyPrivacy_WithValidParameters_AddsNoiseToModel() + { + // Arrange + var dp = new GaussianDifferentialPrivacy(clipNorm: 10.0, randomSeed: 42); + + var originalModel = new Dictionary + { + ["layer1"] = new double[] { 1.0, 2.0, 3.0 } + }; + + // Act + var noisyModel = dp.ApplyPrivacy(originalModel, epsilon: 1.0, delta: 1e-5); + + // Assert + Assert.NotNull(noisyModel); + Assert.Contains("layer1", noisyModel.Keys); + + // Model should be different due to noise + bool hasNoise = false; + for (int i = 0; i < originalModel["layer1"].Length; i++) + { + if (Math.Abs(noisyModel["layer1"][i] - originalModel["layer1"][i]) > 0.0001) + { + hasNoise = true; + break; + } + } + Assert.True(hasNoise, "Noise should have been added to the model"); + } + + [Fact] + public void ApplyPrivacy_UpdatesPrivacyBudget() + { + // Arrange + var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0); + + var model = new Dictionary + { + ["layer1"] = new double[] { 1.0 } + }; + + // Act + dp.ApplyPrivacy(model, epsilon: 0.5, delta: 1e-5); + + // Assert + Assert.Equal(0.5, dp.GetPrivacyBudgetConsumed()); + + // Apply privacy again + dp.ApplyPrivacy(model, epsilon: 0.3, delta: 1e-5); + Assert.Equal(0.8, dp.GetPrivacyBudgetConsumed(), precision: 5); + } + + [Fact] + public void ApplyPrivacy_WithZeroEpsilon_ThrowsArgumentException() + { + // Arrange + var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0); + + var model = new Dictionary + { + ["layer1"] = new double[] { 1.0 } + }; + + // Act & Assert + Assert.Throws(() => dp.ApplyPrivacy(model, epsilon: 0.0, delta: 1e-5)); + } + + [Fact] + public void ApplyPrivacy_WithInvalidDelta_ThrowsArgumentException() + { + // Arrange + var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0); + + var model = new Dictionary + { + ["layer1"] = new double[] { 1.0 } + }; + + // Act & Assert + Assert.Throws(() => dp.ApplyPrivacy(model, epsilon: 1.0, delta: 0.0)); + Assert.Throws(() => dp.ApplyPrivacy(model, epsilon: 1.0, delta: 1.0)); + } + + [Fact] + public void ResetPrivacyBudget_ResetsToZero() + { + // Arrange + var dp = new GaussianDifferentialPrivacy(clipNorm: 1.0); + + var model = new Dictionary + { + ["layer1"] = new double[] { 1.0 } + }; + + dp.ApplyPrivacy(model, epsilon: 1.0, delta: 1e-5); + Assert.Equal(1.0, dp.GetPrivacyBudgetConsumed()); + + // Act + dp.ResetPrivacyBudget(); + + // Assert + Assert.Equal(0.0, dp.GetPrivacyBudgetConsumed()); + } + + [Fact] + public void GetMechanismName_ReturnsCorrectName() + { + // Arrange + var dp = new GaussianDifferentialPrivacy(clipNorm: 2.5); + + // Act + var name = dp.GetMechanismName(); + + // Assert + Assert.Contains("Gaussian DP", name); + Assert.Contains("2.5", name); + } + + [Fact] + public void ApplyPrivacy_WithSameSeed_ProducesSameNoise() + { + // Arrange + var dp1 = new GaussianDifferentialPrivacy(clipNorm: 1.0, randomSeed: 123); + var dp2 = new GaussianDifferentialPrivacy(clipNorm: 1.0, randomSeed: 123); + + var model = new Dictionary + { + ["layer1"] = new double[] { 1.0, 2.0, 3.0 } + }; + + // Act + var noisyModel1 = dp1.ApplyPrivacy(model, epsilon: 1.0, delta: 1e-5); + var noisyModel2 = dp2.ApplyPrivacy(model, epsilon: 1.0, delta: 1e-5); + + // Assert + for (int i = 0; i < noisyModel1["layer1"].Length; i++) + { + Assert.Equal(noisyModel1["layer1"][i], noisyModel2["layer1"][i], precision: 10); + } + } + + [Fact] + public void ApplyPrivacy_PerformsGradientClipping() + { + // Arrange + var clipNorm = 1.0; + var dp = new GaussianDifferentialPrivacy(clipNorm: clipNorm, randomSeed: 42); + + // Create model with large norm (sqrt(100 + 100 + 100) = ~17.3) + var model = new Dictionary + { + ["layer1"] = new double[] { 10.0, 10.0, 10.0 } + }; + + // Act + var clippedModel = dp.ApplyPrivacy(model, epsilon: 10.0, delta: 1e-5); + + // Assert + // Calculate L2 norm of clipped model + double sumSquares = clippedModel["layer1"].Sum(x => x * x); + double norm = Math.Sqrt(sumSquares); + + // Norm should be approximately clipNorm (plus some noise variance) + // With high epsilon (10.0), noise is minimal, so norm should be close to clipNorm + Assert.True(norm < clipNorm * 2.0, $"Norm {norm} should be reasonably close to clip norm {clipNorm}"); + } +}