Skip to content

Commit 869cdd3

Browse files
authored
Merge pull request #160 from openfheorg/156-add-deserialize-keymaps
Add (De)Serialize for key maps
2 parents c615798 + b86ef34 commit 869cdd3

File tree

3 files changed

+254
-11
lines changed

3 files changed

+254
-11
lines changed

src/include/pke/serialization.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,28 @@ T DeserializeFromStringWrapper(const std::string& str, const ST& sertype);
5656
template <typename T, typename ST>
5757
T DeserializeFromBytesWrapper(const py::bytes& bytes, const ST& sertype);
5858

59+
template <typename ST>
60+
std::string SerializeEvalMultKeyToStringWrapper(const ST& sertype, const std::string& id);
61+
62+
template <typename ST>
63+
py::bytes SerializeEvalMultKeyToBytesWrapper(const ST& sertype, const std::string& id);
64+
65+
template <typename ST>
66+
std::string SerializeEvalAutomorphismKeyToStringWrapper(const ST& sertype, const std::string& id);
67+
68+
template <typename ST>
69+
py::bytes SerializeEvalAutomorphismKeyToBytesWrapper(const ST& sertype, const std::string& id);
70+
71+
template <typename ST>
72+
void DeserializeEvalMultKeyFromStringWrapper(const std::string& data, const ST& sertype);
73+
74+
template <typename ST>
75+
void DeserializeEvalMultKeyFromBytesWrapper(const std::string& data, const ST& sertype);
76+
77+
template <typename ST>
78+
void DeserializeEvalAutomorphismKeyFromStringWrapper(const std::string& data, const ST& sertype);
79+
80+
template <typename ST>
81+
void DeserializeEvalAutomorphismKeyFromBytesWrapper(const std::string& data, const ST& sertype);
82+
5983
#endif // OPENFHE_SERIALIZATION_BINDINGS_H

src/lib/pke/serialization.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,87 @@ CryptoContext<DCRTPoly> DeserializeCCFromBytesWrapper(const py::bytes& bytes, co
137137
return obj;
138138
}
139139

140+
template <typename ST>
141+
std::string SerializeEvalMultKeyToStringWrapper(const ST& sertype, const std::string& id) {
142+
std::ostringstream oss;
143+
bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey(oss, sertype, id);
144+
if (!res) {
145+
throw std::runtime_error("Failed to serialize EvalMultKey");
146+
}
147+
return oss.str();
148+
}
149+
150+
template <typename ST>
151+
py::bytes SerializeEvalMultKeyToBytesWrapper(const ST& sertype, const std::string& id) {
152+
std::ostringstream oss(std::ios::binary);
153+
bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey(oss, sertype, id);
154+
if (!res) {
155+
throw std::runtime_error("Failed to serialize EvalMultKey");
156+
}
157+
std::string str = oss.str();
158+
return py::bytes(str);
159+
}
160+
161+
162+
template <typename ST>
163+
std::string SerializeEvalAutomorphismKeyToStringWrapper(const ST& sertype, const std::string& id) {
164+
std::ostringstream oss;
165+
bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey(oss, sertype, id);
166+
if (!res) {
167+
throw std::runtime_error("Failed to serialize EvalAutomorphismKey");
168+
}
169+
return oss.str();
170+
}
171+
172+
173+
template <typename ST>
174+
py::bytes SerializeEvalAutomorphismKeyToBytesWrapper(const ST& sertype, const std::string& id) {
175+
std::ostringstream oss(std::ios::binary);
176+
bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey(oss, sertype, id);
177+
if (!res) {
178+
throw std::runtime_error("Failed to serialize EvalAutomorphismKey");
179+
}
180+
return oss.str();
181+
}
182+
183+
template <typename ST>
184+
void DeserializeEvalMultKeyFromStringWrapper(const std::string& data, const ST& sertype) {
185+
std::istringstream iss(data);
186+
bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(iss, sertype);
187+
if (!res) {
188+
throw std::runtime_error("Failed to deserialize EvalMultKey");
189+
}
190+
}
191+
192+
template <typename ST>
193+
void DeserializeEvalMultKeyFromBytesWrapper(const std::string& data, const ST& sertype) {
194+
std::string str(data);
195+
std::istringstream iss(str, std::ios::binary);
196+
bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(iss, sertype);
197+
if (!res) {
198+
throw std::runtime_error("Failed to deserialize EvalMultKey");
199+
}
200+
}
201+
202+
template <typename ST>
203+
void DeserializeEvalAutomorphismKeyFromStringWrapper(const std::string& data, const ST& sertype) {
204+
std::istringstream iss(data);
205+
std::map<std::string, std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>>> keyMap;
206+
bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<ST>(iss, sertype);
207+
if (!res) {
208+
throw std::runtime_error("Failed to deserialize EvalAutomorphismKey");
209+
}
210+
}
211+
212+
template <typename ST>
213+
void DeserializeEvalAutomorphismKeyFromBytesWrapper(const std::string& data, const ST& sertype) {
214+
std::string str(data);
215+
std::istringstream iss(str, std::ios::binary);
216+
bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<ST>(iss, sertype);
217+
if (!res) {
218+
throw std::runtime_error("Failed to deserialize EvalAutomorphismKey");
219+
}
220+
}
140221

