Skip to content

Commit e2c6ebf

Browse files
Add files via upload
1 parent 36c109f commit e2c6ebf

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from setuptools import find_packages
2+
from setuptools import setup
3+
4+
setup(
5+
name='breast_cancer_federated_learning',
6+
version='0.1',
7+
install_requires=[
8+
'torch',
9+
'numpy',
10+
'pandas',
11+
'matplotlib',
12+
'scikit-learn',
13+
'google-cloud-storage',
14+
'google-cloud-aiplatform',
15+
],
16+
packages=find_packages(),
17+
include_package_data=True,
18+
description='Breast Cancer Federated Learning Training Script',
19+
)
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Importing Libraries
2+
import os
3+
import torch
4+
import numpy as np
5+
import pandas as pd
6+
import torch.nn as nn
7+
import matplotlib.pyplot as plt
8+
9+
from torch.nn import Sequential
10+
from collections import OrderedDict
11+
from torch.utils.data import Dataset, DataLoader
12+
from sklearn.preprocessing import StandardScaler
13+
from google.cloud import storage
14+
from io import BytesIO
15+
16+
# BreastCancerDataset Class
17+
class BreastCancerDataset(Dataset):
18+
def __init__(self, df):
19+
scaler = StandardScaler()
20+
self.X = torch.tensor(scaler.fit_transform(df.iloc[:,1:-1].values)) # first (ID) and last (diagnosis) columns are excluded
21+
self.y = torch.tensor(df.iloc[:,-1].values) # load the diagnosis (malignant=1, benign=0)
22+
23+
def __len__(self):
24+
return len(self.X)
25+
26+
def __getitem__(self, idx):
27+
return self.X[idx], self.y[idx]
28+
29+
# Function to Load Data from GCS
30+
# Description: This function stores downloads files from cloud storage and reads them into a pandas dataframe.
31+
def load_dataset_from_gcs(bucket_name, file_path):
32+
client = storage.Client()
33+
bucket = client.get_bucket(bucket_name)
34+
blob = bucket.blob(file_path)
35+
data = blob.download_as_string()
36+
df = pd.read_csv(BytesIO(data))
37+
return df
38+
39+
# Client Class (same as above)
40+
class Client:
41+
def __init__(self, name, model, train_loader, val_loader, optimizer, criterion):
42+
self.name = name
43+
self.model = model
44+
self.optimizer = optimizer
45+
self.criterion = criterion
46+
self.train_loader = train_loader
47+
self.val_loader = val_loader
48+
self.metrics = dict({"train_acc": list(), "train_loss": list(), "val_acc": list(), "val_loss": list()})
49+
50+
print(f"[INFO] Initialized client '{self.name}' with {len(train_loader.dataset)} train and {len(val_loader.dataset)} validation samples")
51+
52+
def train(self):
53+
"""
54+
Trains the model of the client for 1 epoch.
55+
"""
56+
self.model.train()
57+
correct_predictions = 0
58+
running_loss = 0.0
59+
60+
# iterate over training dataset
61+
for inputs, labels in self.train_loader:
62+
# make predictions
63+
self.optimizer.zero_grad()
64+
outputs = self.model(inputs)
65+
labels = torch.unsqueeze(labels, 1)
66+
67+
# apply gradient
68+
loss = self.criterion(outputs, labels)
69+
loss.backward()
70+
self.optimizer.step()
71+
running_loss += loss.item()
72+
73+
# calculate number of correct predictions
74+
predicted = torch.round(outputs)
75+
correct_predictions += (predicted == labels).sum().item()
76+
77+
# calculate overall loss and acc.
78+
epoch_loss = running_loss / len(self.train_loader)
79+
accuracy = correct_predictions / len(self.train_loader.dataset)
80+
81+
# save metrics
82+
self.metrics["train_acc"].append(accuracy)
83+
self.metrics["train_loss"].append(epoch_loss)
84+
85+
def validate(self):
86+
"""
87+
Validates the model of the client based on the given validation data loader.
88+
"""
89+
self.model.eval()
90+
total_loss = 0
91+
correct_predictions = 0
92+
93+
# iterate over validation data loader and make predictions
94+
with torch.no_grad():
95+
for inputs, labels in self.val_loader:
96+
outputs = self.model(inputs)
97+
labels = torch.unsqueeze(labels, 1)
98+
loss = self.criterion(outputs, labels)
99+
100+
total_loss += loss.item()
101+
predicted = torch.round(outputs)
102+
correct_predictions += (predicted == labels).sum().item()
103+
104+
# calculate overall loss and acc.
105+
average_loss = total_loss / len(self.val_loader)
106+
accuracy = correct_predictions / len(self.val_loader.dataset)
107+
108+
# save metrics
109+
self.metrics["val_acc"].append(accuracy)
110+
self.metrics["val_loss"].append(average_loss)
111+
112+
# SimpleNN Model Definition (same as above)
113+
class SimpleNN(nn.Module):
114+
def __init__(self, n_input):
115+
super(SimpleNN, self).__init__()
116+
self.NN = Sequential(
117+
nn.Linear(n_input, 32),
118+
nn.ReLU(),
119+
nn.Linear(32, 16),
120+
nn.ReLU(),
121+
nn.Linear(16,1),
122+
nn.Sigmoid()
123+
)
124+
125+
def forward(self, x):
126+
logits = self.NN(x)
127+
return logits
128+
129+
# FedAvg Function (same as above)
130+
131+
def fed_avg(global_state_dict, client_states, n_data_points):
132+
"""
133+
Averages the weights of client models to update the global model by FedAvg.
134+
135+
Args:
136+
global_state_dict: The state dict of the global PyTorch model.
137+
client_states: A list of PyTorch models state dicts representing client models.
138+
n_data_points: A list with the number of data points per client.
139+
140+
Returns:
141+
The state dict of the updated global PyTorch model.
142+
"""
143+
averaged_state_dict = OrderedDict()
144+
145+
for key in global_state_dict.keys():
146+
for state, n in zip(client_states, n_data_points):
147+
averaged_state_dict[key] =+ state[key] * (n/ sum(n_data_points))
148+
149+
return averaged_state_dict
150+
151+
# FLServer Class
152+
class FLServer:
153+
def __init__(self, model, clients):
154+
self.model = model
155+
self.clients = clients
156+
self.n_data_points = [len(client.train_loader.dataset) for client in self.clients]
157+
158+
def run(self, epochs):
159+
for i in range(epochs):
160+
print(f"Epoch {i}")
161+
162+
# Step 2 of figure at the beginning of the tutorial
163+
for client in self.clients:
164+
client.train()
165+
166+
# aggregate the models using FedAvg (Step 3 & 4 of figure at the beginning of the tutorial)
167+
client_states = [client.model.state_dict() for client in self.clients] # Step 3
168+
aggregated_state = fed_avg(self.model.state_dict(), client_states, self.n_data_points) # Step 4
169+
self.model.load_state_dict(aggregated_state)
170+
171+
# redistribute central model (Step 1 of figure at the beginning of the tutorial)
172+
for client in fl_server.clients:
173+
client.model.load_state_dict(aggregated_state)
174+
175+
# run validation of aggregated model
176+
for client in self.clients:
177+
client.validate()
178+
179+
# repeat for n epochs (Step 5 of figure at the beginning of the tutorial)
180+
181+
# Plotting Metrics
182+
def plot_metrics(client):
183+
plt.figure(figsize=(8, 4))
184+
for k, v in client.metrics.items():
185+
x_vals = range(len(v))
186+
plt.plot(x_vals, v, label=k)
187+
188+
plt.ylim(bottom=0.0, top=1.0)
189+
plt.xlim(left=0)
190+
plt.xlabel("Epoch")
191+
plt.ylabel("Metric")
192+
plt.title(client.name)
193+
plt.legend()
194+
plt.show()
195+
196+
# Running Prediction on validation data
197+
def run_prediction(model, bucket_name, validation_file_path):
198+
model.eval()
199+
val_df = load_dataset_from_gcs(bucket_name, validation_file_path)
200+
val_data = BreastCancerDataset(val_df)
201+
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)
202+
203+
correct_predictions = 0
204+
with torch.no_grad():
205+
for inputs, labels in val_dataloader:
206+
outputs = model(inputs)
207+
labels = torch.unsqueeze(labels, 1)
208+
predicted = torch.round(outputs)
209+
correct_predictions += (predicted == labels).sum().item()
210+
211+
accuracy = correct_predictions / len(val_dataloader.dataset)
212+
print(f"{accuracy:.2f}")
213+
return accuracy
214+
215+
# Main Function
216+
def main():
217+
import argparse
218+
#arguments are parsed from the command line
219+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
220+
parser.add_argument('--bucket_name', type=str, required=True, help='GCS bucket name')
221+
parser.add_argument('--train_file', type=str, required=True, help='Path to the training file in GCS')
222+
parser.add_argument('--validation_file', type=str, required=True, help='Path to the validation file in GCS')
223+
parser.add_argument('--output_dir', type=str, required=True, help='Output directory for the model in GCS')
224+
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train')
225+
parser.add_argument('--batch_size', type=int, default=50, help='Batch size for training')
226+
args = parser.parse_args()
227+
228+
# Load datasets from GCS
229+
train_df = load_dataset_from_gcs(args.bucket_name, args.train_file)
230+
val_df = load_dataset_from_gcs(args.bucket_name, args.validation_file)
231+
232+
train_data = BreastCancerDataset(train_df)
233+
val_data = BreastCancerDataset(val_df)
234+
235+
train_dataloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
236+
val_dataloader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False)
237+
238+
# Initialize model and client for centralized training
239+
model = SimpleNN(n_input=30)
240+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
241+
criterion = nn.BCELoss()
242+
central_client = Client("central", model, train_dataloader, val_dataloader, optimizer, criterion)
243+
244+
# Centralized training
245+
for i in range(args.epochs):
246+
print(f"Epoch {i}")
247+
central_client.train()
248+
central_client.validate()
249+
250+
plot_metrics(central_client)
251+
252+
print("Accuracy of the centrally trained model:")
253+
run_prediction(central_client.model, args.bucket_name, args.test_file)
254+
255+
# Federated Learning
256+
fed_model = SimpleNN(n_input=30)
257+
clients = list()
258+
for i in range(2):
259+
train_df = load_dataset_from_gcs(args.bucket_name, f"client_{i}/train_data.csv")
260+
val_df = load_dataset_from_gcs(args.bucket_name, f"client_{i}/val_data.csv")
261+
262+
train_data = BreastCancerDataset(train_df)
263+
val_data = BreastCancerDataset(val_df)
264+
265+
train_dataloader = DataLoader(train_data, batch_size=7, shuffle=True)
266+
val_dataloader = DataLoader(val_data, batch_size=7, shuffle=False)
267+
268+
optimizer = torch.optim.SGD(fed_model.parameters(), lr=0.01, momentum=0.9)
269+
criterion = nn.BCELoss()
270+
271+
clients.append(Client(f"client_{i}", fed_model, train_dataloader, val_dataloader, optimizer, criterion))
272+
273+
fl_server = FLServer(fed_model, clients)
274+
275+
for client in fl_server.clients:
276+
client.model.load_state_dict(fl_server.model.state_dict())
277+
278+
fl_server.run(epochs=args.epochs)
279+
280+
for client in fl_server.clients:
281+
plot_metrics(client)
282+
283+
print("Model trained with federated learning accuracy:")
284+
run_prediction(fl_server.model, args.bucket_name, args.test_file)
285+
286+
# Save the model to GCS
287+
client = storage.Client()
288+
bucket = client.get_bucket(args.bucket_name)
289+
model_path = os.path.join(args.output_dir, "fed_model.pth")
290+
torch.save(fed_model.state_dict(), model_path)
291+
bucket.blob(model_path).upload_from_filename(model_path)
292+
293+
if __name__ == "__main__":
294+
main()

0 commit comments

Comments
 (0)