Skip to content

Commit aec8efe

Browse files
Added examples/pke/interactive-bootstrapping.py, added symbols to bindings and got rid of a warning from CMakeLists.txt (#228)
Co-authored-by: Dmitriy Suponitskiy <[email protected]>
1 parent 07a81b4 commit aec8efe

File tree

3 files changed

+290
-22
lines changed

3 files changed

+290
-22
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if(APPLE)
1616
endif()
1717

1818
find_package(OpenFHE 1.3.0 REQUIRED)
19+
set(PYBIND11_FINDPYTHON ON)
1920
find_package(pybind11 REQUIRED)
2021

2122
# "CMAKE_INTERPROCEDURAL_OPTIMIZATION ON" (ON is the default value) causes link failure. see
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
from openfhe import *
2+
3+
4+
def main():
5+
# the scaling technigue can be changed to FIXEDMANUAL, FIXEDAUTO, or FLEXIBLEAUTOEXT
6+
ThresholdFHE(FLEXIBLEAUTO)
7+
Chebyshev(FLEXIBLEAUTO)
8+
9+
def ThresholdFHE(scaleTech):
10+
# if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTOEXT]:
11+
# errMsg = "ERROR: Scaling technique is not supported!"
12+
# raise Exception(errMsg)
13+
14+
print(f"Threshold FHE example with Scaling Technique {scaleTech}")
15+
16+
parameters = CCParamsCKKSRNS()
17+
# 1 extra level needs to be added for FIXED* modes (2 extra levels for FLEXIBLE* modes) to the multiplicative depth
18+
# to support 2-party interactive bootstrapping
19+
depth = 7
20+
parameters.SetMultiplicativeDepth(depth)
21+
parameters.SetScalingModSize(50)
22+
parameters.SetBatchSize(16)
23+
parameters.SetScalingTechnique(scaleTech)
24+
25+
cc = GenCryptoContext(parameters)
26+
cc.Enable(PKE)
27+
cc.Enable(LEVELEDSHE)
28+
cc.Enable(ADVANCEDSHE)
29+
cc.Enable(MULTIPARTY)
30+
31+
#############################################################
32+
# Perform Key Generation Operation
33+
#############################################################
34+
35+
print("Running key generation (used for source data)...")
36+
print("Round 1 (party A) started.")
37+
38+
kp1 = cc.KeyGen()
39+
evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
40+
41+
print("Round 1 of key generation completed.")
42+
#############################################################
43+
print("Round 2 (party B) started.")
44+
print("Joint public key for (s_a + s_b) is generated...")
45+
kp2 = cc.MultipartyKeyGen(kp1.publicKey)
46+
47+
input = [-0.9, -0.8, -0.6, -0.4, -0.2, 0., 0.2, 0.4, 0.6, 0.8, 0.9]
48+
49+
# This plaintext only has 3 RNS limbs, the minimum needed to perform 2-party interactive bootstrapping for FLEXIBLEAUTO
50+
plaintext1 = cc.MakeCKKSPackedPlaintext(input, 1, depth - 2)
51+
ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
52+
53+
# INTERACTIVE BOOTSTRAPPING STARTS
54+
55+
# under the hood it reduces to two towers
56+
ciphertext1 = cc.IntBootAdjustScale(ciphertext1)
57+
print("IntBootAdjustScale Succeeded")
58+
59+
# masked decryption on the server: c0 = b + a*s0
60+
ciphertextOutput1 = cc.IntBootDecrypt(kp1.secretKey, ciphertext1)
61+
print("IntBootDecrypt on Server Succeeded")
62+
63+
ciphertext2 = ciphertext1.Clone()
64+
ciphertext2.SetElements([ciphertext2.GetElements()[1]])
65+
66+
# masked decryption on the client: c1 = a*s1
67+
ciphertextOutput2 = cc.IntBootDecrypt(kp2.secretKey, ciphertext2)
68+
print("IntBootDecrypt on Client Succeeded")
69+
70+
# Encryption of masked decryption c1 = a*s1
71+
ciphertextOutput2 = cc.IntBootEncrypt(kp2.publicKey, ciphertextOutput2)
72+
print("IntBootEncrypt on Client Succeeded")
73+
74+
# Compute Enc(c1) + c0
75+
ciphertextOutput = cc.IntBootAdd(ciphertextOutput2, ciphertextOutput1)
76+
print("IntBootAdd on Server Succeeded")
77+
78+
# INTERACTIVE BOOTSTRAPPING ENDS
79+
80+
# distributed decryption
81+
ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
82+
ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
83+
84+
partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
85+
plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
86+
87+
plaintextMultiparty.SetLength(len(input))
88+
89+
print(f"Original plaintext \n\t {plaintext1.GetCKKSPackedValue()}")
90+
print(f"Result after bootstrapping \n\t {plaintextMultiparty.GetCKKSPackedValue()}")
91+
92+
def Chebyshev(scaleTech):
93+
# if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTOEXT]:
94+
# errMsg = "ERROR: Scaling technique is not supported!"
95+
# raise Exception(errMsg)
96+
97+
print(f"Threshold FHE example with Scaling Technique {scaleTech}")
98+
99+
parameters = CCParamsCKKSRNS()
100+
# 1 extra level needs to be added for FIXED* modes (2 extra levels for FLEXIBLE* modes) to the multiplicative depth
101+
# to support 2-party interactive bootstrapping
102+
parameters.SetMultiplicativeDepth(8)
103+
parameters.SetScalingModSize(50)
104+
parameters.SetBatchSize(16)
105+
parameters.SetScalingTechnique(scaleTech)
106+
107+
cc = GenCryptoContext(parameters)
108+
# enable features that you wish to use
109+
cc.Enable(PKE)
110+
cc.Enable(LEVELEDSHE)
111+
cc.Enable(ADVANCEDSHE)
112+
cc.Enable(MULTIPARTY)
113+
114+
############################################################
115+
# Perform Key Generation Operation
116+
############################################################
117+
118+
print("Running key generation (used for source data)...")
119+
print("Round 1 (party A) started.")
120+
121+
kp1 = cc.KeyGen()
122+
123+
evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
124+
cc.EvalSumKeyGen(kp1.secretKey)
125+
evalSumKeys = cc.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
126+
127+
print("Round 1 of key generation completed.")
128+
############################################################
129+
print("Round 2 (party B) started.")
130+
print("Joint public key for (s_a + s_b) is generated...")
131+
kp2 = cc.MultipartyKeyGen(kp1.publicKey)
132+
133+
evalMultKey2 = cc.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
134+
135+
print("Joint evaluation multiplication key for (s_a + s_b) is generated...")
136+
evalMultAB = cc.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
137+
138+
print("Joint evaluation multiplication key (s_a + s_b) is transformed into s_b*(s_a + s_b)...")
139+
evalMultBAB = cc.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
140+
141+
evalSumKeysB = cc.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
142+
143+
print("Joint evaluation summation key for (s_a + s_b) is generated...")
144+
evalSumKeysJoin = cc.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
145+
146+
cc.InsertEvalSumKey(evalSumKeysJoin)
147+
148+
print("Round 2 of key generation completed.")
149+
150+
print("Round 3 (party A) started.")
151+
print("Joint key (s_a + s_b) is transformed into s_a*(s_a + s_b)...")
152+
evalMultAAB = cc.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
153+
154+
print("Computing the final evaluation multiplication key for (s_a + s_b)*(s_a + s_b)...")
155+
evalMultFinal = cc.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
156+
157+
cc.InsertEvalMultKey([evalMultFinal])
158+
159+
print("Round 3 of key generation completed.")
160+
161+
input = [-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
162+
163+
coefficients = [1.0, 0.558971, 0.0, -0.0943712, 0.0, 0.0215023, 0.0, -0.00505348, 0.0, 0.00119324,
164+
0.0, -0.000281928, 0.0, 0.0000664347, 0.0, -0.0000148709]
165+
166+
a = -4
167+
b = 4
168+
169+
plaintext1 = cc.MakeCKKSPackedPlaintext(input)
170+
171+
ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
172+
173+
# The Chebyshev series interpolation requires 6 levels
174+
ciphertext1 = cc.EvalChebyshevSeries(ciphertext1, coefficients, a, b)
175+
print("Ran Chebyshev interpolation")
176+
177+
# INTERACTIVE BOOTSTRAPPING STARTS
178+
179+
ciphertext1 = cc.IntBootAdjustScale(ciphertext1)
180+
print("IntBootAdjustScale Succeeded")
181+
182+
# masked decryption on the server: c0 = b + a*s0
183+
ciphertextOutput1 = cc.IntBootDecrypt(kp1.secretKey, ciphertext1)
184+
print("IntBootDecrypt on Server Succeeded")
185+
186+
ciphertext2 = ciphertext1.Clone()
187+
ciphertext2.SetElements([ciphertext2.GetElements()[1]])
188+
189+
# masked decryption on the client: c1 = a*s1
190+
ciphertextOutput2 = cc.IntBootDecrypt(kp2.secretKey, ciphertext2)
191+
print("IntBootDecrypt on Client Succeeded")
192+
193+
# Encryption of masked decryption c1 = a*s1
194+
ciphertextOutput2 = cc.IntBootEncrypt(kp2.publicKey, ciphertextOutput2)
195+
print("IntBootEncrypt on Client Succeeded")
196+
197+
# Compute Enc(c1) + c0
198+
ciphertextOutput = cc.IntBootAdd(ciphertextOutput2, ciphertextOutput1)
199+
print("IntBootAdd on Server Succeeded")
200+
201+
# INTERACTIVE BOOTSTRAPPING ENDS
202+
203+
# distributed decryption
204+
205+
ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
206+
207+
ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
208+
209+
partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
210+
plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
211+
212+
plaintextMultiparty.SetLength(len(input))
213+
214+
print(f"\n Original Plaintext #1: \n {plaintext1}")
215+
216+
print(f"\n Results of evaluating the polynomial with coefficients {coefficients} \n")
217+
print(f"\n Ciphertext result: {plaintextMultiparty}")
218+
219+
print("\n Plaintext result: ( 0.0179885, 0.0474289, 0.119205, 0.268936, 0.5, 0.731064, 0.880795, 0.952571, 0.982011 ) \n")
220+
221+
print("\n Exact result: ( 0.0179862, 0.0474259, 0.119203, 0.268941, 0.5, 0.731059, 0.880797, 0.952574, 0.982014 ) \n")
222+
223+
print("\n Another round of Chebyshev interpolation after interactive bootstrapping: \n")
224+
225+
ciphertextOutput = cc.EvalChebyshevSeries(ciphertextOutput, coefficients, a, b)
226+
print("Ran Chebyshev interpolation")
227+
228+
# distributed decryption
229+
230+
ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
231+
232+
ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
233+
234+
partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
235+
plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
236+
237+
plaintextMultiparty.SetLength(len(input))
238+
239+
print(f"\n Ciphertext result: {plaintextMultiparty}")
240+
241+
print("\n Plaintext result: ( 0.504497, 0.511855, 0.529766, 0.566832, 0.622459, 0.675039, 0.706987, 0.721632, 0.727508 )")
242+
243+
244+
if __name__ == "__main__":
245+
main()

src/lib/bindings.cpp

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ inline std::shared_ptr<CryptoParametersRNS> GetParamsRNSChecked(const CryptoCont
6161
return ptr;
6262
}
6363

64+
void bind_DCRTPoly(py::module &m) {
65+
py::class_<DCRTPoly>(m, "DCRTPoly").def(py::init<>());
66+
}
67+
6468
template <typename T>
6569
void bind_parameters(py::module &m,const std::string name)
6670
{
@@ -1312,28 +1316,45 @@ void bind_encodings(py::module &m)
13121316

13131317
void bind_ciphertext(py::module &m)
13141318
{
1315-
py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
1316-
.def(py::init<>())
1317-
.def("__add__", [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b)
1318-
{return a + b; },py::is_operator(),pybind11::keep_alive<0, 1>())
1319-
// .def(py::self + py::self);
1320-
// .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
1321-
// .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
1322-
.def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel,
1323-
ctx_GetLevel_docs)
1324-
.def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel,
1325-
ctx_SetLevel_docs,
1326-
py::arg("level"))
1327-
.def("Clone", &CiphertextImpl<DCRTPoly>::Clone)
1328-
.def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
1329-
// .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
1330-
// .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
1331-
// .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
1332-
// .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
1333-
.def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
1334-
.def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots)
1335-
.def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg)
1336-
.def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg);
1319+
py::class_<CiphertextImpl<DCRTPoly>,
1320+
std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
1321+
.def(py::init<>())
1322+
.def(
1323+
"__add__",
1324+
[](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b) {
1325+
return a + b;
1326+
},
1327+
py::is_operator(), pybind11::keep_alive<0, 1>())
1328+
// .def(py::self + py::self);
1329+
// .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
1330+
// .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
1331+
.def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel, ctx_GetLevel_docs)
1332+
.def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel, ctx_SetLevel_docs,
1333+
py::arg("level"))
1334+
.def("Clone", &CiphertextImpl<DCRTPoly>::Clone)
1335+
.def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
1336+
// .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
1337+
// .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
1338+
// .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
1339+
// .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
1340+
.def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
1341+
.def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots)
1342+
.def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg)
1343+
.def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg)
1344+
.def("GetElements", [](const CiphertextImpl<DCRTPoly>& self) -> const std::vector<DCRTPoly> & {
1345+
return self.GetElements();
1346+
},
1347+
py::return_value_policy::reference_internal)
1348+
.def("GetElementsMutable", [](CiphertextImpl<DCRTPoly>& self) -> std::vector<DCRTPoly> & {
1349+
return self.GetElements();
1350+
},
1351+
py::return_value_policy::reference_internal)
1352+
.def("SetElements", [](CiphertextImpl<DCRTPoly>& self, const std::vector<DCRTPoly> &elems) {
1353+
self.SetElements(elems);
1354+
})
1355+
.def("SetElementsMove", [](CiphertextImpl<DCRTPoly>& self, std::vector<DCRTPoly> &&elems) {
1356+
self.SetElements(std::move(elems));
1357+
});
13371358
}
13381359

13391360
void bind_schemes(py::module &m){
@@ -1400,6 +1421,7 @@ PYBIND11_MODULE(openfhe, m)
14001421
{
14011422
m.doc() = "Open-Source Fully Homomorphic Encryption Library";
14021423
// binfhe library
1424+
bind_DCRTPoly(m);
14031425
bind_binfhe_enums(m);
14041426
bind_binfhe_context(m);
14051427
bind_binfhe_keys(m);

0 commit comments

Comments
 (0)