Skip to content

Commit b1b55da

Browse files
committed
Additional Decrypter refactor.
1 parent 989a887 commit b1b55da

File tree

3 files changed

+136
-110
lines changed

3 files changed

+136
-110
lines changed

timeboost-sequencer/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ ethereum_ssz = { workspace = true }
1919
metrics = { path = "../metrics" }
2020
multisig = { path = "../multisig" }
2121
parking_lot = { workspace = true }
22+
rayon = { workspace = true }
2223
sailfish = { path = "../sailfish" }
2324
serde = { workspace = true }
2425
thiserror = { workspace = true }

timeboost-sequencer/src/decrypt.rs

Lines changed: 120 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ use cliquenet::overlay::{Data, DataError, NetworkDown, Overlay};
55
use cliquenet::{
66
AddressableCommittee, MAX_MESSAGE_SIZE, Network, NetworkError, NetworkMetrics, Role,
77
};
8+
use rayon::iter::ParallelIterator;
9+
use rayon::prelude::*;
10+
811
use multisig::{CommitteeId, PublicKey};
912
use parking_lot::RwLock;
1013
use sailfish::types::{Evidence, Round, RoundNumber};
1114
use serde::{Deserialize, Serialize};
12-
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
15+
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
1316
use std::result::Result as StdResult;
1417
use std::sync::Arc;
1518
use timeboost_crypto::prelude::{DkgDecKey, LabeledDkgDecKey, Vess, Vss};
@@ -22,7 +25,7 @@ use timeboost_types::{
2225
};
2326
use tokio::spawn;
2427
use tokio::sync::mpsc::{Receiver, Sender, channel};
25-
use tokio::task::JoinHandle;
28+
use tokio::task::{JoinError, JoinHandle};
2629
use tracing::{debug, error, info, trace, warn};
2730

