Skip to content

Commit 5e1b699

Browse files
committed
Update_FedAvg_doctests
1 parent f1cf808 commit 5e1b699

File tree

1 file changed

+78
-79
lines changed

1 file changed

+78
-79
lines changed

machine_learning/federated_averaging.py

Lines changed: 78 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,77 @@
11
"""
2-
Federated Averaging (FedAvg)
3-
https://arxiv.org/abs/1602.05629
4-
5-
This module provides a minimal, educational implementation of the Federated
6-
Learning paradigm using the Federated Averaging algorithm. Multiple clients
7-
compute local model updates on their private data and the server aggregates
8-
their updates by (weighted) averaging without collecting raw data.
9-
10-
Notes
11-
-----
12-
- This implementation is framework-agnostic and uses NumPy arrays to represent
13-
model parameters for simplicity and portability within this repository.
14-
- It demonstrates the mechanics of FedAvg, not production concerns like
15-
privacy amplification (e.g., differential privacy), robustness, or security.
16-
17-
Terminology
18-
-----------
19-
- Global model: a list of NumPy arrays representing model parameters.
20-
- Client update: new model parameters produced locally, or the delta from the
21-
global model; we aggregate parameters directly here for clarity.
22-
23-
Examples
24-
--------
25-
Create three synthetic "clients" whose local training produces simple parameter
26-
arrays, then aggregate them with FedAvg.
27-
28-
>>> import numpy as np
29-
>>> # Global model with two parameter tensors
30-
>>> global_model = [np.array([0.0, 0.0]), np.array([[0.0]])]
31-
>>> # Client models after local training
32-
>>> client_models = [
33-
... [np.array([1.0, 2.0]), np.array([[1.0]])],
34-
... [np.array([3.0, 4.0]), np.array([[3.0]])],
35-
... [np.array([5.0, 6.0]), np.array([[5.0]])],
36-
... ]
37-
>>> # Equal weights -> simple average
38-
>>> new_global = federated_average(client_models)
39-
>>> [arr.tolist() for arr in new_global]
40-
[[3.0, 4.0], [[3.0]]]
41-
42-
Weighted averaging by client data sizes:
43-
44-
>>> weights = np.array([10, 20, 30], dtype=float)
45-
>>> new_global_w = federated_average(client_models, weights)
46-
>>> [arr.tolist() for arr in new_global_w]
47-
[[3.6666666666666665, 4.666666666666666], [[3.6666666666666665]]]
48-
49-
Contract
50-
--------
51-
Inputs:
52-
- client_models: list[list[np.ndarray]]: each inner list mirrors model layers
53-
- weights: Optional[np.ndarray] of shape (num_clients,), non-negative, sums to > 0
54-
Output:
55-
- list[np.ndarray]: aggregated model parameters, same shapes as client models
56-
Error modes:
57-
- ValueError for empty clients, shape mismatch, or invalid weights
2+
Federated averaging (FedAvg) utilities.
3+
4+
This module provides a simple NumPy-based implementation of the FedAvg
5+
aggregation algorithm. It supports equal weighting and custom non-negative
6+
weights that are normalized internally.
7+
8+
Doctests
9+
========
10+
11+
Basic equal-weight averaging across two "clients" with two tensors each
12+
(vector and 2x2 matrix):
13+
14+
>>> A = [np.array([1.0, 2.0]), np.array([[1.0, 2.0], [3.0, 4.0]])]
15+
>>> B = [np.array([3.0, 4.0]), np.array([[5.0, 6.0], [7.0, 8.0]])]
16+
>>> eq = federated_average([A, B])
17+
>>> eq[0].tolist()
18+
[2.0, 3.0]
19+
>>> eq[1].tolist()
20+
[[3.0, 4.0], [5.0, 6.0]]
21+
22+
Weighted averaging with weights [2, 1] (normalized to [2/3, 1/3]):
23+
24+
>>> w = federated_average([A, B], weights=np.array([2.0, 1.0]))
25+
>>> w[0].tolist()
26+
[1.6666666666666665, 2.6666666666666665]
27+
>>> w[1].tolist()
28+
[[2.333333333333333, 3.333333333333333], [4.333333333333333, 5.333333333333333]]
29+
30+
Error cases:
31+
32+
- No clients
33+
34+
>>> federated_average([]) # doctest: +ELLIPSIS
35+
Traceback (most recent call last):
36+
...
37+
ValueError: client_models must be a non-empty list
38+
39+
- Mismatched number of tensors per client
40+
41+
>>> C = [np.array([1.0, 2.0])] # only one tensor
42+
>>> federated_average([A, C]) # doctest: +ELLIPSIS
43+
Traceback (most recent call last):
44+
...
45+
ValueError: All clients must have the same number of tensors
46+
47+
- Mismatched tensor shapes across clients
48+
49+
>>> C2 = [np.array([1.0, 2.0]), np.array([[1.0, 2.0]])] # second tensor has different shape
50+
>>> federated_average([A, C2]) # doctest: +ELLIPSIS
51+
Traceback (most recent call last):
52+
...
53+
ValueError: Client 2 tensor shape (1, 2) does not match (2, 2)
54+
55+
- Invalid weights: negative or wrong shape or zero-sum
56+
57+
>>> federated_average([A, B], weights=np.array([1.0, -1.0])) # doctest: +ELLIPSIS
58+
Traceback (most recent call last):
59+
...
60+
ValueError: weights must be non-negative
61+
62+
>>> federated_average([A, B], weights=np.array([0.0, 0.0])) # doctest: +ELLIPSIS
63+
Traceback (most recent call last):
64+
...
65+
ValueError: weights must sum to a positive value
66+
67+
>>> federated_average([A, B], weights=np.array([1.0, 2.0, 3.0])) # doctest: +ELLIPSIS
68+
Traceback (most recent call last):
69+
...
70+
ValueError: weights must have shape (2,)
5871
"""
5972

