Skip to content

Commit 5d5447d

Browse files
committed
more readable parallel iter by pre-collecting subset_map and non_subset_map
1 parent f5f080d commit 5d5447d

File tree

1 file changed

+28
-33
lines changed

1 file changed

+28
-33
lines changed

timeboost-crypto/src/vess.rs

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use spongefish::{
2121
GroupDomainSeparator, GroupToUnit,
2222
},
2323
};
24-
use std::collections::VecDeque;
24+
use std::collections::{HashMap, VecDeque};
2525
use thiserror::Error;
2626

2727
use crate::{
@@ -512,39 +512,32 @@ impl<C: CurveGroup> ShoupVess<C> {
512512
// k in S, then homomorphically shift commitment; k notin S, reproduce dealing from seed
513513
let subset_indices = self.map_subset_seed(subset_seed);
514514

515-
// Convert subset_indices to a HashSet for O(1) lookup
516-
let subset_set: std::collections::HashSet<usize> = subset_indices.iter().copied().collect();
517-
518-
// Create vectors to store the shifted polys and mre_cts data in the correct order
519-
let mut shifted_polys_vec = Vec::new();
520-
let mut mre_cts_vec = Vec::new();
521-
let mut seeds_vec = Vec::new();
522-
523-
// Convert VecDeque to Vec to preserve order for subset items
524-
for poly in shifted_polys.iter() {
525-
shifted_polys_vec.push(poly.clone());
526-
}
527-
for ct in mre_cts.iter() {
528-
mre_cts_vec.push(ct.clone());
529-
}
530-
for seed in seeds.iter() {
531-
seeds_vec.push(*seed);
532-
}
515+
// subset_map tracks the sample i \in S \subseteq [N] and its position in `shifted_polys`
516+
// and `mre_cts`;
517+
// non_subset_map tracks the sample j \notin S \subseteq [N] and its position in `seeds`
518+
let (subset_map, non_subset_map) = {
519+
let mut sm = HashMap::with_capacity(self.subset_size);
520+
let mut nm = HashMap::with_capacity(self.num_repetition - self.subset_size);
521+
522+
let mut non_subset_pos = 0usize;
523+
for i in 0..self.num_repetition {
524+
if let Some(pos) = subset_indices.iter().position(|&x| x == i) {
525+
sm.insert(i, pos);
526+
} else {
527+
nm.insert(i, non_subset_pos);
528+
non_subset_pos += 1;
529+
}
530+
}
531+
(sm, nm)
532+
};
533533

534534
// Compute all hash data in parallel
535535
let hash_data = (0..self.num_repetition)
536536
.into_par_iter()
537537
.map(|i| {
538-
if subset_set.contains(&i) {
538+
if let Some(pos) = subset_map.get(&i) {
539539
// k in S, shift the commitment
540-
// Find the position of i in subset_indices to get the correct poly/ct
541-
// note(alex): binary search will be slower in benchmark
542-
let subset_pos = subset_indices
543-
.iter()
544-
.position(|&x| x == i)
545-
.expect("i should be in subset_indices");
546-
547-
let shifted_comm = vss_pp.commit(&shifted_polys_vec[subset_pos]);
540+
let shifted_comm = vss_pp.commit(shifted_polys[*pos].as_ref());
548541

549542
let mut unshifted_comm = vec![];
550543
for (shifted, delta) in shifted_comm.into_iter().zip(comm.iter()) {
@@ -553,14 +546,12 @@ impl<C: CurveGroup> ShoupVess<C> {
553546
}
554547
let unshifted_comm = C::normalize_batch(&unshifted_comm);
555548
let unshifted_comm_bytes = serialize_to_vec![unshifted_comm]?;
556-
let mre_ct_bytes = mre_cts_vec[subset_pos].to_bytes();
549+
let mre_ct_bytes = mre_cts[*pos].to_bytes();
557550

558551
Ok((unshifted_comm_bytes, mre_ct_bytes))
559-
} else {
552+
} else if let Some(pos) = non_subset_map.get(&i) {
560553
// k notin S, reproduce the dealing deterministically from seed
561-
// Find the position of i among non-subset indices
562-
let non_subset_pos = (0..i).filter(|&j| !subset_set.contains(&j)).count();
563-
let seed = seeds_vec[non_subset_pos];
554+
let seed = seeds[*pos];
564555

565556
let (_poly, cm, mre_ct) =
566557
self.new_dealing(&vss_pp, i, &seed, recipients.clone(), aad)?;
@@ -569,6 +560,8 @@ impl<C: CurveGroup> ShoupVess<C> {
569560
let mre_ct_bytes = mre_ct.to_bytes();
570561

571562
Ok((cm_bytes, mre_ct_bytes))
563+
} else {
564+
Err(VessError::Unreachable)
572565
}
573566
})
574567
.collect::<Result<Vec<_>, VessError>>()?;
@@ -869,6 +862,8 @@ pub enum VessError {
869862
FailedVerification,
870863
#[error("decryption fail")]
871864
DecryptionFailed,
865+
#[error("impossible happens, some function contracts violated")]
866+
Unreachable,
872867
}
873868

874869
impl From<ark_serialize::SerializationError> for VessError {

0 commit comments

Comments
 (0)