141222
void bind_serialization(pybind11::module &m) {
142223
// Json Serialization
@@ -182,6 +263,14 @@ void bind_serialization(pybind11::module &m) {
182263
py::arg("obj"), py::arg("sertype"));
183264
m.def("DeserializeEvalKeyString", &DeserializeFromStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
184265
py::arg("str"), py::arg("sertype"));
266+
m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToStringWrapper<SerType::SERJSON>,
267+
py::arg("sertype"), py::arg("id") = "");
268+
m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromStringWrapper<SerType::SERJSON>,
269+
py::arg("sertype"), py::arg("id") = "");
270+
m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToStringWrapper<SerType::SERJSON>,
271+
py::arg("sertype"), py::arg("id") = "");
272+
m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromStringWrapper<SerType::SERJSON>,
273+
py::arg("sertype"), py::arg("id") = "");
185274

186275
// Binary Serialization
187276
m.def("SerializeToFile", static_cast<bool (*)(const std::string&,const CryptoContext<DCRTPoly>&, const SerType::SERBINARY&)>(&Serial::SerializeToFile<DCRTPoly>),
@@ -226,4 +315,12 @@ void bind_serialization(pybind11::module &m) {
226315
py::arg("obj"), py::arg("sertype"));
227316
m.def("DeserializeEvalKeyString", &DeserializeFromBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
228317
py::arg("str"), py::arg("sertype"));
318+
m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToBytesWrapper<SerType::SERBINARY>,
319+
py::arg("sertype"), py::arg("id") = "");
320+
m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromBytesWrapper<SerType::SERBINARY>,
321+
py::arg("sertype"), py::arg("id") = "");
322+
m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToBytesWrapper<SerType::SERBINARY>,
323+
py::arg("sertype"), py::arg("id") = "");
324+
m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromBytesWrapper<SerType::SERBINARY>,
325+
py::arg("sertype"), py::arg("id") = "");
229326
}

tests/test_serial_cc.py

Lines changed: 133 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,82 @@ def test_serial_cryptocontext(tmp_path):
4040
assert fhe.SerializeToFile(str(tmp_path / "ciphertext12.json"), ct1, fhe.JSON)
4141

4242

43+
VECTOR1_ROTATION = 1
44+
VECTOR2_ROTATION = 2
45+
VECTOR3_ROTATION = -1
46+
VECTOR4_ROTATION = -2
47+
4348
@pytest.mark.parametrize("mode", [fhe.JSON, fhe.BINARY])
4449
def test_serial_cryptocontext_str(mode):
4550
parameters = fhe.CCParamsBFVRNS()
4651
parameters.SetPlaintextModulus(65537)
4752
parameters.SetMultiplicativeDepth(2)
4853

4954
cryptoContext = fhe.GenCryptoContext(parameters)
50-
cryptoContext.Enable(fhe.PKESchemeFeature.PKE)
55+
cryptoContext.Enable(fhe.PKE)
56+
cryptoContext.Enable(fhe.KEYSWITCH)
57+
cryptoContext.Enable(fhe.LEVELEDSHE)
5158
cryptoContext.Enable(fhe.PKESchemeFeature.PRE)
5259

5360
keypair = cryptoContext.KeyGen()
54-
vectorOfInts = list(range(12))
55-
plaintext = cryptoContext.MakePackedPlaintext(vectorOfInts)
56-
ciphertext = cryptoContext.Encrypt(keypair.publicKey, plaintext)
57-
evalKey = cryptoContext.ReKeyGen(keypair.secretKey, keypair.publicKey)
5861

