Skip to content

Commit 7beba6e

Browse files
authored
Update fhe_template_project.py
1 parent 9cae96b commit 7beba6e

File tree

1 file changed

+231
-57
lines changed

1 file changed

+231
-57
lines changed

fhe_template_project.py

Lines changed: 231 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,63 @@
55
import timeit
66
import networkx as nx
77
from random import random
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
import csv
11+
12+
##### GLOBAL VARIABLES ####
13+
14+
# ARRAYS
15+
visitedArray = []
16+
decryptedPrevAdjMatrix = []
17+
queue = []
18+
res = []
19+
20+
# FLAGS
21+
initialize = notDone = True
22+
23+
# OTHER
24+
vector_size = 4096
25+
nodeCount = 0
26+
eps = 0.4
27+
28+
############################
29+
30+
31+
def BreadFirstTraversal(graph, s, nodeCount):
32+
try:
33+
res = []
34+
visitedArray = [False for i in range(nodeCount)]
35+
queue = [s]
36+
37+
while len(queue):
38+
print(queue)
39+
40+
# At each iteration, pop the element at the beginning of the queue
41+
elem = queue.pop(0)
42+
43+
# Update visitedArray array
44+
if not visitedArray[elem]:
45+
visitedArray[elem] = True
46+
res.append(elem)
47+
48+
# Add adjacent nodes of the current element
49+
for i in range(nodeCount):
50+
# add elem to the queue
51+
# if elem does not exist in the queue and not visitedArray
52+
# and if elem is reachable from the current node
53+
if not visitedArray[i] and queue.count(i) == 0 and graph[elem * nodeCount + i] == 1:
54+
queue.append(i)
55+
56+
print("BFS Traversal: " + str(res))
57+
except:
58+
# in case any error occurs
59+
return -1
60+
61+
# success
62+
return 1
63+
64+
865

966
# Using networkx, generate a random graph
1067
# You can change the way you generate the graph
@@ -24,7 +81,7 @@ def serializeGraphZeroOne(GG,vec_size):
2481
g = []
2582
for row in range(n):
2683
for column in range(n):
27-
if GG.has_edge(row, column) or row==column: # I assumed the vertices are connected to themselves
84+
if GG.has_edge(row, column) or row==column:
2885
weight = 1
2986
else:
3087
weight = 0
@@ -40,31 +97,109 @@ def serializeGraphZeroOne(GG,vec_size):
4097
def printGraph(graph,n):
4198
for row in range(n):
4299
for column in range(n):
43-
print("{:.5f}".format(graph[row*n+column]), end = '\t')
100+
print("{:.2f}".format(graph[row*n+column]), end = '\t')
44101
print()
45102

46103
# Eva requires special input, this function prepares the eva input
47104
# Eva will then encrypt them
48105
def prepareInput(n, m):
49106
input = {}
50107
GG = generateGraph(n,3,0.5)
51-
graph, graphdict = serializeGraphZeroOne(GG,m)
52-
input['Graph'] = graph
53-
return input
108+
serializedGraph, graphdict = serializeGraphZeroOne(GG,m)
109+
printGraph(serializedGraph,n)
110+
input['Graph'] = serializedGraph
111+
return input, serializedGraph
112+
113+
# Check adjacency matrix for reachable elements from the origin
114+
def maskReachableItemsInMatrix(graph, origin, nodeCount):
115+
adjMatrix = [0] * vector_size
116+
selectedNode = 0
117+
118+
for i in range(nodeCount):
119+
if queue.count(i) == 0 and not visitedArray[i]:
120+
# Imagine serialized 2D vector
121+
temp = [1 if j == selectedNode else 0 for j in range(vector_size)]
122+
adjMatrix += (graph<< (origin * nodeCount + i - selectedNode)) * temp
123+
selectedNode += 1
124+
125+
return adjMatrix
126+
127+
def updateDecryptedAdjMatrix(outputs):
128+
global nodeCount
129+
global decryptedPrevAdjMatrix
130+
global eps
131+
132+
for i in outputs:
133+
for j in range(nodeCount):
134+
# Use eps value for floating comparison
135+
checkPrevResultIsOne = (outputs[i][j] < 1.00 + eps) and (outputs[i][j] > 1.00 - eps)
136+
decryptedPrevAdjMatrix.append(True) if checkPrevResultIsOne else decryptedPrevAdjMatrix.append(False)
137+
138+
139+
def initAdjacencyMatrix( start, nodeCount, graph):
140+
global initialize
141+
if(initialize):
142+
initialize = False
143+
queue.append(0)
144+
adjMatrix = maskReachableItemsInMatrix(graph, 0, nodeCount)
145+
return 1, adjMatrix
146+
else:
147+
return 0, []
54148

