|
1 | 1 | """ |
2 | | -Federated Averaging (FedAvg) |
3 | | -https://arxiv.org/abs/1602.05629 |
4 | | -
|
5 | | -This module provides a minimal, educational implementation of the Federated |
6 | | -Learning paradigm using the Federated Averaging algorithm. Multiple clients |
7 | | -compute local model updates on their private data and the server aggregates |
8 | | -their updates by (weighted) averaging without collecting raw data. |
9 | | -
|
10 | | -Notes |
11 | | ------ |
12 | | -- This implementation is framework-agnostic and uses NumPy arrays to represent |
13 | | - model parameters for simplicity and portability within this repository. |
14 | | -- It demonstrates the mechanics of FedAvg, not production concerns like |
15 | | - privacy amplification (e.g., differential privacy), robustness, or security. |
16 | | -
|
17 | | -Terminology |
18 | | ------------ |
19 | | -- Global model: a list of NumPy arrays representing model parameters. |
20 | | -- Client update: new model parameters produced locally, or the delta from the |
21 | | - global model; we aggregate parameters directly here for clarity. |
22 | | -
|
23 | | -Examples |
24 | | --------- |
25 | | -Create three synthetic "clients" whose local training produces simple parameter |
26 | | -arrays, then aggregate them with FedAvg. |
27 | | -
|
28 | | ->>> import numpy as np |
29 | | ->>> # Global model with two parameter tensors |
30 | | ->>> global_model = [np.array([0.0, 0.0]), np.array([[0.0]])] |
31 | | ->>> # Client models after local training |
32 | | ->>> client_models = [ |
33 | | -... [np.array([1.0, 2.0]), np.array([[1.0]])], |
34 | | -... [np.array([3.0, 4.0]), np.array([[3.0]])], |
35 | | -... [np.array([5.0, 6.0]), np.array([[5.0]])], |
36 | | -... ] |
37 | | ->>> # Equal weights -> simple average |
38 | | ->>> new_global = federated_average(client_models) |
39 | | ->>> [arr.tolist() for arr in new_global] |
40 | | -[[3.0, 4.0], [[3.0]]] |
41 | | -
|
42 | | -Weighted averaging by client data sizes: |
43 | | -
|
44 | | ->>> weights = np.array([10, 20, 30], dtype=float) |
45 | | ->>> new_global_w = federated_average(client_models, weights) |
46 | | ->>> [arr.tolist() for arr in new_global_w] |
47 | | -[[3.6666666666666665, 4.666666666666666], [[3.6666666666666665]]] |
48 | | -
|
49 | | -Contract |
50 | | --------- |
51 | | -Inputs: |
52 | | - - client_models: list[list[np.ndarray]]: each inner list mirrors model layers |
53 | | - - weights: Optional[np.ndarray] of shape (num_clients,), non-negative, sums to > 0 |
54 | | -Output: |
55 | | - - list[np.ndarray]: aggregated model parameters, same shapes as client models |
56 | | -Error modes: |
57 | | - - ValueError for empty clients, shape mismatch, or invalid weights |
| 2 | +Federated averaging (FedAvg) utilities. |
| 3 | +
|
| 4 | +This module provides a simple NumPy-based implementation of the FedAvg |
| 5 | +aggregation algorithm. It supports equal weighting and custom non-negative |
| 6 | +weights that are normalized internally. |
| 7 | +
|
| 8 | +Doctests |
| 9 | +======== |
| 10 | +
|
| 11 | +Basic equal-weight averaging across two "clients" with two tensors each |
| 12 | +(vector and 2x2 matrix): |
| 13 | +
|
| 14 | +>>> A = [np.array([1.0, 2.0]), np.array([[1.0, 2.0], [3.0, 4.0]])] |
| 15 | +>>> B = [np.array([3.0, 4.0]), np.array([[5.0, 6.0], [7.0, 8.0]])] |
| 16 | +>>> eq = federated_average([A, B]) |
| 17 | +>>> eq[0].tolist() |
| 18 | +[2.0, 3.0] |
| 19 | +>>> eq[1].tolist() |
| 20 | +[[3.0, 4.0], [5.0, 6.0]] |
| 21 | +
|
| 22 | +Weighted averaging with weights [2, 1] (normalized to [2/3, 1/3]): |
| 23 | +
|
| 24 | +>>> w = federated_average([A, B], weights=np.array([2.0, 1.0])) |
| 25 | +>>> w[0].tolist() |
| 26 | +[1.6666666666666665, 2.6666666666666665] |
| 27 | +>>> w[1].tolist() |
| 28 | +[[2.333333333333333, 3.333333333333333], [4.333333333333333, 5.333333333333333]] |
| 29 | +
|
| 30 | +Error cases: |
| 31 | +
|
| 32 | +- No clients |
| 33 | +
|
| 34 | +>>> federated_average([]) # doctest: +ELLIPSIS |
| 35 | +Traceback (most recent call last): |
| 36 | +... |
| 37 | +ValueError: client_models must be a non-empty list |
| 38 | +
|
| 39 | +- Mismatched number of tensors per client |
| 40 | +
|
| 41 | +>>> C = [np.array([1.0, 2.0])] # only one tensor |
| 42 | +>>> federated_average([A, C]) # doctest: +ELLIPSIS |
| 43 | +Traceback (most recent call last): |
| 44 | +... |
| 45 | +ValueError: All clients must have the same number of tensors |
| 46 | +
|
| 47 | +- Mismatched tensor shapes across clients |
| 48 | +
|
| 49 | +>>> C2 = [np.array([1.0, 2.0]), np.array([[1.0, 2.0]])] # second tensor has different shape |
| 50 | +>>> federated_average([A, C2]) # doctest: +ELLIPSIS |
| 51 | +Traceback (most recent call last): |
| 52 | +... |
| 53 | +ValueError: Client 2 tensor shape (1, 2) does not match (2, 2) |
| 54 | +
|
| 55 | +- Invalid weights: negative or wrong shape or zero-sum |
| 56 | +
|
| 57 | +>>> federated_average([A, B], weights=np.array([1.0, -1.0])) # doctest: +ELLIPSIS |
| 58 | +Traceback (most recent call last): |
| 59 | +... |
| 60 | +ValueError: weights must be non-negative |
| 61 | +
|
| 62 | +>>> federated_average([A, B], weights=np.array([0.0, 0.0])) # doctest: +ELLIPSIS |
| 63 | +Traceback (most recent call last): |
| 64 | +... |
| 65 | +ValueError: weights must sum to a positive value |
| 66 | +
|
| 67 | +>>> federated_average([A, B], weights=np.array([1.0, 2.0, 3.0])) # doctest: +ELLIPSIS |
| 68 | +Traceback (most recent call last): |
| 69 | +... |
| 70 | +ValueError: weights must have shape (2,) |
58 | 71 | """ |
59 | 72 |
|
60 | 73 | from __future__ import annotations |
61 | | - |
62 | 74 | from typing import Iterable, List, Sequence |
63 | | - |
64 | 75 | import numpy as np |
65 | 76 |
|
66 | 77 |
|
@@ -94,33 +105,21 @@ def federated_average( |
94 | 105 | client_models: Sequence[Sequence[np.ndarray]], |
95 | 106 | weights: np.ndarray | None = None, |
96 | 107 | ) -> List[np.ndarray]: |
97 | | - """ |
98 | | - Aggregate client model parameters using (weighted) averaging. |
| 108 | + """Compute the weighted average of clients' model tensors. |
99 | 109 |
|
100 | 110 | Parameters |
101 | 111 | ---------- |
102 | | - client_models : list[list[np.ndarray]] |
103 | | - Model parameters for each client; all clients must have same shapes. |
104 | | - weights : np.ndarray | None |
105 | | - Optional non-negative weights per client. If None, equal weights. |
| 112 | + client_models : Sequence[Sequence[np.ndarray]] |
| 113 | + A list of clients, each being a sequence of NumPy arrays (tensors). |
| 114 | + All clients must have the same number of tensors with identical shapes. |
| 115 | + weights : np.ndarray | None, optional |
| 116 | + A 1-D array of non-negative weights, one per client. If None, |
| 117 | + equal weighting is used. Weights are normalized to sum to 1. |
106 | 118 |
|
107 | 119 | Returns |
108 | 120 | ------- |
109 | | - list[np.ndarray] |
110 | | - Aggregated model parameters (same shapes as client tensors). |
111 | | -
|
112 | | - Examples |
113 | | - -------- |
114 | | - >>> import numpy as np |
115 | | - >>> cm = [ |
116 | | - ... [np.array([1.0, 2.0])], |
117 | | - ... [np.array([3.0, 4.0])], |
118 | | - ... ] |
119 | | - >>> [arr.tolist() for arr in federated_average(cm)] |
120 | | - [[2.0, 3.0]] |
121 | | - >>> w = np.array([1.0, 3.0]) |
122 | | - >>> [arr.tolist() for arr in federated_average(cm, w)] |
123 | | - [[2.5, 3.5]] |
| 121 | + List[np.ndarray] |
| 122 | + The list of aggregated tensors with the same shapes as the inputs. |
124 | 123 | """ |
125 | 124 | _validate_clients(client_models) |
126 | 125 | num_clients = len(client_models) |
|
0 commit comments