62+
# First plaintext vector is encoded
63+
vectorOfInts1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
64+
plaintext1 = cryptoContext.MakePackedPlaintext(vectorOfInts1)
65+
66+
# Second plaintext vector is encoded
67+
vectorOfInts2 = [3, 2, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12]
68+
plaintext2 = cryptoContext.MakePackedPlaintext(vectorOfInts2)
69+
70+
# Third plaintext vector is encoded
71+
vectorOfInts3 = [1, 2, 5, 2, 5, 6, 7, 8, 9, 10, 11, 12]
72+
plaintext3 = cryptoContext.MakePackedPlaintext(vectorOfInts3)
73+
74+
# Create a final array adding the three vectors
75+
initialPlaintextAddResult = [vectorOfInts1[i] + vectorOfInts2[i] + vectorOfInts3[i] for i in range(len(vectorOfInts1))]
76+
initialPlaintextAddResult = cryptoContext.MakePackedPlaintext(initialPlaintextAddResult)
77+
78+
# Multiply the values
79+
initialPlaintextMultResult = [vectorOfInts1[i] * vectorOfInts2[i] * vectorOfInts3[i] for i in range(len(vectorOfInts1))]
80+
initialPlaintextMultResult = cryptoContext.MakePackedPlaintext(initialPlaintextMultResult)
81+
82+
# Rotate the values
83+
initialPlaintextRot1 = rotate_vector(vectorOfInts1, VECTOR1_ROTATION)
84+
initialPlaintextRot1 = cryptoContext.MakePackedPlaintext(initialPlaintextRot1)
85+
initialPlaintextRot2 = rotate_vector(vectorOfInts2, VECTOR2_ROTATION)
86+
initialPlaintextRot2 = cryptoContext.MakePackedPlaintext(initialPlaintextRot2)
87+
initialPlaintextRot3 = rotate_vector(vectorOfInts3, VECTOR3_ROTATION)
88+
initialPlaintextRot3 = cryptoContext.MakePackedPlaintext(initialPlaintextRot3)
89+
initialPlaintextRot4 = rotate_vector(vectorOfInts3, VECTOR4_ROTATION)
90+
initialPlaintextRot4 = cryptoContext.MakePackedPlaintext(initialPlaintextRot4)
91+
92+
# The encoded vectors are encrypted
93+
ciphertext1 = cryptoContext.Encrypt(keypair.publicKey, plaintext1)
94+
ciphertext2 = cryptoContext.Encrypt(keypair.publicKey, plaintext2)
95+
ciphertext3 = cryptoContext.Encrypt(keypair.publicKey, plaintext3)
96+
97+
evalKey = cryptoContext.ReKeyGen(keypair.secretKey, keypair.publicKey)
98+
cryptoContext.EvalMultKeyGen(keypair.secretKey)
99+
cryptoContext.EvalRotateKeyGen(keypair.secretKey, [VECTOR1_ROTATION, VECTOR2_ROTATION, VECTOR3_ROTATION, VECTOR4_ROTATION])
59100

60101
cryptoContext_ser = fhe.Serialize(cryptoContext, mode)
61102
LOGGER.debug("The cryptocontext has been serialized.")
62103
publickey_ser = fhe.Serialize(keypair.publicKey, mode)
63104
LOGGER.debug("The public key has been serialized.")
64105
secretkey_ser = fhe.Serialize(keypair.secretKey, mode)
65106
LOGGER.debug("The private key has been serialized.")
66-
ciphertext_ser = fhe.Serialize(ciphertext, mode)
67-
LOGGER.debug("The ciphertext has been serialized.")
107+
ciphertext1_ser = fhe.Serialize(ciphertext1, mode)
108+
LOGGER.debug("The ciphertext 1 has been serialized.")
109+
ciphertext2_ser = fhe.Serialize(ciphertext2, mode)
110+
LOGGER.debug("The ciphertext 2 has been serialized.")
111+
ciphertext3_ser = fhe.Serialize(ciphertext3, mode)
112+
LOGGER.debug("The ciphertext 3 has been serialized.")
68113
evalKey_ser = fhe.Serialize(evalKey, mode)
69114
LOGGER.debug("The evaluation key has been serialized.")
70-
115+
multKey_ser = fhe.SerializeEvalMultKeyString(mode, "")
116+
LOGGER.debug("The relinearization key has been serialized.")
117+
automorphismKey_ser = fhe.SerializeEvalAutomorphismKeyString(mode, "")
118+
LOGGER.debug("The rotation evaluation keys have been serialized.")
71119

72120
cryptoContext.ClearEvalMultKeys()
73121
cryptoContext.ClearEvalAutomorphismKeys()
@@ -85,10 +133,84 @@ def test_serial_cryptocontext_str(mode):
85133
assert isinstance(sk, fhe.PrivateKey)
86134
LOGGER.debug("The private key has been deserialized.")
87135