55149
# This is the dummy analytic service
56150
# You will implement this service based on your selected algorithm
57-
# you can other parameters using global variables !!! do not change the signature of this function
151+
# you can other parameters using global variables !!! do not change the signature of this function
152+
#
153+
# Note that you cannot compute everything using EVA/CKKS
154+
# For instance, comparison is not possible
155+
# You can add, subtract, multiply, negate, shift right/left
156+
# You will have to implement an interface with the trusted entity for comparison (send back the encrypted values, push the trusted entity to compare and get the comparison output)
58157
def graphanalticprogram(graph):
59-
reval = graph<<1 ## Check what kind of operators are there in EVA, this is left shift
60-
# Note that you cannot compute everything using EVA/CKKS
61-
# For instance, comparison is not possible
62-
# You can add, subtract, multiply, negate, shift right/left
63-
# You will have to implement an interface with the trusted entity for comparison (send back the encrypted values, push the trusted entity to compare and get the comparison output)
64-
return reval
65158

159+
global notDone
160+
global nodeCount
161+
global decryptedPrevAdjMatrix
162+
global initialize
163+
global queue
164+
global visitedArray
165+
166+
# initialize adjacency matrix starting from node 0
167+
returned, adjMatrix = initAdjacencyMatrix( 0, nodeCount, graph)
168+
if returned:
169+
return adjMatrix
170+
171+
# start from here if node is not zeroth node
172+
173+
origin = queue[0]
174+
curr = 0
175+
176+
print("queue:" + str(queue))
177+
if not visitedArray[origin]:
178+
visitedArray[origin] = True
179+
res.append(origin)
180+
# Remove first element from queue
181+
queue.pop(0)
182+
183+
184+
for i in range(nodeCount):
185+
186+
if not visitedArray[i] and queue.count(i) == 0:
187+
# reachable from the prev iteration
188+
if decryptedPrevAdjMatrix[curr]:
189+
queue.append(i)
190+
191+
curr += 1
192+
193+
if len(queue) == 0:
194+
return maskReachableItemsInMatrix(graph, origin, nodeCount)
195+
else:
196+
notDone = False
197+
return res
198+
199+
200+
66201
# Do not change this
67-
# the parameter n can be passed in the call from simulate function
202+
# the parameter n can be passed in the call from simulate function
68203
class EvaProgramDriver(EvaProgram):
69204
def __init__(self, name, vec_size=4096, n=4):
70205
self.n = n
@@ -76,80 +211,119 @@ def __enter__(self):
76211
def __exit__(self, exc_type, exc_value, traceback):
77212
super().__exit__(exc_type, exc_value, traceback)
78213

214+
79215
# Repeat the experiments and show averages with confidence intervals
80216
# You can modify the input parameters
81217
# n is the number of nodes in your graph
82218
# If you require additional parameters, add them
83219
def simulate(n):
84-
m = 4096*4
220+
global notDone
221+
global visitedArray
222+
global nodeCount
223+
global res
224+
global eps
225+
global initialize
226+
global decryptedPrevAdjMatrix
227+
228+
m = vector_size
85229
print("Will start simulation for ", n)
86230
config = {}
87231
config['warn_vec_size'] = 'false'
88232
config['lazy_relinearize'] = 'true'
89233
config['rescaler'] = 'always'
90234
config['balance_reductions'] = 'true'
91-
inputs = prepareInput(n, m)
235+
inputs, g= prepareInput(n, m)
92236

