Skip to content

Commit 300731a

Browse files
authored
Merge pull request #18 from b5li/rns
Add gadget-based rns relinearization key
2 parents 8eecd05 + d1a1505 commit 300731a

File tree

6 files changed

+1065
-3
lines changed

6 files changed

+1065
-3
lines changed

shell_encryption/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
load("@rules_license//rules:license.bzl", "license")
1615
load("@rules_cc//cc:defs.bzl", "cc_library")
16+
load("@rules_license//rules:license.bzl", "license")
1717
load("@rules_proto//proto:defs.bzl", "proto_library")
1818

1919
package(default_visibility = ["//visibility:public"])

shell_encryption/rns/BUILD

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,59 @@ cc_test(
772772
"@com_google_absl//absl/strings",
773773
],
774774
)
775+
776+
# Gadget-based relinearization key.
777+
cc_library(
778+
name = "rns_relinearization_key",
779+
srcs = ["rns_relinearization_key.cc"],
780+
hdrs = ["rns_relinearization_key.h"],
781+
deps = [
782+
":error_distribution",
783+
":rns_bfv_ciphertext",
784+
":rns_bgv_ciphertext",
785+
":rns_ciphertext",
786+
":rns_gadget",
787+
":rns_modulus",
788+
":rns_polynomial",
789+
":rns_secret_key",
790+
":serialization_cc_proto",
791+
"//shell_encryption:integral_types",
792+
"//shell_encryption:montgomery",
793+
"//shell_encryption:statusor_fork",
794+
"//shell_encryption/prng",
795+
"//shell_encryption/prng:single_thread_chacha_prng",
796+
"//shell_encryption/prng:single_thread_hkdf_prng",
797+
"@com_google_absl//absl/log:check",
798+
"@com_google_absl//absl/numeric:int128",
799+
"@com_google_absl//absl/status",
800+
"@com_google_absl//absl/status:statusor",
801+
"@com_google_absl//absl/strings",
802+
"@com_google_absl//absl/types:span",
803+
],
804+
)
805+
806+
cc_test(
807+
name = "rns_relinearization_key_test",
808+
srcs = ["rns_relinearization_key_test.cc"],
809+
deps = [
810+
":finite_field_encoder",
811+
":rns_bfv_ciphertext",
812+
":rns_bgv_ciphertext",
813+
":rns_context",
814+
":rns_error_params",
815+
":rns_gadget",
816+
":rns_modulus",
817+
":rns_polynomial",
818+
":rns_relinearization_key",
819+
":rns_secret_key",
820+
"//shell_encryption/rns/testing:parameters",
821+
"//shell_encryption/rns/testing:testing_utils",
822+
"//shell_encryption/testing:matchers",
823+
"//shell_encryption/testing:status_testing",
824+
"//shell_encryption/testing:testing_prng",
825+
"@com_github_google_googletest//:gtest_main",
826+
"@com_google_absl//absl/log:check",
827+
"@com_google_absl//absl/status",
828+
"@com_google_absl//absl/strings",
829+
],
830+
)