88-
ct = fhe.DeserializeCiphertextString(ciphertext_ser, mode)
89-
assert isinstance(ct, fhe.Ciphertext)
90-
LOGGER.debug("The ciphertext has been reserialized.")
136+
ct1 = fhe.DeserializeCiphertextString(ciphertext1_ser, mode)
137+
assert isinstance(ct1, fhe.Ciphertext)
138+
LOGGER.debug("The ciphertext 1 has been reserialized.")
139+
140+
ct2 = fhe.DeserializeCiphertextString(ciphertext2_ser, mode)
141+
assert isinstance(ct2, fhe.Ciphertext)
142+
LOGGER.debug("The ciphertext 2 has been reserialized.")
143+
144+
ct3 = fhe.DeserializeCiphertextString(ciphertext3_ser, mode)
145+
assert isinstance(ct3, fhe.Ciphertext)
146+
LOGGER.debug("The ciphertext 3 has been reserialized.")
91147

92148
ek = fhe.DeserializeEvalKeyString(evalKey_ser, mode)
93149
assert isinstance(ek, fhe.EvalKey)
94150
LOGGER.debug("The evaluation key has been deserialized.")
151+
152+
fhe.DeserializeEvalMultKeyString(multKey_ser, mode)
153+
LOGGER.debug("The relinearization key has been deserialized.")
154+
155+
fhe.DeserializeEvalAutomorphismKeyString(automorphismKey_ser, mode)
156+
LOGGER.debug("The rotation evaluation keys have been deserialized.")
157+
158+
# Homomorphic addition
159+
160+
ciphertextAdd12 = cc.EvalAdd(ct1, ct2)
161+
ciphertextAddResult = cc.EvalAdd(ciphertextAdd12, ct3)
162+
163+
# Homomorphic multiplication
164+
ciphertextMult12 = cc.EvalMult(ct1, ct2)
165+
ciphertextMultResult = cc.EvalMult(ciphertextMult12, ct3)
166+
167+
# Homomorphic rotation
168+
ciphertextRot1 = cc.EvalRotate(ct1, VECTOR1_ROTATION)
169+
ciphertextRot2 = cc.EvalRotate(ct2, VECTOR2_ROTATION)
170+
ciphertextRot3 = cc.EvalRotate(ct3, VECTOR3_ROTATION)
171+
ciphertextRot4 = cc.EvalRotate(ct3, VECTOR4_ROTATION)
172+
173+
# Decrypt the result of additions
174+
plaintextAddResult = cc.Decrypt(sk, ciphertextAddResult)
175+
176+
# Decrypt the result of multiplications
177+
plaintextMultResult = cc.Decrypt(sk, ciphertextMultResult)
178+
179+
# Decrypt the result of rotations
180+
plaintextRot1 = cc.Decrypt(sk, ciphertextRot1)
181+
plaintextRot2 = cc.Decrypt(sk, ciphertextRot2)
182+
plaintextRot3 = cc.Decrypt(sk, ciphertextRot3)
183+
plaintextRot4 = cc.Decrypt(sk, ciphertextRot4)
184+
185+
# Shows only the same number of elements as in the original plaintext vector
186+
# By default it will show all coefficients in the BFV-encoded polynomial
187+
plaintextRot1.SetLength(len(vectorOfInts1))
188+
plaintextRot2.SetLength(len(vectorOfInts1))
189+
plaintextRot3.SetLength(len(vectorOfInts1))
190+
plaintextRot4.SetLength(len(vectorOfInts1))
191+
192+
assert str(plaintextAddResult) == str(initialPlaintextAddResult)
193+
assert str(plaintextMultResult) == str(initialPlaintextMultResult)
194+
assert str(plaintextRot1) == str(initialPlaintextRot1)
195+
assert str(plaintextRot2) == str(initialPlaintextRot2)
196+
assert str(plaintextRot3) == str(initialPlaintextRot3)
197+
assert str(plaintextRot4) == str(initialPlaintextRot4)
198+
199+
def rotate_vector(vector, rotation):
200+
"""
201+
Rotate a vector by a specified number of positions.
202+
Positive values rotate left, negative values rotate right.
203+
204+
:param vector: List[int], the vector to rotate.
205+
:param rotation: int, the number of positions to rotate.
206+
:return: List[int], the rotated vector.
207+
"""
208+
n = len(vector)
209+
if rotation > 0:
210+
rotated = vector[rotation:] + [0] * rotation
211+
elif rotation < 0:
212+
rotation = abs(rotation)
213+
rotated = [0] * rotation + vector[:n - rotation]
214+
else:
215+
rotated = vector
216+
return rotated

0 commit comments

Comments
 (0)