-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparticipate.py
More file actions
110 lines (90 loc) · 2.93 KB
/
participate.py
File metadata and controls
110 lines (90 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Environment Variables
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Imports
import numpy as np
from cit_fl_participation.config import StorageConfig
from cit_fl_participation.participation import (
AbstractTrainer,
ConnectorGrpc,
S3GlobalWeightsReader,
S3LocalWeightsWriter,
ParticipationServicer,
)
from model.predictor import mlModel
# Functions
def export_weights_tf(self, model):
'''
Converts the tf-model-weights to an 1D numpy array
'''
return np.concatenate([
w.flatten()
for w in model.get_weights()
])
def import_weights_tf(self, model, weights):
'''
Imports weights into the tf-model from a given 1D numpy array
'''
start = 0
weights_list = []
for w in model.get_weights():
weights_list.append(weights[start:start+w.size].reshape(w.shape))
start += w.size
model.set_weights(weights_list)
def round_callback(model:mlModel):
print('TODO: Setup scoring metric')
# Classes
class MyTrainer(AbstractTrainer):
def __init__(self, predictor:mlModel, X, y, round_callback):
super().__init__() # not absolutely necessary
self.X = X
self.y = y
self.round_callback = round_callback
self.predictor = predictor
print("Before FL")
round_callback(self.predictor)
def train(self, import_weights, epochs):
if import_weights.size > 0:
import_weights_tf(self.predictor.model, import_weights)
# Do some training
round_callback(self.predictor)
return (export_weights_tf(self.predictor.model), self.X.shape[0])
def training_finished(self, final_weights):
import_weights_tf(self.predictor.model, import_weights)
print("Final Score")
round_callback(self.predictor)
#### MAIN ####
def main():
# select any token you want
api_token = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
# Fedml data storage
storageConfig = StorageConfig(
endpoint="http://49.12.108.67:9000",
bucket="cit-fl-demo",
access_key_id = "cit-fl-demo",
secret_access_key = "Wy8#oS3U#$q7o40%"
)
reader = S3GlobalWeightsReader(storageConfig)
writer = S3LocalWeightsWriter(storageConfig)
# FedML server connector
connector = ConnectorGrpc(
heartbeat_time=1,
coordinator_url="49.12.108.67:5051",
api_token=api_token,
tsl_certificate=None,
local_weights_writer=writer,
global_weights_reader=reader
)
# load training data
X, y = np.zeros(shape=(8, 512, 512, 3)), np.zeros(shape=(8))
# create model
predictor = mlModel()
# create trainer
trainer = MyTrainer(predictor, X, y, round_callback)
# participate using trainer
p = ParticipationServicer(connector, trainer)
print('Begining federated learning...')
p.start()
print('Finished learning')
if __name__ == "__main__":
main()