2831
use crate::config::DecrypterConfig;
@@ -65,16 +68,13 @@ enum Command {
6568

6669
/// Holds key material for the next committee
6770
struct NextKey {
68-
next_dkg_key: LabeledDkgDecKey,
69-
next_dec_key: DecryptionKey,
71+
dkg_key: LabeledDkgDecKey,
72+
dec_key: DecryptionKey,
7073
}
7174

7275
impl NextKey {
73-
pub fn new(next_dkg_key: LabeledDkgDecKey, next_dec_key: DecryptionKey) -> Self {
74-
Self {
75-
next_dkg_key,
76-
next_dec_key,
77-
}
76+
pub fn new(dkg_key: LabeledDkgDecKey, dec_key: DecryptionKey) -> Self {
77+
Self { dkg_key, dec_key }
7878
}
7979
}
8080

@@ -99,7 +99,7 @@ enum NextCommittee {
9999
/// ciphertexts. The shares are exchanged with other decrypters, and once a sufficient number are
100100
/// collected, shares can be combined ("hatching") to obtain the plaintext.
101101
///
102-
/// In addition, the `Decrypter` extracts DKG bundles from inclusion lists and combines them to derive
102+
/// In addition, the `Decrypter` extracts DKG bundles from candidate lists and combines them to derive
103103
/// the keys used for threshold decryption/combining.
104104
pub struct Decrypter {
105105
/// Public key of the node.
@@ -111,7 +111,7 @@ pub struct Decrypter {
111111
incls: VecDeque<(RoundNumber, bool)>,
112112
/// Sender end of the worker commands
113113
worker_tx: Sender<Command>,
114-
/// Receiver end of the Worker response
114+
/// Receiver end of the Worker response.
115115
worker_rx: Receiver<InclusionList>,
116116
/// Worker task handle.
117117
worker: JoinHandle<EndOfPlay>,
@@ -121,7 +121,7 @@ pub struct Decrypter {
121121
key_stores: Arc<RwLock<KeyStoreVec<2>>>,
122122
/// Current committee.
123123
current: CommitteeId,
124-
/// Metrics to keep track of decrypter status
124+
/// Metrics to keep track of decrypter status.
125125
metrics: Arc<SequencerMetrics>,
126126
}
127127

@@ -452,11 +452,11 @@ struct Worker {
452452
/// Number of rounds to retain.
453453
retain: usize,
454454

455-
/// State of the node holding generated threshold decryption keys.
455+
/// Operational state of the node.
456456
#[builder(default)]
457457
state: WorkerState,
458458

459-
/// The next committee ID and its round number (if any).
459+
/// The next committee ID, its round number and generated key (if any).
460460
#[builder(default)]
461461
next_committee: NextCommittee,
462462

@@ -804,7 +804,7 @@ impl Worker {
804804
if !is_resharing {
805805
if let Some(subset) = acc.try_finalize() {
806806
let dec_key = subset
807-
.extract_key(&key_store, &self.dkg_sk, None)
807+
.extract_key(&key_store.clone(), &self.dkg_sk, None)
808808
.map_err(|e| DecrypterError::Dkg(e.to_string()))?;
809809
self.dec_key.set(dec_key);
810810
self.state = WorkerState::Running;
@@ -890,24 +890,6 @@ impl Worker {
890890
Ok(())
891891
}
892892

893-
/// scan through the inclusion list and extract the ciphertexts from encrypted
894-
/// bundle/tx while preserving the order.
895-
///
896-
/// dev: Option<_> return type indicates potential failure in ciphertext deserialization
897-
fn extract_ciphertexts(incl: &InclusionList) -> impl Iterator<Item = Option<Ciphertext>> {
898-
incl.priority_bundles()
899-
.iter()
900-
.filter(move |pb| pb.bundle().is_encrypted())
901-
.map(|pb| pb.bundle().data())
902-
.chain(
903-
incl.regular_bundles()
904-
.iter()
905-
.filter(move |b| b.is_encrypted())
906-
.map(|b| b.data()),
907-
)
908-
.map(|bytes| deserialize::<Ciphertext>(bytes).ok())
909-
}
910-
911893
/// Produce decryption shares for each encrypted bundle inside the inclusion list.
912894
///
913895
/// NOTE: when a ciphertext is malformed, we will skip decrypting it (treat as garbage) here.
@@ -918,17 +900,20 @@ impl Worker {
918900
})?;
919901

920902
let round = Round::new(incl.round(), self.current);
921-
let dec_shares = Self::extract_ciphertexts(incl)
922-
.map(|optional_ct| {
923-
optional_ct.and_then(|ct| {
924-
// TODO: (anders) consider using committee_id as part of `aad`.
925-
<DecryptionScheme as ThresholdEncScheme>::decrypt(
926-
dec_key.privkey(),
927-
&ct,
928-
&THRES_AAD.to_vec(),
929-
)
930-
.ok() // decryption failure result in None
931-
})
903+
let ciphertexts: Vec<_> = incl
904+
.filter_ciphertexts()
905+
.filter_map(|bytes| deserialize::<_, Ciphertext>(&bytes).ok())
906+
.collect();
907+
908+
let dec_shares = ciphertexts
909+
.par_iter()
910+
.map(|ct| {
911+
<DecryptionScheme as ThresholdEncScheme>::decrypt(
912+
dec_key.privkey(),
913+
ct,
914+
&THRES_AAD.to_vec(),
915+
)
916+
.ok()
932917
})
933918
.collect::<Vec<_>>();
934919

@@ -1011,7 +996,7 @@ impl Worker {
1011996
return Ok(Some(incl));
1012997
}
1013998

1014-
let ciphertexts = Self::extract_ciphertexts(&incl);
999+
let ciphertexts = incl.filter_ciphertexts().map(|b| deserialize(b).ok());
10151000
let Some(dec_shares) = self.dec_shares.get(&round) else {
10161001
return Ok(None);
10171002
};
@@ -1042,51 +1027,90 @@ impl Worker {
10421027

10431028
let mut decrypted: Vec<Option<Plaintext>> = vec![];
10441029

1045-
// Now, after immutable borrow is dropped, get mutable access
10461030
let Some(per_ct_opt_dec_shares) = self.dec_shares.get_mut(&round) else {
10471031
return Ok(None);
10481032
};
10491033

1050-
for (opt_ct, opt_dec_shares) in ciphertexts.into_iter().zip(per_ct_opt_dec_shares) {
1051-
// only Some(_) for valid ciphertext's decryption shares
1052-
let dec_shares = opt_dec_shares
1053-
.iter()
1054-
.filter_map(|s| s.as_ref())
1055-
.collect::<Vec<_>>();
1034+
// define the result of a combine operation
1035+
#[derive(Debug)]
1036+
enum CombineResult {
1037+
Success(Plaintext),
1038+
FaultySubset(BTreeSet<u32>),
1039+
Error(ThresholdEncError),
1040+
InsufficientShares,
1041+
}
10561042

1057-
if dec_shares.len() < key_store.committee().one_honest_threshold().into() {
1058-
decrypted.push(None);
1059-
continue;
1060-
}
1043+
// process each ciphertext in parallel using spawn_blocking
1044+
let combine_results = tokio::task::spawn_blocking({
1045+
let key_store = key_store.clone();
1046+
let dec_key = dec_key.clone();
1047+
let ciphertexts: Vec<_> = ciphertexts.collect();
1048+
let mut per_ct_opt_dec_shares = per_ct_opt_dec_shares.clone();
1049+
1050+
move || {
1051+
ciphertexts
1052+
.into_par_iter()
1053+
.zip(per_ct_opt_dec_shares.par_iter_mut())
1054+
.map(|(maybe_ct, decryption_shares)| {
1055+
// Collect valid decryption shares
1056+
let valid_shares: Vec<_> = decryption_shares
1057+
.iter()
1058+
.filter_map(|s| s.as_ref())
1059+
.cloned()
1060+
.collect();
1061+
1062+
// check if we have enough shares
1063+
let threshold: usize = key_store.committee().one_honest_threshold().into();
1064+
if valid_shares.len() < threshold {
1065+
return CombineResult::InsufficientShares;
1066+
}
10611067

1062-
if let Some(ct) = opt_ct {
1063-
match DecryptionScheme::combine(
1064-
key_store.committee(),
1065-
dec_key.combkey(),
1066-
dec_shares,
1067-
&ct,
1068-
&THRES_AAD.to_vec(),
1069-
) {
1070-
Ok(pt) => decrypted.push(Some(pt)),
1071-
// with f+1 decryption shares, which means ciphertext is valid, we just need to
1072-
// remove bad decryption shares and wait for enough shares from honest nodes
1073-
Err(ThresholdEncError::FaultySubset(wrong_indices)) => {
1074-
opt_dec_shares.retain(|opt_s| {
1075-
opt_s
1076-
.clone()
1077-
.is_none_or(|s| !wrong_indices.contains(&s.index()))
1078-
});
1079-
warn!(node = %self.label, ?wrong_indices, "combine found faulty subset");
1080-
// not ready to hatch this ciphertext, thus the containing inclusion list
1081-
return Ok(None);
1082-
}
1083-
Err(e) => {
1084-
warn!(node = %self.label, error = ?e, "error in combine");
1085-
return Err(DecrypterError::Decryption(e));
1086-
}
1068+
// skip if no ciphertext
1069+
let Some(ct) = maybe_ct else {
1070+
return CombineResult::InsufficientShares;
1071+
};
1072+
1073+
// attempt to combine shares
1074+
match DecryptionScheme::combine(
1075+
key_store.committee(),
1076+
dec_key.combkey(),
1077+
valid_shares.iter().collect(),
1078+
&ct,
1079+
&THRES_AAD.to_vec(),
1080+
) {
1081+
Ok(pt) => CombineResult::Success(pt),
1082+
Err(ThresholdEncError::FaultySubset(wrong_indices)) => {
1083+
CombineResult::FaultySubset(wrong_indices.into())
1084+
}
1085+
Err(e) => CombineResult::Error(e),
1086+
}
1087+
})
1088+
.collect::<Vec<_>>()
1089+
}
1090+
})
1091+
.await?;
1092+
1093+
for (result, decryption_shares) in combine_results.into_iter().zip(per_ct_opt_dec_shares) {
1094+
match result {
1095+
CombineResult::Success(pt) => decrypted.push(Some(pt)),
1096+
CombineResult::FaultySubset(wrong_indices) => {
1097+
// Remove faulty decryption shares
1098+
decryption_shares.retain(|opt_s| {
1099+
opt_s
1100+
.clone()
1101+
.is_none_or(|s| !wrong_indices.contains(&s.index()))
1102+
});
1103+
warn!(node = %self.label, ?wrong_indices, "combine found faulty subset");
1104+
// Not ready to hatch this ciphertext
1105+
return Ok(None);
1106+
}
1107+
CombineResult::Error(e) => {
1108+
warn!(node = %self.label, error = ?e, "error in combine");
1109+
return Err(DecrypterError::Decryption(e));
1110+
}
1111+
CombineResult::InsufficientShares => {
1112+
decrypted.push(None);
10871113
}
1088-
} else {
1089-
decrypted.push(None);
10901114
}
10911115
}
10921116

@@ -1188,11 +1212,11 @@ impl Worker {
11881212
return Err(DecrypterError::Dkg("accumulator incomplete".into()));
11891213
};
11901214

1191-
let Some(new_pos) = new.committee().get_index(&self.label) else {
1215+
let Some(new_node_idx) = new.committee().get_index(&self.label) else {
11921216
return Err(DecrypterError::Dkg("node not found in committee".into()));
11931217
};
11941218

1195-
let new_dkg_sk = DkgDecKey::from(self.dkg_sk.clone()).label(new_pos.into());
1219+
let new_dkg_sk = DkgDecKey::from(self.dkg_sk.clone()).label(new_node_idx.into());
11961220
let new_dec_key = subset
11971221
.extract_key(&new, &new_dkg_sk, Some(&old))
11981222
.map_err(|e| DecrypterError::Dkg(format!("key extraction failed: {e}")))?;
@@ -1254,13 +1278,9 @@ impl Worker {
12541278
.map_err(|_: NetworkDown| EndOfPlay::NetworkDown)?;
12551279

12561280
// update keys if also member of next committee
1257-
if let Some(NextKey {
1258-
next_dkg_key,
1259-
next_dec_key,
1260-
}) = next_key
1261-
{
1262-
self.dec_key.set(next_dec_key.clone());
1263-
self.dkg_sk = next_dkg_key.clone();
1281+
if let Some(NextKey { dkg_key, dec_key }) = next_key {
1282+
self.dec_key.set(dec_key.clone());
1283+
self.dkg_sk = dkg_key.clone();
12641284
}
12651285
self.current = start.committee();
12661286
self.next_committee = NextCommittee::Del(*start);
@@ -1353,9 +1373,12 @@ fn serialize<T: Serialize>(d: &T) -> Result<Data> {
13531373
Ok(Data::try_from(b.into_inner())?)
13541374
}
13551375

1356-
fn deserialize<T: for<'de> serde::Deserialize<'de>>(d: &bytes::Bytes) -> Result<T> {
1376+
fn deserialize<B, T: for<'de> serde::Deserialize<'de>>(d: B) -> Result<T>
1377+
where
1378+
B: AsRef<[u8]>,
1379+
{
13571380
bincode::serde::decode_from_slice(
1358-
d,
1381+
d.as_ref(),
13591382
bincode::config::standard().with_limit::<MAX_MESSAGE_SIZE>(),
13601383
)
13611384
.map(|(msg, _)| msg)
@@ -1390,6 +1413,9 @@ pub enum DecrypterError {
13901413
#[error("unexpected internal err: {0}")]
13911414
Internal(String),
13921415

1416+
#[error("failed to join task: {0}")]
1417+
JoinErr(#[from] JoinError),
1418+
13931419
#[error("empty set of valid decryption shares")]
13941420
EmptyDecShares,
13951421

0 commit comments

Comments
 (0)