Skip to content

Commit 184d846

Browse files
author
“Zelaikha
committed
Updated FL tutorial added explainations and a draft Vertex AI process
1 parent ecbdd95 commit 184d846

File tree

6 files changed

+2019
-538
lines changed

6 files changed

+2019
-538
lines changed
519 KB
Loading

notebooks/FederatedLearning/GCP_FederatedLearning.ipynb

Lines changed: 1613 additions & 275 deletions
Large diffs are not rendered by default.

notebooks/FederatedLearning/scripts/fl_packages/_init_.py

Whitespace-only changes.
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import os
2+
import torch
3+
4+
import pickle
5+
import numpy as np
6+
import pandas as pd
7+
import torch.nn as nn
8+
import matplotlib.pyplot as plt
9+
10+
from kfp.v2 import compiler
11+
from torch.nn import Sequential
12+
from collections import OrderedDict
13+
from google.cloud import aiplatform
14+
from torch.utils.data import Dataset
15+
from torch.utils.data import DataLoader
16+
from kfp.v2.dsl import component, Output, Dataset, Model, Input, Artifact, Metrics
17+
from sklearn.preprocessing import StandardScaler
18+
19+
class SimpleNN(nn.Module):
20+
def __init__(self, n_input):
21+
super(SimpleNN, self).__init__()
22+
self.NN = Sequential(
23+
nn.Linear(n_input, 32),
24+
nn.ReLU(),
25+
nn.Linear(32, 16),
26+
nn.ReLU(),
27+
nn.Linear(16, 1),
28+
nn.Sigmoid()
29+
)
30+
def forward(self, x):
31+
return self.NN(x)
32+
33+
class BreastCancerDataset(Dataset):
34+
def __init__(self, df):
35+
scaler = StandardScaler()
36+
self.X = torch.tensor(scaler.fit_transform(df.iloc[:,1:-1].values)) # first (ID) and last (diagnisis) columns are excluded
37+
self.y = torch.tensor(df.iloc[:,-1].values) # load the diagnosis (malignant=1, benign=0)
38+
39+
def __len__(self):
40+
return len(self.X)
41+
42+
def __getitem__(self, idx):
43+
return self.X[idx], self.y[idx]
44+
45+
class Client:
46+
def __init__(self, name, model, train_loader, val_loader, optimizer, criterion):
47+
self.name = name
48+
self.model = model
49+
self.optimizer = optimizer
50+
self.criterion = criterion
51+
self.train_loader = train_loader
52+
self.val_loader = val_loader
53+
self.metrics = dict({"train_acc": list(), "train_loss": list(), "val_acc": list(), "val_loss": list()})
54+
55+
print(f"[INFO] Initialized client '{self.name}' with {len(train_loader.dataset)} train and {len(val_loader.dataset)} validation samples")
56+
57+
58+
def train(self):
59+
"""
60+
Trains the model of the client for 1 epoch.
61+
"""
62+
self.model.train()
63+
correct_predictions = 0
64+
running_loss = 0.0
65+
66+
# iterate over training dataset
67+
for inputs, labels in self.train_loader:
68+
# make predictions
69+
self.optimizer.zero_grad()
70+
outputs = self.model(inputs)
71+
labels = torch.unsqueeze(labels, 1)
72+
73+
# apply gradient
74+
loss = self.criterion(outputs, labels)
75+
loss.backward()
76+
self.optimizer.step()
77+
running_loss += loss.item()
78+
79+
# calculate number of correct predictions
80+
predicted = torch.round(outputs)
81+
correct_predictions += (predicted == labels).sum().item()
82+
83+
# calculate overall loss and acc.
84+
epoch_loss = running_loss / len(self.train_loader)
85+
accuracy = correct_predictions / len(self.train_loader.dataset)
86+
87+
# save metrics
88+
self.metrics["train_acc"].append(accuracy)
89+
self.metrics["train_loss"].append(epoch_loss)
90+
91+
def validate(self):
92+
"""
93+
Validates the model of the client based on the given validation data loader.
94+
"""
95+
self.model.eval()
96+
total_loss = 0
97+
correct_predictions = 0
98+
99+
# iterate over validation data loader and make predictions
100+
with torch.no_grad():
101+
for inputs, labels in self.val_loader:
102+
outputs = self.model(inputs)
103+
labels = torch.unsqueeze(labels, 1)
104+
loss = self.criterion(outputs, labels)
105+
106+
total_loss += loss.item()
107+
predicted = torch.round(outputs)
108+
correct_predictions += (predicted == labels).sum().item()
109+
110+
# calculate overall loss and acc.
111+
average_loss = total_loss / len(self.val_loader)
112+
accuracy = correct_predictions / len(self.val_loader.dataset)
113+
114+
# save metrics
115+
self.metrics["val_acc"].append(accuracy)
116+
self.metrics["val_loss"].append(average_loss)
117+
118+
class FLServer:
119+
def __init__(self, model, clients):
120+
self.model = model
121+
self.clients = clients
122+
self.n_data_points = [len(client.train_loader.dataset) for client in self.clients]
123+
124+
def run(self, epochs):
125+
for i in range(epochs):
126+
print(f"Epoch {i}")
127+
128+
# Step 2 of figure at the beginning of the tutorial
129+
for client in self.clients:
130+
client.train()
131+
132+
# aggregate the models using FedAvg (Step 3 & 4 of figure at the beginning of the tutorial)
133+
client_states = [client.model.state_dict() for client in self.clients] # Step 3
134+
aggregated_state = fed_avg(self.model.state_dict(), client_states, self.n_data_points) # Step 4
135+
self.model.load_state_dict(aggregated_state)
136+
137+
# redistribute central model (Step 1 of figure at the beginning of the tutorial)
138+
for client in fl_server.clients:
139+
client.model.load_state_dict(aggregated_state)
140+
141+
# run validation of aggregated model
142+
for client in self.clients:
143+
client.validate()
144+
145+
# repeat for n epochs (Step 5 of figure at the beginning of the tutorial
Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
from setuptools import find_packages
2-
from setuptools import setup
1+
from setuptools import find_packages, setup
2+
3+
# File: setup.py
4+
from setuptools import setup, find_packages
35

46
setup(
5-
name='breast_cancer_federated_learning',
6-
version='0.1',
7+
name="my_package",
8+
version="0.1",
9+
packages=find_packages(),
710
install_requires=[
8-
'torch',
9-
'numpy',
10-
'pandas',
11-
'matplotlib',
12-
'scikit-learn',
13-
'google-cloud-storage',
14-
'google-cloud-aiplatform',
11+
"torch",
12+
"pandas",
13+
"scikit-learn",
14+
"matplotlib",
15+
"ordereddict" # Add any dependencies your package needs
1516
],
16-
packages=find_packages(),
17-
include_package_data=True,
18-
description='Breast Cancer Federated Learning Training Script',
1917
)

0 commit comments

Comments
 (0)