Skip to content

Commit 3482f0d

Browse files
committed
improve vess api to accept IntoIterator instead
1 parent cf5f0cd commit 3482f0d

File tree

2 files changed

+61
-32
lines changed

2 files changed

+61
-32
lines changed

timeboost-crypto/src/mre.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ pub struct Ciphertext<C: CurveGroup, H: Digest = sha2::Sha256> {
174174
/// - `aad` is the associated data
175175
/// - `C` is the DL group, `H` is the choice of H_enc whose output space = message space
176176
/// - preprocess messages to pad them to proper length before passing in
177-
pub fn encrypt<C, H, R>(
178-
recipients: &[EncryptionKey<C>],
177+
pub fn encrypt<'a, C, H, R, I>(
178+
recipients: I,
179179
messages: &[Vec<u8>],
180180
aad: &[u8],
181181
rng: &mut R,
@@ -184,14 +184,17 @@ where
184184
C: CurveGroup,
185185
H: Digest,
186186
R: Rng + CryptoRng,
187+
I: IntoIterator<Item = &'a EncryptionKey<C>>,
188+
I::IntoIter: ExactSizeIterator,
187189
{
188190
// input validation
189-
if recipients.is_empty() || messages.is_empty() {
191+
let recipients_iter = recipients.into_iter();
192+
if messages.is_empty() {
190193
return Err(MultiRecvEncError::EmptyInput);
191194
}
192-
if recipients.len() != messages.len() {
195+
if recipients_iter.len() != messages.len() {
193196
return Err(MultiRecvEncError::MismatchedInputLength(
194-
recipients.len(),
197+
recipients_iter.len(),
195198
messages.len(),
196199
));
197200
}
@@ -210,8 +213,7 @@ where
210213
let epk = C::generator().mul(&esk);
211214

212215
// generate recipient-specific ciphertext parts
213-
let cts = recipients
214-
.iter()
216+
let cts = recipients_iter
215217
.zip(messages.iter())
216218
.enumerate()
217219
.map(|(idx, (pk, msg))| {
@@ -228,7 +230,7 @@ where
228230

229231
// TODO(alex): use SIMD vectorized XOR when `std::simd` move out of nightly,
230232
// or rayon as an intermediate improvement
231-
let ct = Output::<H>::from_iter(k.iter().zip(msg).map(|(ki, m)| ki ^ m));
233+
let ct = Output::<H>::from_iter(k.iter().zip(msg.iter()).map(|(ki, m)| ki ^ m));
232234
Ok(ct)
233235
})
234236
.collect::<Result<Vec<_>, MultiRecvEncError>>()?;
@@ -276,7 +278,7 @@ impl From<ark_serialize::SerializationError> for MultiRecvEncError {
276278

277279
#[cfg(test)]
278280
mod tests {
279-
use std::iter::repeat_with;
281+
use std::{collections::BTreeMap, iter::repeat_with};
280282

281283
use ark_bls12_381::G1Projective;
282284
use ark_std::rand;
@@ -290,8 +292,12 @@ mod tests {
290292
let n = 10; // num of recipients
291293
let recv_sks: Vec<DecryptionKey<G1Projective>> =
292294
repeat_with(|| DecryptionKey::rand(rng)).take(n).collect();
293-
let recv_pks: Vec<EncryptionKey<G1Projective>> =
294-
recv_sks.iter().map(EncryptionKey::from).collect();
295+
// collecting into a BTreeSet to demonstrate flexible encrypt() input type
296+
let recv_pks: BTreeMap<usize, EncryptionKey<G1Projective>> = recv_sks
297+
.iter()
298+
.enumerate()
299+
.map(|(i, sk)| (i, EncryptionKey::from(sk)))
300+
.collect();
295301
let labeled_sks: Vec<LabeledDecryptionKey<G1Projective>> = recv_sks
296302
.into_iter()
297303
.enumerate()
@@ -302,7 +308,7 @@ mod tests {
302308
.collect::<Vec<_>>();
303309
let aad = b"Alice";
304310

305-
let mre_ct = encrypt::<G1Projective, H, _>(&recv_pks, &msgs, aad, rng).unwrap();
311+
let mre_ct = encrypt::<G1Projective, H, _, _>(recv_pks.values(), &msgs, aad, rng).unwrap();
306312
for i in 0..n {
307313
let ct = mre_ct.get_recipient_ct(i).unwrap();
308314
assert_eq!(

timeboost-crypto/src/vess.rs

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@ impl<C: CurveGroup> ShoupVess<C> {
209209

210210
// deterministically generate the `i`-th dealing from a random seed
211211
// each dealing contains (Shamir poly + Feldman commitment + MRE ciphertext)
212-
fn new_dealing(
212+
fn new_dealing<'a, I>(
213213
&self,
214214
ith: usize,
215215
seed: &[u8; 32],
216-
recipients: &[mre::EncryptionKey<C>],
216+
recipients: I,
217217
aad: &[u8],
218218
) -> Result<
219219
(
@@ -222,7 +222,11 @@ impl<C: CurveGroup> ShoupVess<C> {
222222
MultiRecvCiphertext<C, sha2::Sha256>,
223223
),
224224
VessError,
225-
> {
225+
>
226+
where
227+
I: IntoIterator<Item = &'a mre::EncryptionKey<C>>,
228+
I::IntoIter: ExactSizeIterator,
229+
{
226230
let mut rng = ChaCha20Rng::from_seed(*seed);
227231
let vss_secret = C::ScalarField::rand(&mut rng);
228232

@@ -231,7 +235,7 @@ impl<C: CurveGroup> ShoupVess<C> {
231235
let serialized_shares: Vec<Vec<u8>> =
232236
FeldmanVss::<C>::compute_serialized_shares(&self.vss_pp, &poly).collect();
233237

234-
let mre_ct = mre::encrypt::<C, sha2::Sha256, _>(
238+
let mre_ct = mre::encrypt::<C, sha2::Sha256, _, _>(
235239
recipients,
236240
&serialized_shares,
237241
&self.indexed_aad(aad, ith),
@@ -258,9 +262,9 @@ impl<C: CurveGroup> ShoupVess<C> {
258262
/// step 1.a is split as two internal steps in the two APIs above. r_k is 32 bytes and
259263
/// SpongeFish's built-in prover private coin toss.
260264
/// - random subset seed s: see [`Self::map_subset_seed()`]
261-
pub fn encrypted_shares(
265+
pub fn encrypted_shares<'a, I>(
262266
&self,
263-
recipients: &[mre::EncryptionKey<C>],
267+
recipients: I,
264268
secret: C::ScalarField,
265269
aad: &[u8],
266270
) -> Result<
@@ -269,11 +273,16 @@ impl<C: CurveGroup> ShoupVess<C> {
269273
<FeldmanVss<C> as VerifiableSecretSharing>::Commitment,
270274
),
271275
VessError,
272-
> {
273-
// input validation
276+
>
277+
where
278+
I: IntoIterator<Item = &'a mre::EncryptionKey<C>>,
279+
I::IntoIter: ExactSizeIterator + Clone + Sync,
280+
{
281+
// input validation - check length without consuming the iterator
282+
let recipients_iter = recipients.into_iter();
274283
let n = self.vss_pp.n.get();
275-
if recipients.len() != n {
276-
return Err(VessError::WrongRecipientsLength(n, recipients.len()));
284+
if recipients_iter.len() != n {
285+
return Err(VessError::WrongRecipientsLength(n, recipients_iter.len()));
277286
}
278287

279288
let mut prover_state = self.io_pattern(aad).to_prover_state();
@@ -293,7 +302,7 @@ impl<C: CurveGroup> ShoupVess<C> {
293302
)> = seeds
294303
.par_iter()
295304
.enumerate()
296-
.map(|(i, r)| self.new_dealing(i, r, recipients, aad))
305+
.map(|(i, r)| self.new_dealing(i, r, recipients_iter.clone(), aad))
297306
.collect::<Result<_, VessError>>()?;
298307

299308
// compute h:= H_compress(aad, dealings)
@@ -392,13 +401,17 @@ impl<C: CurveGroup> ShoupVess<C> {
392401

393402
/// Verify if the ciphertext (for all recipients) correctly encrypting valid secret shares,
394403
/// verifiable by anyone.
395-
pub fn verify(
404+
pub fn verify<'a, I>(
396405
&self,
397-
recipients: &[mre::EncryptionKey<C>],
406+
recipients: I,
398407
ct: &VessCiphertext,
399408
comm: &<FeldmanVss<C> as VerifiableSecretSharing>::Commitment,
400409
aad: &[u8],
401-
) -> Result<(), VessError> {
410+
) -> Result<(), VessError>
411+
where
412+
I: IntoIterator<Item = &'a mre::EncryptionKey<C>> + Clone,
413+
I::IntoIter: ExactSizeIterator,
414+
{
402415
let mut verifier_state = self.io_pattern(aad).to_verifier_state(&ct.transcript);
403416

404417
// verifier logic until Step 4b
@@ -454,7 +467,8 @@ impl<C: CurveGroup> ShoupVess<C> {
454467
let seed = seeds
455468
.pop_front()
456469
.expect("subset_size < num_repetitions, so seeds.len() > 0");
457-
let (_poly, cm, mre_ct) = self.new_dealing(i, &seed, recipients, aad)?;
470+
let (_poly, cm, mre_ct) =
471+
self.new_dealing(i, &seed, recipients.clone(), aad)?;
458472

459473
hasher.update(serialize_to_vec![cm]?);
460474
hasher.update(mre_ct.to_bytes());
@@ -703,7 +717,11 @@ mod tests {
703717
UniformRand,
704718
rand::{SeedableRng, rngs::StdRng},
705719
};
706-
use std::{collections::BTreeSet, iter::repeat_with, num::NonZeroUsize};
720+
use std::{
721+
collections::{BTreeMap, BTreeSet},
722+
iter::repeat_with,
723+
num::NonZeroUsize,
724+
};
707725

708726
type H = sha2::Sha256;
709727
type Vss = FeldmanVss<G1Projective>;
@@ -717,18 +735,23 @@ mod tests {
717735
repeat_with(|| mre::DecryptionKey::rand(rng))
718736
.take(n)
719737
.collect();
720-
let recv_pks: Vec<mre::EncryptionKey<G1Projective>> =
721-
recv_sks.iter().map(mre::EncryptionKey::from).collect();
738+
let recv_pks: BTreeMap<usize, mre::EncryptionKey<G1Projective>> = recv_sks
739+
.iter()
740+
.enumerate()
741+
.map(|(i, sk)| (i, mre::EncryptionKey::from(sk)))
742+
.collect();
722743
let labeled_sks: Vec<LabeledDecryptionKey<G1Projective>> = recv_sks
723744
.into_iter()
724745
.enumerate()
725746
.map(|(i, sk)| sk.label(i))
726747
.collect();
727748

728749
let aad = b"Associated data";
729-
let (ct, comm) = vess.encrypted_shares(&recv_pks, secret, aad).unwrap();
750+
let (ct, comm) = vess
751+
.encrypted_shares(recv_pks.values(), secret, aad)
752+
.unwrap();
730753

731-
assert!(vess.verify(&recv_pks, &ct, &comm, aad).is_ok());
754+
assert!(vess.verify(recv_pks.values(), &ct, &comm, aad).is_ok());
732755
for labeled_recv_sk in labeled_sks {
733756
let share = vess.decrypt_share(&labeled_recv_sk, &ct, aad).unwrap();
734757
assert!(Vss::verify(&vess.vss_pp, labeled_recv_sk.node_idx, &share, &comm).is_ok());

0 commit comments

Comments
 (0)