93-
graphanaltic = EvaProgramDriver("graphanaltic", vec_size=m,n=n)
94-
with graphanaltic:
95-
graph = Input('Graph')
96-
reval = graphanalticprogram(graph)
97-
Output('ReturnedValue', reval)
98237

99-
prog = graphanaltic
100-
prog.set_output_ranges(30)
101-
prog.set_input_scales(30)
102-
103-
start = timeit.default_timer()
104-
compiler = CKKSCompiler(config=config)
105-
compiled_multfunc, params, signature = compiler.compile(prog)
106-
compiletime = (timeit.default_timer() - start) * 1000.0 #ms
107-
108-
start = timeit.default_timer()
109-
public_ctx, secret_ctx = generate_keys(params)
110-
keygenerationtime = (timeit.default_timer() - start) * 1000.0 #ms
238+
nodeCount = n
239+
initialize = notDone = True
240+
totalCompiletime = totalKeygenerationtime = totalEncryptiontime = 0
241+
totalExecutiontime = totalDecryptiontime = totalReferenceexecutiontime = 0
242+
totalMse = 0
243+
res = []
244+
queue = []
245+
visitedArray = []
246+
247+
isSuccess = BreadFirstTraversal(g, 0, n)
248+
if(not isSuccess):
249+
raise Exception("BFS Algorithm failed.")
250+
111251

112-
start = timeit.default_timer()
113-
encInputs = public_ctx.encrypt(inputs, signature)
114-
encryptiontime = (timeit.default_timer() - start) * 1000.0 #ms
252+
# Clear and init all nodes as not visited
253+
visitedArray = [False] * nodeCount
115254

116-
start = timeit.default_timer()
117-
encOutputs = public_ctx.execute(compiled_multfunc, encInputs)
118-
executiontime = (timeit.default_timer() - start) * 1000.0 #ms
255+
while notDone:
119256

120-
start = timeit.default_timer()
121-
outputs = secret_ctx.decrypt(encOutputs, signature)
122-
decryptiontime = (timeit.default_timer() - start) * 1000.0 #ms
257+
graphanaltic = EvaProgramDriver("graphanaltic", vec_size=m,n=n)
258+
with graphanaltic:
259+
graph = Input('Graph')
260+
reval = graphanalticprogram(graph)
261+
Output('ReturnedValue', reval)
262+
263+
prog = graphanaltic
264+
prog.set_output_ranges(30)
265+
prog.set_input_scales(30)
266+
267+
start = timeit.default_timer()
268+
compiler = CKKSCompiler(config=config)
269+
compiled_multfunc, params, signature = compiler.compile(prog)
270+
totalCompiletime += (timeit.default_timer() - start) * 1000.0 #ms
123271

124-
start = timeit.default_timer()
125-
reference = evaluate(compiled_multfunc, inputs)
126-
referenceexecutiontime = (timeit.default_timer() - start) * 1000.0 #ms
272+
start = timeit.default_timer()
273+
public_ctx, secret_ctx = generate_keys(params)
274+
totalKeygenerationtime = (timeit.default_timer() - start) * 1000.0 #ms
127275

