|
| 1 | +# Federated Learning |
| 2 | + |
| 3 | +Federated Learning is a specialized form of distributed machine learning |
| 4 | +that enables multiple client devices (such as smartphones and personal |
| 5 | +computers) to collaboratively train a model without sharing their |
| 6 | +private datasets. This approach primarily aims to enhance privacy |
| 7 | +protection for users. |
| 8 | + |
| 9 | +In a federated learning system, training data remains on the clients' |
| 10 | +devices, with only model parameters exchanged among participants. This |
| 11 | +contrasts with traditional distributed machine learning, where a single |
| 12 | +entity collects the entire dataset into a data center. For instance, an |
| 13 | +input method software company might record user input on mobile devices |
| 14 | +and upload this data to servers for model training. In federated |
| 15 | +learning, data stays on users' devices, and the model is trained and |
| 16 | +updated locally, with parameters shared among participants to update the |
| 17 | +model. |
| 18 | + |
| 19 | +Federated learning systems can be classified into two types: |
| 20 | +cross-device and cross-organizational. The main difference lies in the |
| 21 | +nature of the clients involved. |
| 22 | + |
| 23 | +In cross-device federated learning, the clients are typically personal |
| 24 | +user devices. These devices often have limited computational |
| 25 | +capabilities, unstable communication, and may not always be online. On |
| 26 | +the other hand, in cross-organizational federated learning, clients are |
| 27 | +usually servers of large institutions like hospitals and banks. These |
| 28 | +clients, though fewer in number, possess strong computational |
| 29 | +capabilities and stable network connections. |
| 30 | + |
| 31 | +A notable example of federated learning deployment is its optimization |
| 32 | +for mobile phone input methods, as demonstrated by Google's research. |
| 33 | +Predicting the user's next input word using machine learning models can |
| 34 | +significantly enhance user experience. Traditionally, user input content |
| 35 | +needs to be collected on the service provider's server for training |
| 36 | +data. However, due to privacy concerns, users may not want their input |
| 37 | +content collected. Federated learning addresses this by sending only |
| 38 | +model parameters to user devices. Client programs locally record input |
| 39 | +data, update model parameters, and then upload the updated parameters |
| 40 | +back to the server. By aggregating updates from multiple users, the |
| 41 | +central server can improve the model's accuracy without accessing users' |
| 42 | +private data. |
| 43 | + |
| 44 | +Other examples of federated learning systems are often found in |
| 45 | +healthcare and financial sectors. For example, in healthcare, multiple |
| 46 | +hospitals can collaboratively train models through federated learning |
| 47 | +without sharing patients' raw data, thereby enhancing diagnostic |
| 48 | +support. |
| 49 | + |
| 50 | +## Key Operations in Federated Learning Systems |
| 51 | + |
| 52 | + |
| 53 | +:label:`ch010/ch10-federated-learning-systems` |
| 54 | + |
| 55 | +We use an input method's next-word prediction task to illustrate the |
| 56 | +typical architecture of a federated learning system. As shown in Figure |
| 57 | +1, there is usually a server and multiple clients. The server can be a |
| 58 | +cloud server belonging to the input method provider, while the clients |
| 59 | +run the input method programs on user devices. |
| 60 | + |
| 61 | +To train the latest model using user data, the input method provider can |
| 62 | +initiate a federated learning-based model training session. During this |
| 63 | +session, they must select a training algorithm, such as FedSGD or |
| 64 | +FedAvg. For illustration, we will use the widely-adopted FedAvg |
| 65 | +algorithm defined in |
| 66 | +Algorithm :numref:`fedavg`. FedAvg primarily has the following steps: |
| 67 | + |
| 68 | +1. Model parameter initialization: In the first step, the server |
| 69 | + initializes the model parameters as input parameters for the |
| 70 | + federated learning system. |
| 71 | + |
| 72 | +2. Client selection: The system selects a batch of clients from the |
| 73 | + user pool based on the following criteria: (i) Devices connected to |
| 74 | + a stable local network (e.g., Wi-Fi), (ii) Users not actively using |
| 75 | + the device, and (iii) Devices being charged The current model |
| 76 | + parameters are then broadcast to these selected clients. |
| 77 | + |
| 78 | +3. Local updates: Clients receive the model parameters and conduct |
| 79 | + local model training, such as using the SGD algorithm. Unlike |
| 80 | + typical distributed training, federated learning performs multiple |
| 81 | + rounds of gradient updates locally to reduce the cost of uploading |
| 82 | + parameters each round. After several updates, clients upload their |
| 83 | + latest local model parameters back to the server. |
| 84 | + |
| 85 | +4. Global aggregation and update: The server calculates a weighted |
| 86 | + average of the received model parameters to obtain the new global |
| 87 | + model parameters. This process repeats until the model accuracy |
| 88 | + meets the requirements or the loss is sufficiently low. |
| 89 | + |
| 90 | + |
| 91 | +:label:`fedavg` |
| 92 | + |
| 93 | +Deploying the FedAvg system in practice presents several challenges. |
| 94 | +First, due to the lack of low-latency, high-bandwidth network |
| 95 | +connections between the clients and servers involved in federated |
| 96 | +learning, it is not feasible to average gradients every round. To |
| 97 | +address this, practitioners often average model parameters every few |
| 98 | +rounds instead of averaging gradients every round. Further, the |
| 99 | +configurations of devices used for model training can vary |
| 100 | +significantly, leading to differences in the time taken by selected |
| 101 | +clients to complete training. To address this, practitioners often |
| 102 | +Select more clients than needed during client selection. The server then |
| 103 | +aggregates parameters and moves to the next iteration once a sufficient |
| 104 | +number of clients return model parameters. |
0 commit comments