6073
from __future__ import annotations
61-
6274
from typing import Iterable, List, Sequence
63-
6475
import numpy as np
6576

6677

@@ -94,33 +105,21 @@ def federated_average(
94105
client_models: Sequence[Sequence[np.ndarray]],
95106
weights: np.ndarray | None = None,
96107
) -> List[np.ndarray]:
97-
"""
98-
Aggregate client model parameters using (weighted) averaging.
108+
"""Compute the weighted average of clients' model tensors.
99109
100110
Parameters
101111
----------
102-
client_models : list[list[np.ndarray]]
103-
Model parameters for each client; all clients must have same shapes.
104-
weights : np.ndarray | None
105-
Optional non-negative weights per client. If None, equal weights.
112+
client_models : Sequence[Sequence[np.ndarray]]
113+
A list of clients, each being a sequence of NumPy arrays (tensors).
114+
All clients must have the same number of tensors with identical shapes.
115+
weights : np.ndarray | None, optional
116+
A 1-D array of non-negative weights, one per client. If None,
117+
equal weighting is used. Weights are normalized to sum to 1.
106118
107119
Returns
108120
-------
109-
list[np.ndarray]
110-
Aggregated model parameters (same shapes as client tensors).
111-
112-
Examples
113-
--------
114-
>>> import numpy as np
115-
>>> cm = [
116-
... [np.array([1.0, 2.0])],
117-
... [np.array([3.0, 4.0])],
118-
... ]
119-
>>> [arr.tolist() for arr in federated_average(cm)]
120-
[[2.0, 3.0]]
121-
>>> w = np.array([1.0, 3.0])
122-
>>> [arr.tolist() for arr in federated_average(cm, w)]
123-
[[2.5, 3.5]]
121+
List[np.ndarray]
122+
The list of aggregated tensors with the same shapes as the inputs.
124123
"""
125124
_validate_clients(client_models)
126125
num_clients = len(client_models)

0 commit comments

Comments
 (0)