128-
# Change this if you want to output something or comment out the two lines below
129-
for key in outputs:
130-
print(key, float(outputs[key][0]), float(reference[key][0]))
276+
start = timeit.default_timer()
277+
encInputs = public_ctx.encrypt(inputs, signature)
278+
totalEncryptiontime += (timeit.default_timer() - start) * 1000.0 #ms
279+
280+
start = timeit.default_timer()
281+
encOutputs = public_ctx.execute(compiled_multfunc, encInputs)
282+
totalExecutiontime += (timeit.default_timer() - start) * 1000.0 #ms
283+
284+
start = timeit.default_timer()
285+
outputs = secret_ctx.decrypt(encOutputs, signature)
286+
totalDecryptiontime += (timeit.default_timer() - start) * 1000.0 #ms
287+
288+
decryptedPrevAdjMatrix.clear()
289+
290+
# Check previous adjacency matrix values and update the boolean matrix for next iteration
291+
292+
updateDecryptedAdjMatrix(outputs)
293+
294+
start = timeit.default_timer()
295+
reference = evaluate(compiled_multfunc, inputs)
296+
totalReferenceexecutiontime += (timeit.default_timer() - start) * 1000.0 #ms
131297

132-
mse = valuation_mse(outputs, reference) # since CKKS does approximate computations, this is an important measure that depicts the amount of error
298+
299+
totalMse += valuation_mse(outputs, reference) # since CKKS does approximate computations, this is an important measure that depicts the amount of error
133300

134-
return compiletime, keygenerationtime, encryptiontime, executiontime, decryptiontime, referenceexecutiontime, mse
301+
return totalCompiletime, totalKeygenerationtime, totalEncryptiontime, totalExecutiontime, totalDecryptiontime, totalReferenceexecutiontime, totalMse
302+
135303

136304

137305
if __name__ == "__main__":
138-
simcnt = 3 #The number of simulation runs, set it to 3 during development otherwise you will wait for a long time
306+
simcnt = 5 #The number of simulation runs, set it to 3 during development otherwise you will wait for a long time
139307
# For benchmarking you must set it to a large number, e.g., 100
140308
#Note that file is opened in append mode, previous results will be kept in the file
141-
resultfile = open("results.csv", "a") # Measurement results are collated in this file for you to plot later on
142-
resultfile.write("NodeCount,SimCnt,CompileTime,KeyGenerationTime,EncryptionTime,ExecutionTime,DecryptionTime,ReferenceExecutionTime,Mse\n")
309+
resultfile = open("results.csv", "w") # Measurement results are collated in this file for you to plot later on
310+
resultfile.write("NodeCount,SimCnt,totalCompiletime,KeyGenerationTime,EncryptionTime,ExecutionTime,DecryptionTime,ReferenceExecutionTime,Mse\n")
143311
resultfile.close()
144312

145313
print("Simulation campaing started:")
146-
for nc in range(36,64,4): # Node counts for experimenting various graph sizes
314+
315+
for nc in range(8,50,4): # Node counts for experimenting various graph sizes
147316
n = nc
317+
148318
resultfile = open("results.csv", "a")
149319
for i in range(simcnt):
150320
#Call the simulator
151-
compiletime, keygenerationtime, encryptiontime, executiontime, decryptiontime, referenceexecutiontime, mse = simulate(n)
152-
res = str(n) + "," + str(i) + "," + str(compiletime) + "," + str(keygenerationtime) + "," + str(encryptiontime) + "," + str(executiontime) + "," + str(decryptiontime) + "," + str(referenceexecutiontime) + "," + str(mse) + "\n"
153-
print(res)
321+
totalCompiletime, totalKeygenerationtime, totalEncryptiontime, totalExecutiontime, totalDecryptiontime, totalReferenceexecutiontime, totalMse = simulate(n)
322+
res = str(n) + "," + str(i) + "," + str(totalCompiletime) + "," + str(totalKeygenerationtime) + "," + str(totalEncryptiontime) + "," + str(totalExecutiontime) + "," + str(totalDecryptiontime) + "," + str(totalReferenceexecutiontime) + "," + str(totalMse) + "\n"
154323
resultfile.write(res)
155-
resultfile.close()
324+
325+
resultfile.close()
326+
327+
#plotResults()
328+
329+

0 commit comments

Comments
 (0)