1+ # Importing Libraries
2+ import os
3+ import torch
4+ import numpy as np
5+ import pandas as pd
6+ import torch .nn as nn
7+ import matplotlib .pyplot as plt
8+
9+ from torch .nn import Sequential
10+ from collections import OrderedDict
11+ from torch .utils .data import Dataset , DataLoader
12+ from sklearn .preprocessing import StandardScaler
13+ from google .cloud import storage
14+ from io import BytesIO
15+
16+ # BreastCancerDataset Class
17+ class BreastCancerDataset (Dataset ):
18+ def __init__ (self , df ):
19+ scaler = StandardScaler ()
20+ self .X = torch .tensor (scaler .fit_transform (df .iloc [:,1 :- 1 ].values )) # first (ID) and last (diagnosis) columns are excluded
21+ self .y = torch .tensor (df .iloc [:,- 1 ].values ) # load the diagnosis (malignant=1, benign=0)
22+
23+ def __len__ (self ):
24+ return len (self .X )
25+
26+ def __getitem__ (self , idx ):
27+ return self .X [idx ], self .y [idx ]
28+
29+ # Function to Load Data from GCS
30+ # Description: This function stores downloads files from cloud storage and reads them into a pandas dataframe.
31+ def load_dataset_from_gcs (bucket_name , file_path ):
32+ client = storage .Client ()
33+ bucket = client .get_bucket (bucket_name )
34+ blob = bucket .blob (file_path )
35+ data = blob .download_as_string ()
36+ df = pd .read_csv (BytesIO (data ))
37+ return df
38+
39+ # Client Class (same as above)
40+ class Client :
41+ def __init__ (self , name , model , train_loader , val_loader , optimizer , criterion ):
42+ self .name = name
43+ self .model = model
44+ self .optimizer = optimizer
45+ self .criterion = criterion
46+ self .train_loader = train_loader
47+ self .val_loader = val_loader
48+ self .metrics = dict ({"train_acc" : list (), "train_loss" : list (), "val_acc" : list (), "val_loss" : list ()})
49+
50+ print (f"[INFO] Initialized client '{ self .name } ' with { len (train_loader .dataset )} train and { len (val_loader .dataset )} validation samples" )
51+
52+ def train (self ):
53+ """
54+ Trains the model of the client for 1 epoch.
55+ """
56+ self .model .train ()
57+ correct_predictions = 0
58+ running_loss = 0.0
59+
60+ # iterate over training dataset
61+ for inputs , labels in self .train_loader :
62+ # make predictions
63+ self .optimizer .zero_grad ()
64+ outputs = self .model (inputs )
65+ labels = torch .unsqueeze (labels , 1 )
66+
67+ # apply gradient
68+ loss = self .criterion (outputs , labels )
69+ loss .backward ()
70+ self .optimizer .step ()
71+ running_loss += loss .item ()
72+
73+ # calculate number of correct predictions
74+ predicted = torch .round (outputs )
75+ correct_predictions += (predicted == labels ).sum ().item ()
76+
77+ # calculate overall loss and acc.
78+ epoch_loss = running_loss / len (self .train_loader )
79+ accuracy = correct_predictions / len (self .train_loader .dataset )
80+
81+ # save metrics
82+ self .metrics ["train_acc" ].append (accuracy )
83+ self .metrics ["train_loss" ].append (epoch_loss )
84+
85+ def validate (self ):
86+ """
87+ Validates the model of the client based on the given validation data loader.
88+ """
89+ self .model .eval ()
90+ total_loss = 0
91+ correct_predictions = 0
92+
93+ # iterate over validation data loader and make predictions
94+ with torch .no_grad ():
95+ for inputs , labels in self .val_loader :
96+ outputs = self .model (inputs )
97+ labels = torch .unsqueeze (labels , 1 )
98+ loss = self .criterion (outputs , labels )
99+
100+ total_loss += loss .item ()
101+ predicted = torch .round (outputs )
102+ correct_predictions += (predicted == labels ).sum ().item ()
103+
104+ # calculate overall loss and acc.
105+ average_loss = total_loss / len (self .val_loader )
106+ accuracy = correct_predictions / len (self .val_loader .dataset )
107+
108+ # save metrics
109+ self .metrics ["val_acc" ].append (accuracy )
110+ self .metrics ["val_loss" ].append (average_loss )
111+
112+ # SimpleNN Model Definition (same as above)
113+ class SimpleNN (nn .Module ):
114+ def __init__ (self , n_input ):
115+ super (SimpleNN , self ).__init__ ()
116+ self .NN = Sequential (
117+ nn .Linear (n_input , 32 ),
118+ nn .ReLU (),
119+ nn .Linear (32 , 16 ),
120+ nn .ReLU (),
121+ nn .Linear (16 ,1 ),
122+ nn .Sigmoid ()
123+ )
124+
125+ def forward (self , x ):
126+ logits = self .NN (x )
127+ return logits
128+
129+ # FedAvg Function (same as above)
130+
131+ def fed_avg (global_state_dict , client_states , n_data_points ):
132+ """
133+ Averages the weights of client models to update the global model by FedAvg.
134+
135+ Args:
136+ global_state_dict: The state dict of the global PyTorch model.
137+ client_states: A list of PyTorch models state dicts representing client models.
138+ n_data_points: A list with the number of data points per client.
139+
140+ Returns:
141+ The state dict of the updated global PyTorch model.
142+ """
143+ averaged_state_dict = OrderedDict ()
144+
145+ for key in global_state_dict .keys ():
146+ for state , n in zip (client_states , n_data_points ):
147+ averaged_state_dict [key ] = + state [key ] * (n / sum (n_data_points ))
148+
149+ return averaged_state_dict
150+
151+ # FLServer Class
152+ class FLServer :
153+ def __init__ (self , model , clients ):
154+ self .model = model
155+ self .clients = clients
156+ self .n_data_points = [len (client .train_loader .dataset ) for client in self .clients ]
157+
158+ def run (self , epochs ):
159+ for i in range (epochs ):
160+ print (f"Epoch { i } " )
161+
162+ # Step 2 of figure at the beginning of the tutorial
163+ for client in self .clients :
164+ client .train ()
165+
166+ # aggregate the models using FedAvg (Step 3 & 4 of figure at the beginning of the tutorial)
167+ client_states = [client .model .state_dict () for client in self .clients ] # Step 3
168+ aggregated_state = fed_avg (self .model .state_dict (), client_states , self .n_data_points ) # Step 4
169+ self .model .load_state_dict (aggregated_state )
170+
171+ # redistribute central model (Step 1 of figure at the beginning of the tutorial)
172+ for client in fl_server .clients :
173+ client .model .load_state_dict (aggregated_state )
174+
175+ # run validation of aggregated model
176+ for client in self .clients :
177+ client .validate ()
178+
179+ # repeat for n epochs (Step 5 of figure at the beginning of the tutorial)
180+
181+ # Plotting Metrics
182+ def plot_metrics (client ):
183+ plt .figure (figsize = (8 , 4 ))
184+ for k , v in client .metrics .items ():
185+ x_vals = range (len (v ))
186+ plt .plot (x_vals , v , label = k )
187+
188+ plt .ylim (bottom = 0.0 , top = 1.0 )
189+ plt .xlim (left = 0 )
190+ plt .xlabel ("Epoch" )
191+ plt .ylabel ("Metric" )
192+ plt .title (client .name )
193+ plt .legend ()
194+ plt .show ()
195+
196+ # Running Prediction on validation data
197+ def run_prediction (model , bucket_name , validation_file_path ):
198+ model .eval ()
199+ val_df = load_dataset_from_gcs (bucket_name , validation_file_path )
200+ val_data = BreastCancerDataset (val_df )
201+ val_dataloader = DataLoader (val_data , batch_size = 1 , shuffle = False )
202+
203+ correct_predictions = 0
204+ with torch .no_grad ():
205+ for inputs , labels in val_dataloader :
206+ outputs = model (inputs )
207+ labels = torch .unsqueeze (labels , 1 )
208+ predicted = torch .round (outputs )
209+ correct_predictions += (predicted == labels ).sum ().item ()
210+
211+ accuracy = correct_predictions / len (val_dataloader .dataset )
212+ print (f"{ accuracy :.2f} " )
213+ return accuracy
214+
215+ # Main Function
216+ def main ():
217+ import argparse
218+ #arguments are parsed from the command line
219+ parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
220+ parser .add_argument ('--bucket_name' , type = str , required = True , help = 'GCS bucket name' )
221+ parser .add_argument ('--train_file' , type = str , required = True , help = 'Path to the training file in GCS' )
222+ parser .add_argument ('--validation_file' , type = str , required = True , help = 'Path to the validation file in GCS' )
223+ parser .add_argument ('--output_dir' , type = str , required = True , help = 'Output directory for the model in GCS' )
224+ parser .add_argument ('--epochs' , type = int , default = 10 , help = 'Number of epochs to train' )
225+ parser .add_argument ('--batch_size' , type = int , default = 50 , help = 'Batch size for training' )
226+ args = parser .parse_args ()
227+
228+ # Load datasets from GCS
229+ train_df = load_dataset_from_gcs (args .bucket_name , args .train_file )
230+ val_df = load_dataset_from_gcs (args .bucket_name , args .validation_file )
231+
232+ train_data = BreastCancerDataset (train_df )
233+ val_data = BreastCancerDataset (val_df )
234+
235+ train_dataloader = DataLoader (train_data , batch_size = args .batch_size , shuffle = True )
236+ val_dataloader = DataLoader (val_data , batch_size = args .batch_size , shuffle = False )
237+
238+ # Initialize model and client for centralized training
239+ model = SimpleNN (n_input = 30 )
240+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.01 , momentum = 0.9 )
241+ criterion = nn .BCELoss ()
242+ central_client = Client ("central" , model , train_dataloader , val_dataloader , optimizer , criterion )
243+
244+ # Centralized training
245+ for i in range (args .epochs ):
246+ print (f"Epoch { i } " )
247+ central_client .train ()
248+ central_client .validate ()
249+
250+ plot_metrics (central_client )
251+
252+ print ("Accuracy of the centrally trained model:" )
253+ run_prediction (central_client .model , args .bucket_name , args .test_file )
254+
255+ # Federated Learning
256+ fed_model = SimpleNN (n_input = 30 )
257+ clients = list ()
258+ for i in range (2 ):
259+ train_df = load_dataset_from_gcs (args .bucket_name , f"client_{ i } /train_data.csv" )
260+ val_df = load_dataset_from_gcs (args .bucket_name , f"client_{ i } /val_data.csv" )
261+
262+ train_data = BreastCancerDataset (train_df )
263+ val_data = BreastCancerDataset (val_df )
264+
265+ train_dataloader = DataLoader (train_data , batch_size = 7 , shuffle = True )
266+ val_dataloader = DataLoader (val_data , batch_size = 7 , shuffle = False )
267+
268+ optimizer = torch .optim .SGD (fed_model .parameters (), lr = 0.01 , momentum = 0.9 )
269+ criterion = nn .BCELoss ()
270+
271+ clients .append (Client (f"client_{ i } " , fed_model , train_dataloader , val_dataloader , optimizer , criterion ))
272+
273+ fl_server = FLServer (fed_model , clients )
274+
275+ for client in fl_server .clients :
276+ client .model .load_state_dict (fl_server .model .state_dict ())
277+
278+ fl_server .run (epochs = args .epochs )
279+
280+ for client in fl_server .clients :
281+ plot_metrics (client )
282+
283+ print ("Model trained with federated learning accuracy:" )
284+ run_prediction (fl_server .model , args .bucket_name , args .test_file )
285+
286+ # Save the model to GCS
287+ client = storage .Client ()
288+ bucket = client .get_bucket (args .bucket_name )
289+ model_path = os .path .join (args .output_dir , "fed_model.pth" )
290+ torch .save (fed_model .state_dict (), model_path )
291+ bucket .blob (model_path ).upload_from_filename (model_path )
292+
293+ if __name__ == "__main__" :
294+ main ()
0 commit comments