shell_encryption/rns/rns_bgv_ciphertext_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ TYPED_TEST(RnsBgvCiphertextTest, ModReducedCiphertextDecrypts) {
269269
if (this->main_moduli_.size() < 2) {
270270
// There is only one prime modulus in the moduli chain, so we cannot perform
271271
// modulus reduction further.
272-
return;
272+
GTEST_SKIP() << "Insufficient number of prime moduli for ModReduce.";
273273
}
274274

275275
ASSERT_OK_AND_ASSIGN(RnsRlweSecretKey<TypeParam> key, this->SampleKey());
@@ -649,7 +649,7 @@ TYPED_TEST(RnsBgvCiphertextPackedTest, HomomorphicMulWithCiphertext) {
649649
// The test parameters are not suitable for homomorphic multiplication as
650650
// the error in the product ciphertext is expected to be larger than the
651651
// ciphertext modulus.
652-
return;
652+
continue;
653653
}
654654

655655
ASSERT_OK_AND_ASSIGN(RnsRlweSecretKey<TypeParam> key, this->SampleKey());
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "shell_encryption/rns/rns_relinearization_key.h"
16+
17+
#include <memory>
18+
#include <string>
19+
#include <utility>
20+
#include <vector>
21+
22+
#include "absl/log/check.h"
23+
#include "absl/numeric/int128.h"
24+
#include "absl/status/status.h"
25+
#include "absl/status/statusor.h"
26+
#include "absl/strings/str_cat.h"
27+
#include "absl/strings/string_view.h"
28+
#include "absl/types/span.h"
29+
#include "shell_encryption/integral_types.h"
30+
#include "shell_encryption/montgomery.h"
31+
#include "shell_encryption/prng/prng.h"
32+
#include "shell_encryption/prng/single_thread_chacha_prng.h"
33+
#include "shell_encryption/prng/single_thread_hkdf_prng.h"
34+
#include "shell_encryption/rns/error_distribution.h"
35+
#include "shell_encryption/rns/rns_ciphertext.h"
36+
#include "shell_encryption/rns/rns_gadget.h"
37+
#include "shell_encryption/rns/rns_modulus.h"
38+
#include "shell_encryption/rns/rns_polynomial.h"
39+
#include "shell_encryption/rns/rns_secret_key.h"
40+
#include "shell_encryption/status_macros.h"
41+
42+
namespace rlwe {
43+
44+
template <typename ModularInt>
45+
absl::StatusOr<RnsRelinKey<ModularInt>> RnsRelinKey<ModularInt>::Create(
46+
const RnsRlweSecretKey<ModularInt>& secret_key, int degree, int variance,
47+
const RnsGadget<ModularInt>* gadget, PrngType prng_type,
48+
Integer error_scalar) {
49+
if (degree <= 1) {
50+
return absl::InvalidArgumentError("`degree` must be at least 2.");
51+
}
52+
if (gadget == nullptr) {
53+
return absl::InvalidArgumentError("`gadget` must not be null.");
54+
}
55+
56+
// Sample a PRNG seed for sampling the random polynomials in `key_as`.
57+
std::string prng_pad_seed;
58+
if (prng_type == PRNG_TYPE_HKDF) {
59+
RLWE_ASSIGN_OR_RETURN(prng_pad_seed, SingleThreadHkdfPrng::GenerateSeed());
60+
} else if (prng_type == PRNG_TYPE_CHACHA) {
61+
RLWE_ASSIGN_OR_RETURN(prng_pad_seed,
62+
SingleThreadChaChaPrng::GenerateSeed());
63+
} else {
64+
return absl::InvalidArgumentError("PrngType not specified correctly.");
65+
}
66+
67+
// The relinearization key is a k*(l-1) x 2 matrix [[b2, a2], ..., [bl, al]],
68+
// where each block [bi, ai] consists of k RLWE encryptions of the target
69+
// secret g * s^i under the canonical secret key s(X). In the following, we
70+
// first sample the random polynomials ai's, and then generate bi's.
71+
RLWE_ASSIGN_OR_RETURN(
72+
std::vector<RnsPolynomial<ModularInt>> key_as,
73+
SampleRandomPad(gadget->Dimension(), degree, secret_key.LogN(),
74+
secret_key.Moduli(), prng_pad_seed, prng_type));
75+
76+
return CreateWithRandomPad(std::move(key_as), secret_key, degree, variance,
77+
gadget, prng_pad_seed, prng_type, error_scalar);
78+
}
79+
80+
template <typename ModularInt>
81+
absl::StatusOr<std::vector<RnsPolynomial<ModularInt>>>
82+
RnsRelinKey<ModularInt>::SampleRandomPad(
83+
int dimension, int degree, int log_n,
84+
absl::Span<const PrimeModulus<ModularInt>* const> moduli,
85+
absl::string_view prng_seed, PrngType prng_type) {
86+
if (dimension <= 0) {
87+
return absl::InvalidArgumentError("`dimension` must be positive.");
88+
}
89+
if (degree <= 1) {
90+
return absl::InvalidArgumentError("`degree` must be at least 2.");
91+
}
92+
if (log_n <= 0) {
93+
return absl::InvalidArgumentError("`log_n` must be positive.");
94+
}
95+
// Create a PRNG for sampling the random polynomials in `key_as`.
96+
std::unique_ptr<SecurePrng> prng_pad;
97+
if (prng_type == PRNG_TYPE_HKDF) {
98+
RLWE_ASSIGN_OR_RETURN(prng_pad, SingleThreadHkdfPrng::Create(prng_seed));
99+
} else if (prng_type == PRNG_TYPE_CHACHA) {
100+
RLWE_ASSIGN_OR_RETURN(prng_pad, SingleThreadChaChaPrng::Create(prng_seed));
101+
} else {
102+
return absl::InvalidArgumentError("PrngType not specified correctly.");
103+
}
104+
105+
std::vector<RnsPolynomial<ModularInt>> key_as;
106+
key_as.reserve(dimension * (degree - 1));
107+
for (int i = 0; i < dimension * (degree - 1); ++i) {
108+
RLWE_ASSIGN_OR_RETURN(auto a, RnsPolynomial<ModularInt>::SampleUniform(
109+
log_n, prng_pad.get(), moduli));
110+
key_as.push_back(std::move(a));
111+
}
112+
return key_as;
113+
}
114+
115+
template <typename ModularInt>
116+
absl::StatusOr<RnsRelinKey<ModularInt>>
117+
RnsRelinKey<ModularInt>::CreateWithRandomPad(
118+
std::vector<RnsPolynomial<ModularInt>> pads,
119+
const RnsRlweSecretKey<ModularInt>& secret_key, int degree, int variance,
120+
const RnsGadget<ModularInt>* gadget, absl::string_view prng_pad_seed,
121+
PrngType prng_type, Integer error_scalar) {
122+
if (variance <= 0) {
123+
return absl::InvalidArgumentError("`variance` must be positive.");
124+
}
125+
126+
// Create the PRNGs for sampling the encryption randomness.
127+
std::unique_ptr<SecurePrng> prng_encryption;
128+
std::string prng_encryption_seed;
129+
if (prng_type == PRNG_TYPE_HKDF) {
130+
RLWE_ASSIGN_OR_RETURN(prng_encryption_seed,
131+
SingleThreadHkdfPrng::GenerateSeed());
132+
RLWE_ASSIGN_OR_RETURN(prng_encryption,
133+
SingleThreadHkdfPrng::Create(prng_encryption_seed));
134+
} else {
135+
RLWE_ASSIGN_OR_RETURN(prng_encryption_seed,
136+
SingleThreadChaChaPrng::GenerateSeed());
137+
RLWE_ASSIGN_OR_RETURN(prng_encryption,
138+
SingleThreadChaChaPrng::Create(prng_encryption_seed));
139+
}
140+
141+
const RnsPolynomial<ModularInt>& target_key = secret_key.Key();
142+
RnsPolynomial<ModularInt> secret = target_key;
143+
144+
// The relinearization key is a k*(l-1) x 2 matrix [[b2, a2], ..., [bl, al]],
145+
// where ai consists of k uniformly random polynomials mod `moduli` and bi =
146+
// -ai * s + t * ei + gadget * s^i, for t = `error_scalar`.
147+
int log_n = secret_key.LogN();
148+
int k = gadget->Dimension();
149+
std::vector<RnsPolynomial<ModularInt>> key_bs;
150+
key_bs.reserve(k);
151+
int index = 0; // for polynomials in `pads`.
152+
auto moduli = secret_key.Moduli();
153+
for (int i = 2; i <= degree; ++i) {
154+
RLWE_RETURN_IF_ERROR(secret.MulInPlace(target_key, moduli)); // s^i
155+
for (int j = 0; j < k; ++j) {
156+
// a = -u
157+
RLWE_ASSIGN_OR_RETURN(RnsPolynomial<ModularInt> u,
158+
pads[index++].Negate(moduli));
159+
160+
RnsPolynomial<ModularInt> z = secret;
161+
RLWE_RETURN_IF_ERROR(z.MulInPlace(gadget->Component(j), moduli));
162+
163+
RLWE_ASSIGN_OR_RETURN(
164+
RnsPolynomial<ModularInt> b,
165+
SampleError<ModularInt>(log_n, variance, moduli,
166+
prng_encryption.get())); // b = e
167+
RLWE_RETURN_IF_ERROR(b.MulInPlace(error_scalar, moduli)); // b = t * e
168+
RLWE_RETURN_IF_ERROR(b.AddInPlace(z, moduli)); // b = t * e + g[j] * s'
169+
RLWE_RETURN_IF_ERROR(b.FusedMulAddInPlace(u, target_key, moduli));
170+
171+
key_bs.push_back(std::move(b));
172+
}
173+
}
174+
175+
// Store the RNS moduli.
176+
std::vector<const PrimeModulus<ModularInt>*> moduli_vector;
177+
moduli_vector.insert(moduli_vector.begin(), moduli.begin(), moduli.end());
178+
179+
return RnsRelinKey(/*key_as=*/std::move(pads), std::move(key_bs), gadget,
180+
degree, std::move(moduli_vector), prng_pad_seed,
181+
prng_type);
182+
}
183+
184+
template <typename ModularInt>
185+
absl::StatusOr<std::vector<RnsPolynomial<ModularInt>>>
186+
RnsRelinKey<ModularInt>::ApplyToRlweCiphertext(
187+
const RnsRlweCiphertext<ModularInt>& ciphertext) const {
188+
if (ciphertext.Degree() > degree_) {
189+
return absl::InvalidArgumentError(
190+
absl::StrCat("`ciphertext` degree is larger than degree of this "
191+
"relinearization key, ",
192+
degree_, "."));
193+
}
194+
if (ciphertext.NumModuli() != moduli_.size()) {
195+
return absl::InvalidArgumentError(
196+
"`ciphertext` does not have a matching RNS moduli set.");
197+
}
198+
if (ciphertext.PowerOfS() != 1) {
199+
return absl::InvalidArgumentError(
200+
"Relinearization key can only apply to a ciphertext of power 1.");
201+
}
202+
203+
// Apply the relinearization key with blocks [b2, a2], ..., [bl, al] to a
204+
// degree-l ciphertext (c0, ..., cl) to get a new ciphertext (c0', c1') =
205+
// (c0, c1) + sum(g^-1(ci) * [bi, ai], i = 2..l).
206+
int k = gadget_->Dimension();
207+
int l = ciphertext.Degree();
208+
RLWE_ASSIGN_OR_RETURN(RnsPolynomial<ModularInt> c0_new,
209+
ciphertext.Component(0));
210+
RLWE_ASSIGN_OR_RETURN(RnsPolynomial<ModularInt> c1_new,
211+
ciphertext.Component(1));
212+
for (int i = 2; i <= l; ++i) {
213+
RLWE_ASSIGN_OR_RETURN(RnsPolynomial<ModularInt> ci,
214+
ciphertext.Component(i));
215+
if (ci.IsNttForm()) {
216+
RLWE_RETURN_IF_ERROR(ci.ConvertToCoeffForm(moduli_));
217+
}
218+
RLWE_ASSIGN_OR_RETURN(std::vector<RnsPolynomial<ModularInt>> ci_digits,
219+
gadget_->Decompose(ci, moduli_));
220+
for (int j = 0; j < k; ++j) {
221+
RLWE_RETURN_IF_ERROR(ci_digits[j].ConvertToNttForm(moduli_));
222+
int index = (i - 2) * l + j;
223+
RLWE_RETURN_IF_ERROR(
224+
c0_new.FusedMulAddInPlace(ci_digits[j], key_bs_[index], moduli_));
225+
RLWE_RETURN_IF_ERROR(
226+
c1_new.FusedMulAddInPlace(ci_digits[j], key_as_[index], moduli_));
227+
}
228+
}
229+
230+
return std::vector<RnsPolynomial<ModularInt>>{std::move(c0_new),
231+
std::move(c1_new)};
232+
}
233+
234+
template class RnsRelinKey<MontgomeryInt<Uint16>>;
235+
template class RnsRelinKey<MontgomeryInt<Uint32>>;
236+
template class RnsRelinKey<MontgomeryInt<Uint64>>;
237+
template class RnsRelinKey<MontgomeryInt<absl::uint128>>;
238+
#ifdef ABSL_HAVE_INTRINSIC_INT128
239+
template class RnsRelinKey<MontgomeryInt<unsigned __int128>>;
240+
#endif
241+
242+
} // namespace rlwe

0 commit comments

Comments
 (0)