1+ import os
2+ import torch
3+
4+ import pickle
5+ import numpy as np
6+ import pandas as pd
7+ import torch .nn as nn
8+ import matplotlib .pyplot as plt
9+
10+ from kfp .v2 import compiler
11+ from torch .nn import Sequential
12+ from collections import OrderedDict
13+ from google .cloud import aiplatform
14+ from torch .utils .data import Dataset
15+ from torch .utils .data import DataLoader
16+ from kfp .v2 .dsl import component , Output , Dataset , Model , Input , Artifact , Metrics
17+ from sklearn .preprocessing import StandardScaler
18+
19+ class SimpleNN (nn .Module ):
20+ def __init__ (self , n_input ):
21+ super (SimpleNN , self ).__init__ ()
22+ self .NN = Sequential (
23+ nn .Linear (n_input , 32 ),
24+ nn .ReLU (),
25+ nn .Linear (32 , 16 ),
26+ nn .ReLU (),
27+ nn .Linear (16 , 1 ),
28+ nn .Sigmoid ()
29+ )
30+ def forward (self , x ):
31+ return self .NN (x )
32+
33+ class BreastCancerDataset (Dataset ):
34+ def __init__ (self , df ):
35+ scaler = StandardScaler ()
36+ self .X = torch .tensor (scaler .fit_transform (df .iloc [:,1 :- 1 ].values )) # first (ID) and last (diagnisis) columns are excluded
37+ self .y = torch .tensor (df .iloc [:,- 1 ].values ) # load the diagnosis (malignant=1, benign=0)
38+
39+ def __len__ (self ):
40+ return len (self .X )
41+
42+ def __getitem__ (self , idx ):
43+ return self .X [idx ], self .y [idx ]
44+
45+ class Client :
46+ def __init__ (self , name , model , train_loader , val_loader , optimizer , criterion ):
47+ self .name = name
48+ self .model = model
49+ self .optimizer = optimizer
50+ self .criterion = criterion
51+ self .train_loader = train_loader
52+ self .val_loader = val_loader
53+ self .metrics = dict ({"train_acc" : list (), "train_loss" : list (), "val_acc" : list (), "val_loss" : list ()})
54+
55+ print (f"[INFO] Initialized client '{ self .name } ' with { len (train_loader .dataset )} train and { len (val_loader .dataset )} validation samples" )
56+
57+
58+ def train (self ):
59+ """
60+ Trains the model of the client for 1 epoch.
61+ """
62+ self .model .train ()
63+ correct_predictions = 0
64+ running_loss = 0.0
65+
66+ # iterate over training dataset
67+ for inputs , labels in self .train_loader :
68+ # make predictions
69+ self .optimizer .zero_grad ()
70+ outputs = self .model (inputs )
71+ labels = torch .unsqueeze (labels , 1 )
72+
73+ # apply gradient
74+ loss = self .criterion (outputs , labels )
75+ loss .backward ()
76+ self .optimizer .step ()
77+ running_loss += loss .item ()
78+
79+ # calculate number of correct predictions
80+ predicted = torch .round (outputs )
81+ correct_predictions += (predicted == labels ).sum ().item ()
82+
83+ # calculate overall loss and acc.
84+ epoch_loss = running_loss / len (self .train_loader )
85+ accuracy = correct_predictions / len (self .train_loader .dataset )
86+
87+ # save metrics
88+ self .metrics ["train_acc" ].append (accuracy )
89+ self .metrics ["train_loss" ].append (epoch_loss )
90+
91+ def validate (self ):
92+ """
93+ Validates the model of the client based on the given validation data loader.
94+ """
95+ self .model .eval ()
96+ total_loss = 0
97+ correct_predictions = 0
98+
99+ # iterate over validation data loader and make predictions
100+ with torch .no_grad ():
101+ for inputs , labels in self .val_loader :
102+ outputs = self .model (inputs )
103+ labels = torch .unsqueeze (labels , 1 )
104+ loss = self .criterion (outputs , labels )
105+
106+ total_loss += loss .item ()
107+ predicted = torch .round (outputs )
108+ correct_predictions += (predicted == labels ).sum ().item ()
109+
110+ # calculate overall loss and acc.
111+ average_loss = total_loss / len (self .val_loader )
112+ accuracy = correct_predictions / len (self .val_loader .dataset )
113+
114+ # save metrics
115+ self .metrics ["val_acc" ].append (accuracy )
116+ self .metrics ["val_loss" ].append (average_loss )
117+
118+ class FLServer :
119+ def __init__ (self , model , clients ):
120+ self .model = model
121+ self .clients = clients
122+ self .n_data_points = [len (client .train_loader .dataset ) for client in self .clients ]
123+
124+ def run (self , epochs ):
125+ for i in range (epochs ):
126+ print (f"Epoch { i } " )
127+
128+ # Step 2 of figure at the beginning of the tutorial
129+ for client in self .clients :
130+ client .train ()
131+
132+ # aggregate the models using FedAvg (Step 3 & 4 of figure at the beginning of the tutorial)
133+ client_states = [client .model .state_dict () for client in self .clients ] # Step 3
134+ aggregated_state = fed_avg (self .model .state_dict (), client_states , self .n_data_points ) # Step 4
135+ self .model .load_state_dict (aggregated_state )
136+
137+ # redistribute central model (Step 1 of figure at the beginning of the tutorial)
138+ for client in fl_server .clients :
139+ client .model .load_state_dict (aggregated_state )
140+
141+ # run validation of aggregated model
142+ for client in self .clients :
143+ client .validate ()
144+
145+ # repeat for n epochs (Step 5 of figure at the beginning of the tutorial
0 commit comments