diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..6f3770c --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,99 @@ +# Architecture overview +This repository contains a gRPC server which coordinates multi-party threshold protocols. + +## Terminology +- A **client** is a regular client in the gRPC sense. +- A **device** is a **client** with an issued certificate. In some parts of the code, it might mean one **share** of a **participant** because those used to be synonymous. +- A **task** is an abstracted communication among multiple **participants** with multiple phases and a result if it completes successfully. +- A **participant** is a **device** and its **shares** within a **group** or a **task**. +- A **share** is a unit of computation and voting power in the protocol layer and on the client side. +- A **threshold** is the minimum number of accepting **shares** needed to start a **task**. It is mostly a property of a **group**, but is also used in the context of a **task**: + - The **threshold** of a **group task** is the threshold of its resulting **group**. + - The **threshold** of a **threshold task** is the threshold of the **group** within which it runs. +- A **voting task** is the first task phase after a **task**'s creation where the **task participants** can either accept or reject the **task**. A participant stays a participant even if it rejects the task, it receives the task result if the task finishes successfully. +- A **declined task** is a task phase which a **voting task** enters immediately once it's impossible to gather enough accepting votes. +- A **running task** is a task phase which a **voting task** enters once it gathers enough accepting votes and a **communicator** is created. It gives control to a **protocol**. +- A **failed task** is a task phase which is entered upon certain failures. +- A **finished task** is a task phase which a **running task** enters once its **protocol** successfully computes a result. +- A **communicator** is a bridge between **devices** and a **protocol**. It handles **message** gathering, relaying, broadcasting and **protocol** result collection. +- A **message** is serialized data exchanged among **active shares**. All messages pass through the server. Depending on the origin, they are categorized into **client** and **server** messages. Relayed messages are also bundled into **server messages**. +- A **group task** is a task which establishes a **group**. All its **participants** need to accept in order for the task to start. +- A **group** is an abstraction over a shared key and a **threshold**. +- A **threshold task** is a task which needs a **group** to be created and is an umbrella term for **sign**, **sign pdf** and **decrypt tasks**. In order for a threshold task to start, at least the **group**'s **threshold** number of **participants** need to accept. +- A **protocol** is an abstraction for actual multi-party threshold protocols. +- An **active share** is a **share** which participates in the computation of a **protocol**. +- A **protocol index** is the index of an **active share** and corresponds to the way **protocols** manage their active parties. The assignment of indices to shares is rather complicated, see [*Protocol index assignment*](#protocol-index-assignment). + +## Module structure +- `persistence` contains code related to the server persistence. +- `state` contains and manages all of the server's state. +- `interfaces` contains the modules `grpc` and `timer`, which define long-running services + - `grpc` provides the server's gRPC endpoints, handles client registration and certificates + - `timer` periodically runs checks over the state +- `task_store` manages the persistence and caching of tasks +- `task` contains the logic for task computation +- `protocol` contains the logic for protocol computation +- `communicator` defines the communicator +- `error` contains definitions of error variants +- `utils` contains a few helper functions + +## Persistence +Most of the server state is persisted throughout server restarts, but some state is deliberately ephemeral and kept only in the RAM. The ephemeral state is mostly data which changes "rapidly", namely activity timestamps and messages exchanged during protocol computation. + +Persistence is handled in the `state` module, with the exception of `task_store`, which is only used within `state`. This is to decouple the logic from bookkeeping. + +The `persistence` module is supposed to be a "dumb" interface for communicating with the DB. In particular, it shouldn't validate data, perform complex logic, etc... + +## State machines and state changes +Much of the actual logic can be easily modeled using state machines. We use typestates to enforce valid state transitions. For example, a running task cannot change into a voting task. + +Functions which update some state return a state change enum, which enforces handling of all possible situations and explicitly defines the logic. For example, saving a vote in a voting task can have three outcomes: +- The task is accepted and transitions into a running task +- The task is declined and transitions into a declined task +- The task does not have enough votes to determine an outcome and it stays as a voting task + +## Protocol index assignment +Multiple shares per device were implemented in an ad-hoc way and devices do not understand share indices. Instead, they accept a vector of *k* messages, one per active share, and they implicitly assign indices *[0..(k-1)]* to the messages. The index assignment algorithm thus has to more or less work like this: +1. Gather all candidate shares sorted by their corresponding device id - this lets us deterministically recover correct indices without persisting the index assignment. +2. Assign indices *[0..n]* to the sorted shares. +3. For each device, get the range of indices assigned to its shares. +4. Choose the active shares such that for each device, they are chosen from the start of its range. + +For example, consider a `3-of-1,2,3` setup with devices `A`, `B`, `C`. +1. Gather sorted candidate shares: `[A, B, B, C, C, C]` +2. Assign indices: `{0: A, 1: B, 2: B, 3: C, 4: C, 5: C}` +3. Get the index ranges per device: `{A: [0..0], B: [1..2], C: [3..5]}` +4. Choose 3 active shares from range beginnings: `{1: B, 3: C, 4: C}` + +==> The protocol indices are thus `[1, 3, 4]`. + +# Guides for common changes +This section provides guides for certain changes to the codebase which may be common. + +## Adding a new protocol type +Adding new protocols must be coordinated with the `meesign-crypto` repository. + +Protocols are defined throughout several places in the codebase: +1. The `proto/meesign.proto` files in this and `meesign-crypto` repositories define a `ProtocolType` enumeration. Both must be extended. +2. A few trait implementations for `ProtocolType` must be extended in `persistence/enums.rs`. +3. The `protocol_type` enum must be extended in the DB migrations. +4. A module should be added into `protocols`, similar to other protocols defined there. +5. The module needs to create a type implementing the `Protocol` trait for each of its variants, for example `Group`, `Sign`, ... +6. The module should use constants from `meesign_crypto::protocols::`. + +The overall structure should reflect the way other protocols are implemented. See the `protocols/frost.rs` module for example. + +## Adding a new task type +Adding new task types must be coordinated with the `meesign-client` repository. + +If the task follows the usual task phases (voting, declined, running, failed, finished), then it should follow the structure of other task types already established in this repo. Otherwise, it must be handled exceptionally. + +Here is a general process for when the new task type follows the usual task phases: +1. The `proto/meesign.proto` file defines a `TaskType` enumeration. It must be extended. +2. A few trait implementations for `TaskType` must be extended in `persistence/enums.rs`. +3. The `task_type` enum must be extended in the DB migrations. +4. The `RunningTaskContext` enum in `tasks/mod.rs` must be extended. +5. A module should be added into `tasks`, similar to other tasks defined there. +6. The module needs to create a type implementing the `RunningTask` trait. + +The overall structure should reflect the way other tasks are implemented. See the `tasks/sign.rs` module for example. diff --git a/migrations/2025-08-27-171451_initial_schema/down.sql b/migrations/2025-08-27-171451_initial_schema/down.sql index 2ba7be8..3af5e3c 100644 --- a/migrations/2025-08-27-171451_initial_schema/down.sql +++ b/migrations/2025-08-27-171451_initial_schema/down.sql @@ -1,3 +1,4 @@ +DROP TABLE active_task_participant; DROP TABLE task_participant; DROP TABLE task_result; DROP TABLE task; diff --git a/migrations/2025-08-27-171451_initial_schema/up.sql b/migrations/2025-08-27-171451_initial_schema/up.sql index 2eb10b0..3965648 100644 --- a/migrations/2025-08-27-171451_initial_schema/up.sql +++ b/migrations/2025-08-27-171451_initial_schema/up.sql @@ -84,3 +84,12 @@ CREATE TABLE task_participant ( "acknowledgment" boolean, PRIMARY KEY ("task_id", "device_id") ); + +CREATE TABLE active_task_participant ( + "task_id" uuid NOT NULL REFERENCES task("id"), + "device_id" bytea NOT NULL REFERENCES device("id"), + "active_shares" integer NOT NULL CHECK ("active_shares" > 0), + PRIMARY KEY ("task_id", "device_id"), + FOREIGN KEY ("task_id", "device_id") + REFERENCES task_participant("task_id", "device_id") +); diff --git a/src/communicator.rs b/src/communicator.rs index 51e38ae..8671d35 100644 --- a/src/communicator.rs +++ b/src/communicator.rs @@ -1,25 +1,15 @@ -use crate::persistence::PgPool; -use crate::persistence::{Device, Participant}; +use crate::persistence::Device; use crate::proto::ProtocolType; use meesign_crypto::auth::verify_broadcast; use meesign_crypto::proto::{ClientMessage, Message, ServerMessage}; -use rand::prelude::SliceRandom; -use rand::thread_rng; use std::collections::HashMap; -use tonic::codegen::Arc; /// Communication state of a Task pub struct Communicator { /// The minimal number of parties needed to successfully complete the task threshold: u32, - /// Ordered list of devices - device_list: Vec, - /// Ordered list of active devices (participating in the protocol) - active_devices: Option>>, - /// A mapping of device identifiers to their Task decision weight (0 - no decision, positive - accept, negative - reject) - decisions: HashMap, i8>, - /// A mapping of device identifiers to their Task acknowledgement - acknowledgements: HashMap, bool>, + /// A mapping of protocol indices to the active shares' devices + active_shares: HashMap, /// A mapping of protocol indices to incoming messages input: HashMap, /// A mapping of protocol indices to outgoing messages @@ -29,40 +19,29 @@ pub struct Communicator { } impl Communicator { - /// Constructs a new Communicator instance with given Participants, threshold, ProtocolType, decisions and acknowledgements + /// Constructs a new Communicator instance. /// /// # Arguments - /// * `participants` - List of distinct participants sorted by device id - /// * `threshold` - The minimal number of devices to successfully complete the task + /// * `threshold` - The minimal number of devices to successfully complete the task. + /// * `protocol_type` - The protocol type of the task. + /// * `active_shares` - A mapping of protocol indices to the active shares' devices. pub fn new( - participants: Vec, threshold: u32, protocol_type: ProtocolType, - decisions: HashMap, i8>, - acknowledgements: HashMap, bool>, + active_shares: HashMap, ) -> Self { - let device_list: Vec = participants - .into_iter() - .flat_map(|p| std::iter::repeat(p.device).take(p.shares as usize)) - .collect(); - - assert!(device_list.len() > 1); - assert!(threshold <= device_list.len() as u32); + assert!(active_shares.len() > 1); + assert!(threshold <= active_shares.len() as u32); // TODO uncomment once is_sorted is stabilized // assert!(devices.is_sorted()); - let mut communicator = Communicator { + Communicator { threshold, - device_list, - active_devices: None, - decisions, - acknowledgements, + active_shares, input: HashMap::new(), output: HashMap::new(), protocol_type, - }; - communicator.clear_input(); - communicator + } } /// Clears incoming message buffers @@ -74,15 +53,11 @@ impl Communicator { /// /// # Arguments /// - /// * `from_identifier` - identifier of the sender device - /// * `messages` - vector containing messages from each of the sender device's shares - pub fn receive_messages( - &mut self, - from_identifier: &[u8], - messages: Vec, - ) -> bool { - let from_indices = self.identifier_to_indices(from_identifier); - if messages.is_empty() || from_indices.len() != messages.len() { + /// * `sender_id` - identifier of the sender device + /// * `messages` - a message from each of the sender device's shares + pub fn receive_messages(&mut self, sender_id: &[u8], messages: Vec) -> bool { + let sender_indices = self.identifier_to_indices(sender_id); + if messages.is_empty() || sender_indices.len() != messages.len() { return false; } @@ -90,7 +65,7 @@ impl Communicator { assert!(msg.broadcast.is_some() || msg.unicasts.len() == self.threshold as usize - 1); } - self.input.extend(from_indices.into_iter().zip(messages)); + self.input.extend(sender_indices.into_iter().zip(messages)); true } @@ -153,43 +128,27 @@ impl Communicator { /// Check whether incoming buffers contain messages from all active devices pub fn round_received(&self) -> bool { - if self.active_devices.is_none() { - return false; - } - - self.input.len() == self.active_devices.as_ref().unwrap().len() + self.input.len() == self.active_shares.len() } /// Get all messages for a given device pub fn get_messages(&self, device_id: &[u8]) -> Vec> { self.identifier_to_indices(device_id) .iter() - .map(|idx| self.output.get(idx).map(Vec::clone).unwrap_or_default()) + .map(|idx| self.output.get(idx).cloned().unwrap_or_default()) .collect() } /// Get the final message pub fn get_final_message(&self) -> Option> { - if self.input.len() == 0 { + if self.input.is_empty() { return None; } - let active_devices = self.get_active_devices()?; - let protocol_indices = self.get_protocol_indices(); - let mut final_message = None; - for (&sender, msg) in &self.input { - let device_index = protocol_indices - .iter() - .position(|&idx| idx == sender) - .unwrap(); - let device = &active_devices[device_index]; - let cert_der = &self - .device_list - .iter() - .find(|dev| dev.identifier() == device) - .unwrap() - .certificate; + for (sender, msg) in &self.input { + let sender_index = sender - self.protocol_type.index_offset(); + let cert_der = &self.active_shares[&sender_index].certificate; // NOTE: Verify all signed broadcasts and check that the messages are all equal let msg = verify_broadcast(msg.broadcast.as_ref().unwrap(), cert_der).ok()?; @@ -202,170 +161,26 @@ impl Communicator { final_message } - /// Sets the active devices - /// - /// Picks which devices shall participate in the protocol - /// Considers only those devices which accepted participation - /// If enough devices are available, additionaly filters by response latency - pub fn set_active_devices(&mut self, pg_pool: Option>) -> Vec> { - assert!(self.accept_count() >= self.threshold); - let agreeing_devices = self - .device_list - .iter() - .filter(|device| self.decisions.get(device.identifier()) > Some(&0)) - .collect::>(); - - let connected_devices: Vec<_> = match pg_pool { - Some(_pg_pool) => { - todo!(); - //let latest_acceptable_time = Local::now() - Duration::seconds(5); - // agreeing_devices - // .iter() - // .filter(|device| device.last_active() > &latest_acceptable_time) - // .map(Deref::deref) - // .collect() - } - None => agreeing_devices.clone(), - }; - - let (devices, indices): (&Vec<&Device>, Vec<_>) = - if connected_devices.len() >= self.threshold as usize { - (&connected_devices, (0..connected_devices.len()).collect()) - } else { - (&agreeing_devices, (0..agreeing_devices.len()).collect()) - }; - let mut indices = indices - .choose_multiple(&mut thread_rng(), self.threshold as usize) - .cloned() - .collect::>(); - indices.sort(); - - self.active_devices = Some( - devices - .iter() - .enumerate() - .filter(|(idx, _)| indices.contains(idx)) - .map(|(_, device)| device.identifier().to_vec()) - .collect(), - ); - assert_eq!( - self.active_devices.as_ref().unwrap().len(), - self.threshold as usize - ); - - self.active_devices.as_ref().unwrap().clone() - } - - /// Get the active devices - pub fn get_active_devices(&self) -> Option>> { - self.active_devices.clone() - } - - /// Save a decision by the given device - /// - /// # Returns - /// `false` if the `device_id` is invalid or has already decided - /// `true` otherwise - pub fn decide(&mut self, device_id: &[u8], decision: bool) -> bool { - if !self.decisions.contains_key(device_id) || self.decisions[device_id] != 0 { - return false; - } - let votes = self - .device_list - .iter() - .filter(|x| x.identifier() == device_id) - .count() as i8; - self.decisions - .insert(device_id.to_vec(), if decision { votes } else { -votes }); - true - } - - /// Get the number of Task accepts - pub fn accept_count(&self) -> u32 { - self.decisions - .iter() - .filter(|x| *x.1 > 0) - .map(|x| *x.1 as i32) - .sum::() - .abs() as u32 - } - - /// Get the number of Task rejects - pub fn reject_count(&self) -> u32 { - self.decisions - .iter() - .filter(|x| *x.1 < 0) - .map(|x| *x.1 as i32) - .sum::() - .abs() as u32 - } - - /// Check whether a device submitted its decision - pub fn device_decided(&self, device_id: &[u8]) -> bool { - if let Some(d) = self.decisions.get(device_id) { - *d != 0 - } else { - false - } - } - - /// Save an acknowledgement by the given device - /// - /// # Returns - /// `false` if `device` is invalid or has already acknowledged this task's output - /// `true` otherwise - pub fn acknowledge(&mut self, device: &[u8]) -> bool { - if !self.acknowledgements.contains_key(device) || self.acknowledgements[device] { - return false; - } - self.acknowledgements.insert(device.to_vec(), true); - true - } - - /// Check whether a device has acknowledged this task's output - pub fn device_acknowledged(&self, device: &[u8]) -> bool { - *self.acknowledgements.get(device).unwrap_or(&false) - } - /// Get the protocol indices of active devices pub fn get_protocol_indices(&self) -> Vec { - assert!(self.active_devices.is_some()); - - let active_devices = self.get_active_devices().unwrap(); - let mut devices_iter = self.device_list.iter().enumerate(); - let mut indices: Vec = Vec::new(); - - for device in &active_devices { - while let Some((idx, dev)) = devices_iter.next() { - if dev.identifier() == device { - indices.push(idx as u32 + self.protocol_type.index_offset()); - break; - } - } - } - + let mut indices: Vec = self + .active_shares + .keys() + .map(|idx| *idx + self.protocol_type.index_offset()) + .collect(); + indices.sort(); indices } /// Get the protocol indices of an active device pub fn identifier_to_indices(&self, device_id: &[u8]) -> Vec { - if self.active_devices.is_none() { - return Vec::new(); - } - - let mut devices_iter = self.device_list.iter().enumerate(); - let mut indices = Vec::new(); - - for device in self.get_active_devices().unwrap() { - if device == device_id { - let (idx, _) = devices_iter - .find(|(_, dev)| dev.identifier() == &device) - .unwrap(); - - indices.push(idx as u32 + self.protocol_type.index_offset()); - } - } - + let mut indices: Vec = self + .active_shares + .iter() + .filter(|(_, device)| device.id == device_id) + .map(|(idx, _)| *idx + self.protocol_type.index_offset()) + .collect(); + indices.sort(); indices } } @@ -379,101 +194,23 @@ mod tests { #[test] #[should_panic] fn communicator_with_no_devices() { - new_communicator(vec![], 0, ProtocolType::Gg18); + new_communicator(&[], 0, ProtocolType::Gg18); } #[test] #[should_panic] fn communicator_too_large_threshold() { - new_communicator(prepare_participants(2), 3, ProtocolType::Gg18); - } - - #[test] - fn empty_communicator() { - let participants = prepare_participants(5); - let d0 = participants[0].device.identifier().clone(); - let communicator = new_communicator(participants, 3, ProtocolType::Gg18); - assert_eq!(communicator.accept_count(), 0); - assert_eq!(communicator.reject_count(), 0); - assert_eq!(communicator.round_received(), false); - assert_eq!(communicator.get_messages(&d0), Vec::>::new()); - assert_eq!( - communicator.get_messages(&[0x00, 0x00]), - Vec::>::new() - ); - assert_eq!(communicator.device_decided(&d0), false); - assert_eq!(communicator.device_decided(&[0x00, 0x00]), false); - assert_eq!(communicator.waiting_for(&d0), false); - assert_eq!(communicator.waiting_for(&[0x00, 0x00]), false); - assert_eq!(communicator.get_active_devices(), None); - assert_eq!(communicator.get_final_message(), None); + new_communicator(&[(true, 1), (true, 1)], 3, ProtocolType::Gg18); } #[test] fn valid_communicator() { - let participants = prepare_participants(5); - let mut communicator = new_communicator(participants.clone(), 3, ProtocolType::Gg18); - assert_eq!( - communicator.device_decided(participants[0].device.identifier()), - false - ); - communicator.decide(participants[0].device.identifier(), true); - assert_eq!(communicator.accept_count(), 1); - assert_eq!(communicator.reject_count(), 0); - assert_eq!( - communicator.device_decided(participants[0].device.identifier()), - true - ); - assert_eq!( - communicator.device_decided(participants[2].device.identifier()), - false - ); - communicator.decide(participants[2].device.identifier(), false); - assert_eq!(communicator.accept_count(), 1); - assert_eq!(communicator.reject_count(), 1); - assert_eq!( - communicator.device_decided(participants[2].device.identifier()), - true - ); - assert_eq!( - communicator.device_decided(participants[4].device.identifier()), - false - ); - communicator.decide(participants[4].device.identifier(), true); - assert_eq!(communicator.accept_count(), 2); - assert_eq!(communicator.reject_count(), 1); - assert_eq!( - communicator.device_decided(participants[4].device.identifier()), - true - ); - assert_eq!( - communicator.device_decided(participants[1].device.identifier()), - false - ); - communicator.decide(participants[1].device.identifier(), true); - assert_eq!(communicator.accept_count(), 3); - assert_eq!(communicator.reject_count(), 1); - assert_eq!( - communicator.device_decided(participants[1].device.identifier()), - true - ); - assert_eq!( - communicator.device_decided(participants[3].device.identifier()), - false + let mut communicator = new_communicator( + &[(true, 1), (true, 1), (false, 1), (false, 1), (true, 1)], + 3, + ProtocolType::Gg18, ); - assert_eq!(communicator.get_active_devices(), None); - communicator.set_active_devices(None); let active_indices = [0, 1, 4]; - assert_eq!( - communicator.get_active_devices(), - Some( - active_indices - .iter() - .map(|idx| participants[*idx].device.identifier().to_vec()) - .collect() - ) - ); - assert_eq!( &communicator .get_protocol_indices() @@ -482,18 +219,11 @@ mod tests { .collect::>(), &active_indices ); - - for idx in 0..participants.len() { - assert_eq!( - communicator.waiting_for(participants[idx].device.identifier()), - active_indices.contains(&idx) - ); - } assert_eq!(communicator.round_received(), false); - for idx in 0..participants.len() { + for idx in 0..5 { assert_eq!( communicator.receive_messages( - participants[idx].device.identifier(), + &[idx as u8], vec![ClientMessage { protocol_type: 0, unicasts: active_indices @@ -513,8 +243,8 @@ mod tests { } assert_eq!(communicator.round_received(), true); - for idx in 0..participants.len() { - let msgs = communicator.get_messages(participants[idx].device.identifier()); + for idx in 0..5 { + let msgs = communicator.get_messages(&[idx as u8]); let expected: Vec> = if active_indices.contains(&idx) { vec![vec![]] } else { @@ -523,83 +253,28 @@ mod tests { assert_eq!(msgs, expected); } communicator.relay(); - for idx in 0..participants.len() { + for idx in 0..5 { assert_eq!( - !communicator - .get_messages(participants[idx].device.identifier()) - .is_empty(), + !communicator.get_messages(&[idx as u8]).is_empty(), active_indices.contains(&idx) ); } assert_eq!(communicator.round_received(), false); - for participant in participants { - assert_eq!( - communicator.device_acknowledged(participant.device.identifier()), - false - ); - assert_eq!( - communicator.acknowledge(participant.device.identifier()), - true - ); - assert_eq!( - communicator.device_acknowledged(participant.device.identifier()), - true - ); - } - } - - #[test] - fn unknown_device_decide() { - let participants = prepare_participants(3); - let mut communicator = new_communicator( - participants.iter().cloned().take(2).collect(), - 2, - ProtocolType::Gg18, - ); - assert_eq!( - communicator.decide(participants[2].device.identifier(), true), - false - ); - } - - #[test] - fn repeated_device_decide() { - let participants = prepare_participants(2); - let mut communicator = new_communicator(participants.clone(), 2, ProtocolType::Gg18); - assert_eq!( - communicator.decide(participants[0].device.identifier(), true), - true - ); - assert_eq!( - communicator.decide(participants[0].device.identifier(), true), - false - ); } #[test] fn repeated_devices() { - let participants = prepare_participants(1); - let participants = vec![participants[0].clone(), participants[0].clone()]; - let mut communicator = new_communicator(participants.clone(), 2, ProtocolType::Gg18); - assert_eq!( - communicator.decide(participants[0].device.identifier(), true), - true - ); - communicator.set_active_devices(None); + let communicator = new_communicator(&[(true, 2)], 2, ProtocolType::Gg18); assert_eq!(communicator.get_protocol_indices(), vec![0, 1]); } #[test] #[should_panic] fn not_enough_messages() { - let participants = prepare_participants(3); - let mut communicator = new_communicator(participants.clone(), 3, ProtocolType::Gg18); - communicator.decide(participants[0].device.identifier(), true); - communicator.decide(participants[1].device.identifier(), true); - communicator.decide(participants[2].device.identifier(), true); - communicator.set_active_devices(None); + let mut communicator = + new_communicator(&[(true, 1), (true, 1), (true, 1)], 3, ProtocolType::Gg18); communicator.receive_messages( - participants[0].device.identifier(), + &[0], vec![ClientMessage { protocol_type: 0, unicasts: HashMap::new(), @@ -611,14 +286,10 @@ mod tests { #[test] #[should_panic] fn too_many_messages() { - let participants = prepare_participants(3); - let mut communicator = new_communicator(participants.clone(), 3, ProtocolType::Gg18); - communicator.decide(participants[0].device.identifier(), true); - communicator.decide(participants[1].device.identifier(), true); - communicator.decide(participants[2].device.identifier(), true); - communicator.set_active_devices(None); + let mut communicator = + new_communicator(&[(true, 1), (true, 1), (true, 1)], 3, ProtocolType::Gg18); communicator.receive_messages( - participants[0].device.identifier(), + &[0], vec![ClientMessage { protocol_type: 0, unicasts: (0..6 as u32).map(|i| (i, vec![])).collect(), @@ -627,77 +298,21 @@ mod tests { ); } - #[test] - #[should_panic] - fn not_enough_accepts() { - let participants = prepare_participants(5); - let mut communicator = new_communicator(participants.clone(), 3, ProtocolType::Gg18); - communicator.decide(participants[0].device.identifier(), true); - communicator.decide(participants[2].device.identifier(), false); - communicator.decide(participants[4].device.identifier(), true); - communicator.set_active_devices(None); - } - - #[test] - fn more_than_threshold_accepts() { - let threshold = 3; - let participants = prepare_participants(5); - let mut communicator = - new_communicator(participants.clone(), threshold, ProtocolType::Gg18); - for participant in participants { - communicator.decide(participant.device.identifier(), true); - } - communicator.set_active_devices(None); - assert_eq!( - communicator.get_active_devices().as_ref().map(Vec::len), - Some(threshold as usize) - ); - } - #[test] fn send_all() { - let participants = prepare_participants(3); - let mut communicator = new_communicator(participants.clone(), 2, ProtocolType::Gg18); - communicator.decide(participants[0].device.identifier(), true); - communicator.decide(participants[2].device.identifier(), true); - communicator.set_active_devices(None); - assert_eq!( - communicator.get_active_devices(), - Some(vec![ - participants[0].device.identifier().to_vec(), - participants[2].device.identifier().to_vec() - ]) - ); + let mut communicator = + new_communicator(&[(true, 1), (false, 1), (true, 1)], 2, ProtocolType::Gg18); communicator.send_all(|idx| vec![idx as u8]); - assert_eq!( - communicator.get_messages(participants[0].device.identifier()), - vec![vec![0]] - ); - assert_eq!( - communicator.get_messages(participants[1].device.identifier()), - Vec::>::new() - ); - assert_eq!( - communicator.get_messages(participants[2].device.identifier()), - vec![vec![2]] - ); + assert_eq!(communicator.get_messages(&[0]), vec![vec![0]]); + assert_eq!(communicator.get_messages(&[1]), Vec::>::new()); + assert_eq!(communicator.get_messages(&[2]), vec![vec![2]]); } #[test] fn protocol_init() { use meesign_crypto::proto::ProtocolInit; - let participants = prepare_participants(3); - let mut communicator = new_communicator(participants.clone(), 2, ProtocolType::Frost); - communicator.decide(participants[0].device.identifier(), true); - communicator.decide(participants[2].device.identifier(), true); - communicator.set_active_devices(None); - assert_eq!( - communicator.get_active_devices(), - Some(vec![ - participants[0].device.identifier().to_vec(), - participants[2].device.identifier().to_vec() - ]) - ); + let mut communicator = + new_communicator(&[(true, 1), (false, 1), (true, 1)], 2, ProtocolType::Frost); communicator.send_all(|idx| { ProtocolInit { protocol_type: ProtocolType::Frost as i32, @@ -708,7 +323,7 @@ mod tests { .encode_to_vec() }); assert_eq!( - communicator.get_messages(participants[0].device.identifier()), + communicator.get_messages(&[0]), vec![ProtocolInit { protocol_type: ProtocolType::Frost as i32, indices: Vec::new(), @@ -717,12 +332,9 @@ mod tests { } .encode_to_vec()] ); + assert_eq!(communicator.get_messages(&[1]), Vec::new() as Vec>); assert_eq!( - communicator.get_messages(participants[1].device.identifier()), - Vec::new() as Vec> - ); - assert_eq!( - communicator.get_messages(participants[2].device.identifier()), + communicator.get_messages(&[2]), vec![ProtocolInit { protocol_type: ProtocolType::Frost as i32, indices: Vec::new(), @@ -733,50 +345,17 @@ mod tests { ); } - #[test] - fn unknown_device_acknowledgement() { - let participants = prepare_participants(3); - let mut communicator = new_communicator( - participants.iter().cloned().take(2).collect(), - 2, - ProtocolType::Gg18, - ); - assert_eq!( - communicator.acknowledge(participants[2].device.identifier()), - false - ); - } - - #[test] - fn repeated_device_acknowledgement() { - let participants = prepare_participants(2); - let mut communicator = new_communicator(participants.clone(), 2, ProtocolType::Gg18); - assert_eq!( - communicator.acknowledge(participants[0].device.identifier()), - true - ); - assert_eq!( - communicator.acknowledge(participants[0].device.identifier()), - false - ); - } - #[test] fn broadcast_messages() { - let participants = prepare_participants(3); - let mut communicator = new_communicator(participants.clone(), 2, ProtocolType::Frost); - - communicator.decide(participants[0].device.identifier(), true); - communicator.decide(participants[1].device.identifier(), true); - communicator.decide(participants[2].device.identifier(), false); - communicator.set_active_devices(None); + let mut communicator = + new_communicator(&[(true, 1), (true, 1), (false, 1)], 2, ProtocolType::Frost); assert_eq!(communicator.get_protocol_indices(), vec![1, 2]); for i in 0..2 { assert_eq!( communicator.receive_messages( - participants[i].device.identifier(), + &[i as u8], vec![ClientMessage { protocol_type: ProtocolType::Frost.into(), unicasts: HashMap::new(), @@ -788,12 +367,10 @@ mod tests { } assert_eq!(communicator.round_received(), true); - eprintln!("input: {:?}", communicator.input); communicator.relay(); - eprintln!("output: {:?}", communicator.output); assert_eq!( - communicator.get_messages(participants[0].device.identifier()), + communicator.get_messages(&[0]), vec![ServerMessage { protocol_type: ProtocolType::Frost.into(), unicasts: HashMap::new(), @@ -802,7 +379,7 @@ mod tests { .encode_to_vec()], ); assert_eq!( - communicator.get_messages(participants[1].device.identifier()), + communicator.get_messages(&[1]), vec![ServerMessage { protocol_type: ProtocolType::Frost.into(), unicasts: HashMap::new(), @@ -813,39 +390,26 @@ mod tests { } fn new_communicator( - participants: Vec, + decisions_shares: &[(bool, u32)], threshold: u32, protocol_type: ProtocolType, ) -> Communicator { - let decisions = participants - .iter() - .map(|p| (p.device.identifier().clone(), 0)) - .collect(); - let acknowledgements = participants - .iter() - .map(|p| (p.device.identifier().clone(), false)) - .collect(); - Communicator::new( - participants, - threshold, - protocol_type, - decisions, - acknowledgements, - ) - } - - fn prepare_participants(n: usize) -> Vec { - assert!(n < u8::MAX as usize); - (0..n) - .map(|i| { + let active_shares = decisions_shares + .into_iter() + .enumerate() + .flat_map(|(idx, &(accept, shares))| { let device = Device::new( - vec![i as u8], - format!("d{}", i), + vec![idx as u8], + format!("d{}", idx), DeviceKind::User, - vec![0xf0 | i as u8], + vec![0xf0 | idx as u8], ); - Participant { device, shares: 1 } + std::iter::repeat_n((accept, device), shares as usize) }) - .collect() + .enumerate() + .filter(|(_, (accept, _))| *accept) + .map(|(idx, (_, device))| (idx as u32, device)) + .collect(); + Communicator::new(threshold, protocol_type, active_shares) } } diff --git a/src/group.rs b/src/group.rs deleted file mode 100644 index cb082e2..0000000 --- a/src/group.rs +++ /dev/null @@ -1,209 +0,0 @@ -use crate::persistence::{Group as GroupModel, Participant}; -use crate::proto::{KeyType, ProtocolType}; -#[derive(Clone)] -pub struct Group { - identifier: Vec, - name: String, - participants: Vec, - threshold: u32, - protocol: ProtocolType, - key_type: KeyType, - certificate: Option>, - note: Option, -} - -impl Group { - pub fn new( - identifier: Vec, - name: String, - threshold: u32, - participants: Vec, - protocol: ProtocolType, - key_type: KeyType, - certificate: Option>, - note: Option, - ) -> Self { - assert!(!identifier.is_empty()); - assert!(threshold >= 1); - Group { - identifier, - name, - participants, - threshold, - protocol, - key_type, - certificate, - note, - } - } - - pub fn identifier(&self) -> &[u8] { - &self.identifier - } - - pub fn name(&self) -> &str { - &self.name - } - pub fn participants(&self) -> &Vec { - &self.participants - } - pub fn threshold(&self) -> u32 { - self.threshold - } - - pub fn reject_threshold(&self) -> u32 { - let total_parties: u32 = self.participants.iter().map(|p| p.shares).sum(); - total_parties - self.threshold + 1 // rejects >= threshold_reject => fail - } - - pub fn protocol(&self) -> ProtocolType { - self.protocol - } - - pub fn key_type(&self) -> KeyType { - self.key_type - } - - pub fn certificate(&self) -> Option<&Vec> { - self.certificate.as_ref() - } - - pub fn note(&self) -> Option<&str> { - self.note.as_deref() - } - - // TODO: consider merging Group with GroupModel - pub fn from_model(value: GroupModel, participants: Vec) -> Self { - Self { - identifier: value.id, - name: value.name, - threshold: value.threshold as u32, - participants, - protocol: value.protocol.into(), - key_type: value.key_type.into(), - certificate: value.certificate, - note: value.note, - } - } -} - -// impl From<&Group> for crate::proto::Group { -// fn from(group: &Group) -> Self { -// crate::proto::Group { -// identifier: group.identifier().to_vec(), -// name: group.name().to_owned(), -// threshold: group.threshold(), -// device_ids: group -// .devices() -// .iter() -// .map(|x| x.identifier()) -// .map(Vec::from) -// .collect(), -// protocol: group.protocol().into(), -// key_type: group.key_type().into(), -// note: group.note().map(String::from), -// } -// } -// } - -#[cfg(test)] -mod tests { - use crate::persistence::{Device, DeviceKind}; - use std::vec; - - use super::*; - - #[test] - #[should_panic] - fn empty_identifier() { - Group::new( - vec![], - String::from("Sample Group"), - 2, - vec![], - ProtocolType::Gg18, - KeyType::SignPdf, - None, - None, - ); - } - - // #[test] - // fn protobuf_group() { - // let group = Group::new( - // vec![0x00], - // String::from("Sample Group"), - // prepare_devices(3), - // 2, - // ProtocolType::Gg18, - // KeyType::SignPdf, - // None, - // None, - // ); - // let protobuf = crate::proto::Group::from(&group); - // assert_eq!(protobuf.identifier, group.identifier()); - // assert_eq!(protobuf.name, group.name()); - // assert_eq!(protobuf.threshold, group.threshold()); - // assert_eq!( - // protobuf.device_ids, - // group - // .devices() - // .iter() - // .map(|device| device.identifier()) - // .map(Vec::from) - // .collect::>() - // ); - // assert_eq!(protobuf.protocol, group.protocol() as i32); - // assert_eq!(protobuf.key_type, group.key_type() as i32); - // } - - #[test] - fn sample_group() { - let identifier = vec![0x01, 0x02, 0x03, 0x04]; - let name = String::from("Sample Group"); - let mut participants = prepare_participants(6); - let extra_participant = participants.pop().unwrap(); - let threshold = 3; - let protocol_type = ProtocolType::Gg18; - let key_type = KeyType::SignPdf; - let group = Group::new( - identifier.clone(), - name.clone(), - threshold, - participants.clone(), - protocol_type, - key_type, - None, - Some("time policy".into()), - ); - assert_eq!(group.identifier(), &identifier); - assert_eq!(group.name(), &name); - assert_eq!(group.threshold(), threshold); - assert_eq!(group.reject_threshold(), 3); - for (a, b) in group.participants().iter().zip(participants.iter()) { - assert_eq!(a.device.identifier(), b.device.identifier()); - } - assert!(!group - .participants() - .iter() - .any(|p| p.device.identifier() == extra_participant.device.identifier())); - assert_eq!(group.protocol(), protocol_type.into()); - assert_eq!(group.key_type(), key_type.into()); - assert_eq!(group.certificate(), None); - } - - fn prepare_participants(n: usize) -> Vec { - assert!(n < u8::MAX as usize); - (0..n) - .map(|i| { - let device = Device::new( - vec![i as u8], - format!("d{}", i), - DeviceKind::User, - vec![0xf0 | i as u8], - ); - Participant { device, shares: 1 } - }) - .collect() - } -} diff --git a/src/interfaces/grpc.rs b/src/interfaces/grpc.rs index de4081a..ab4b48f 100644 --- a/src/interfaces/grpc.rs +++ b/src/interfaces/grpc.rs @@ -9,7 +9,6 @@ use openssl::x509::extension::{ use openssl::x509::{X509Builder, X509NameBuilder, X509Req}; use rand::Rng; use tokio::sync::mpsc; -use tokio::sync::Mutex; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::Stream; use tonic::codegen::Arc; @@ -22,14 +21,15 @@ use crate::proto::{Group, KeyType, MeeSign, MeeSignServer, ProtocolType}; use crate::state::State; use crate::{proto as msg, utils, CA_CERT, CA_KEY}; +use meesign_crypto::proto::{ClientMessage, Message as _}; use std::pin::Pin; pub struct MeeSignService { - state: Arc>, + state: Arc, } impl MeeSignService { - pub fn new(state: Arc>) -> Self { + pub fn new(state: Arc) -> Self { MeeSignService { state } } @@ -39,8 +39,8 @@ impl MeeSignService { required: bool, ) -> Result<(), Status> { if let Some(certs) = certs { - let device_id = certs.get(0).map(cert_to_id).unwrap_or(vec![]); - if !self.state.lock().await.device_exists(&device_id).await? { + let device_id = certs.first().map(cert_to_id).unwrap_or(vec![]); + if !self.state.device_exists(&device_id) { return Err(Status::unauthenticated("Unknown device certificate")); } } else if required { @@ -80,11 +80,10 @@ impl MeeSign for MeeSignService { let kind = DeviceKind::User; // TODO info!("RegistrationRequest name={:?}", name); - let state = self.state.lock().await; - if let Ok(certificate) = issue_certificate(&name, &csr) { let identifier = cert_to_id(&certificate); - match state + match self + .state .add_device(&identifier, &name, &kind, &certificate) .await { @@ -115,10 +114,7 @@ impl MeeSign for MeeSignService { let data = request.data; info!("SignRequest group_id={}", utils::hextrunc(&group_id)); - let mut state = self.state.lock().await; - let task_id = state.add_sign_task(&group_id, &name, &data).await?; - let task_model = state.get_task(&task_id).await?; - let task = state.format_task(task_model, None, None).await?; + let task = self.state.add_sign_task(&group_id, &name, &data).await?; Ok(Response::new(task)) } @@ -135,12 +131,10 @@ impl MeeSign for MeeSignService { let data_type = request.data_type; info!("DecryptRequest group_id={}", utils::hextrunc(&group_id)); - let mut state = self.state.lock().await; - let task_id = state + let task = self + .state .add_decrypt_task(&group_id, &name, &data, &data_type) .await?; - let task_model = state.get_task(&task_id).await?; - let task = state.format_task(task_model, None, None).await?; Ok(Response::new(task)) } @@ -152,26 +146,20 @@ impl MeeSign for MeeSignService { let request = request.into_inner(); let task_id = Uuid::from_slice(&request.task_id).unwrap(); - let device_id = request.device_id; - let device_id = if device_id.is_none() { - None - } else { - Some(device_id.as_ref().unwrap().as_slice()) - }; + let device_id = request.device_id.as_deref(); debug!( "TaskRequest task_id={} device_id={}", utils::hextrunc(task_id.as_bytes()), utils::hextrunc(device_id.unwrap_or(&[])) ); - let state = self.state.lock().await; - if device_id.is_some() { - state.activate_device(device_id.unwrap()); + if let Some(device_id) = device_id { + self.state.activate_device(device_id); } - let task_model = state.get_task(&task_id).await?; - let request = Some(task_model.request.clone()); - - let task = state.format_task(task_model, device_id, request).await?; + let task = self + .state + .get_formatted_voting_task(&task_id, device_id) + .await?; Ok(Response::new(task)) } @@ -183,7 +171,7 @@ impl MeeSign for MeeSignService { let device_id = request .peer_certs() - .and_then(|certs| certs.get(0).map(cert_to_id)) + .and_then(|certs| certs.first().map(cert_to_id)) .unwrap(); let request = request.into_inner(); @@ -206,10 +194,16 @@ impl MeeSign for MeeSignService { attempt ); - let mut state = self.state.lock().await; - state.activate_device(&device_id); - let result = state - .update_task(&task_id, &device_id, &data, attempt) + let messages = data + .into_iter() + .map(|bytes| ClientMessage::decode(bytes.as_slice())) + .collect::, _>>() + .map_err(|_| Status::invalid_argument("Invalid ClientMessage data."))?; + + self.state.activate_device(&device_id); + let result = self + .state + .update_task(&task_id, &device_id, messages, attempt) .await; match result { @@ -241,22 +235,15 @@ impl MeeSign for MeeSignService { .unwrap_or_else(|| "unknown".to_string()); debug!("TasksRequest device_id={}", device_str); - let state = self.state.lock().await; - - let task_models = if let Some(device_id) = &device_id { - state.activate_device(device_id); - state.get_active_device_tasks(device_id).await? + let tasks = if let Some(device_id) = &device_id { + self.state.activate_device(device_id); + self.state + .get_formatted_active_device_tasks(device_id) + .await? } else { - state.get_tasks().await? + self.state.get_formatted_tasks().await? }; - let mut tasks = Vec::new(); - for task_model in task_models { - let task = state - .format_task(task_model, device_id.as_deref(), None) - .await?; - tasks.push(task); - } Ok(Response::new(msg::Tasks { tasks })) } @@ -274,18 +261,17 @@ impl MeeSign for MeeSignService { .unwrap_or_else(|| "unknown".to_string()); debug!("GroupsRequest device_id={}", device_str); - let state = self.state.lock().await; // TODO: refactor, consider storing device IDS in the group model directly let groups = if let Some(device_id) = device_id { - state.activate_device(&device_id); - state + self.state.activate_device(&device_id); + self.state .get_device_groups(&device_id) .await? .into_iter() .map(Group::from_model) .collect() } else { - state + self.state .get_groups() .await? .into_iter() @@ -324,8 +310,8 @@ impl MeeSign for MeeSignService { .iter() .map(|device_id| device_id.as_ref()) .collect(); - let mut state = self.state.lock().await; - match state + match self + .state .add_group_task( &name, &device_id_references, @@ -336,13 +322,7 @@ impl MeeSign for MeeSignService { ) .await { - Ok(task_id) => { - state.send_updates(&task_id).await?; - // TODO: use group task - let task_model = state.get_task(&task_id).await?; - let task = state.format_task(task_model, None, None).await?; - Ok(Response::new(task)) - } + Ok(task) => Ok(Response::new(task)), Err(err) => { error!("{}", err); Err(Status::failed_precondition("Request failed")) @@ -361,10 +341,7 @@ impl MeeSign for MeeSignService { let resp = msg::Devices { devices: self .state - .lock() - .await .get_devices() - .await? .into_iter() .map(|(device, last_active)| msg::Device { identifier: device.id, @@ -383,7 +360,7 @@ impl MeeSign for MeeSignService { let device_id = request .peer_certs() - .and_then(|certs| certs.get(0).map(cert_to_id)); + .and_then(|certs| certs.first().map(cert_to_id)); let device_str = device_id .as_ref() @@ -393,10 +370,7 @@ impl MeeSign for MeeSignService { debug!("LogRequest device_id={} message={}", device_str, message); if device_id.is_some() { - self.state - .lock() - .await - .activate_device(device_id.as_ref().unwrap()); + self.state.activate_device(device_id.as_ref().unwrap()); } Ok(Response::new(msg::Resp { @@ -412,7 +386,7 @@ impl MeeSign for MeeSignService { let device_id = request .peer_certs() - .and_then(|certs| certs.get(0).map(cert_to_id)) + .and_then(|certs| certs.first().map(cert_to_id)) .unwrap(); let request = request.into_inner(); @@ -426,19 +400,15 @@ impl MeeSign for MeeSignService { accept ); - let state = self.state.clone(); - tokio::task::spawn(async move { - let mut state = state.lock().await; - state.activate_device(&device_id); - if let Err(err) = state.decide_task(&task_id, &device_id, accept).await { - error!( - "Couldn't decide task {} for device {}: {}", - task_id, - utils::hextrunc(&device_id), - err - ); - } - }); + self.state.activate_device(&device_id); + if let Err(err) = self.state.decide_task(&task_id, &device_id, accept).await { + error!( + "Couldn't decide task {} for device {}: {}", + task_id, + utils::hextrunc(&device_id), + err + ); + } Ok(Response::new(msg::Resp { message: "OK".into(), @@ -453,7 +423,7 @@ impl MeeSign for MeeSignService { let device_id = request .peer_certs() - .and_then(|certs| certs.get(0).map(cert_to_id)) + .and_then(|certs| certs.first().map(cert_to_id)) .unwrap(); let task_id = request.into_inner().task_id; @@ -464,11 +434,10 @@ impl MeeSign for MeeSignService { utils::hextrunc(&device_id) ); - let mut state = self.state.lock().await; - state.activate_device(&device_id); + self.state.activate_device(&device_id); let task_id = Uuid::from_slice(&task_id).unwrap(); - if let Err(err) = state.acknowledge_task(&task_id, &device_id).await { + if let Err(err) = self.state.acknowledge_task(&task_id, &device_id).await { error!( "Couldn't acknowledge task {} for device {}: {}", task_id, @@ -490,12 +459,12 @@ impl MeeSign for MeeSignService { let device_id = request .peer_certs() - .and_then(|certs| certs.get(0).map(cert_to_id)) + .and_then(|certs| certs.first().map(cert_to_id)) .unwrap(); let (tx, rx) = mpsc::channel(8); - self.state.lock().await.add_subscriber(device_id, tx); + self.state.add_subscriber(device_id, tx); Ok(Response::new(Box::pin(ReceiverStream::new(rx)))) } @@ -584,7 +553,7 @@ pub fn cert_to_id(cert: impl AsRef<[u8]>) -> Vec { sha2::Sha256::digest(cert).to_vec() } -pub async fn run_grpc(state: Arc>, addr: &str, port: u16) -> Result<(), String> { +pub async fn run_grpc(state: Arc, addr: &str, port: u16) -> Result<(), String> { let addr = format!("{}:{}", addr, port) .parse() .map_err(|_| String::from("Unable to parse server address"))?; diff --git a/src/interfaces/timer.rs b/src/interfaces/timer.rs index 367ab2d..4ed86b6 100644 --- a/src/interfaces/timer.rs +++ b/src/interfaces/timer.rs @@ -3,28 +3,23 @@ use crate::state::State; use crate::utils; use log::debug; -use tokio::sync::MutexGuard; -use tokio::{sync::Mutex, time}; +use tokio::time; use tonic::codegen::Arc; -pub async fn run_timer(state: Arc>) -> Result<(), String> { +pub async fn run_timer(state: Arc) -> Result<(), String> { let mut interval = time::interval(time::Duration::from_secs(1)); loop { interval.tick().await; - let mut state = state.lock().await; - check_tasks(&mut state).await.unwrap(); - check_subscribers(&mut state).await; + check_tasks(&state).await.unwrap(); + check_subscribers(&state).await; } } -async fn check_tasks(state: &mut MutexGuard<'_, State>) -> Result<(), Error> { - for task_id in state.get_tasks_for_restart().await? { - state.restart_task(&task_id).await?; - } - Ok(()) +async fn check_tasks(state: &State) -> Result<(), Error> { + state.restart_stale_tasks().await } -async fn check_subscribers(state: &mut MutexGuard<'_, State>) { +async fn check_subscribers(state: &State) { let mut remove = Vec::new(); for subscriber in state.get_subscribers().iter() { let (device_id, tx) = subscriber.pair(); @@ -35,7 +30,7 @@ async fn check_subscribers(state: &mut MutexGuard<'_, State>) { ); remove.push(device_id.clone()); } else { - state.activate_device(&device_id); + state.activate_device(device_id); } } for device_id in remove { diff --git a/src/main.rs b/src/main.rs index afe6509..4bd1f32 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,16 +9,16 @@ use openssl::x509::X509; use persistence::Repository; use crate::state::State; -use tokio::{sync::Mutex, try_join}; +use tokio::try_join; use tonic::codegen::Arc; mod communicator; mod error; -mod group; mod interfaces; mod persistence; mod protocols; mod state; +mod task_store; mod tasks; mod utils; @@ -61,37 +61,6 @@ mod proto { } } - impl From for crate::persistence::TaskType { - fn from(task_type: TaskType) -> Self { - match task_type { - TaskType::Group => Self::Group, - TaskType::SignChallenge => Self::SignChallenge, - TaskType::SignPdf => Self::SignPdf, - TaskType::Decrypt => Self::Decrypt, - } - } - } - - impl Into for crate::persistence::TaskType { - fn into(self) -> TaskType { - match self { - Self::Group => TaskType::Group, - Self::SignChallenge => TaskType::SignChallenge, - Self::SignPdf => TaskType::SignPdf, - Self::Decrypt => TaskType::Decrypt, - } - } - } - - impl Into for crate::persistence::DeviceKind { - fn into(self) -> DeviceKind { - match self { - Self::User => DeviceKind::User, - Self::Bot => DeviceKind::Bot, - } - } - } - impl Task { pub fn created( id: Vec, @@ -113,6 +82,26 @@ mod proto { attempt, } } + pub fn declined( + id: Vec, + r#type: i32, + accept: u32, + reject: u32, + request: Option>, + attempt: u32, + ) -> Self { + Self { + id, + r#type, + state: task::TaskState::Failed.into(), + round: 0, + accept, + reject, + data: vec!["Task declined".to_string().into_bytes()], + request, + attempt, + } + } pub fn running( id: Vec, r#type: i32, @@ -155,9 +144,6 @@ mod proto { pub fn failed( id: Vec, r#type: i32, - round: u32, - accept: u32, - reject: u32, reason: String, request: Option>, attempt: u32, @@ -166,9 +152,9 @@ mod proto { id, r#type, state: task::TaskState::Failed.into(), - round, - accept, - reject, + round: u32::MAX, + accept: u32::MAX, + reject: 0, data: vec![reason.into_bytes()], request, attempt, @@ -246,8 +232,11 @@ async fn main() -> Result<(), String> { .await .expect("Coudln't init postgres repo"); repo.apply_migrations().expect("Couldn't apply migrations"); + let state = State::restore(Arc::new(repo)) + .await + .expect("Couldn't initialize State"); // TODO: remove mutex when DB done - let state = Arc::new(Mutex::new(State::new(Arc::new(repo)))); + let state = Arc::new(state); let grpc = interfaces::grpc::run_grpc(state.clone(), &args.addr, args.port); let timer = interfaces::timer::run_timer(state); @@ -261,7 +250,6 @@ mod cli { use crate::proto::MeeSignClient; use crate::{Args, CA_CERT}; use clap::Subcommand; - use meesign_crypto; use std::str::FromStr; use std::time::SystemTime; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Uri}; diff --git a/src/persistence.rs b/src/persistence.rs index e3708b7..7de732c 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -7,13 +7,8 @@ mod models; mod repository; mod schema; -pub use enums::DeviceKind; -pub use enums::TaskType; +pub use enums::{DeviceKind, KeyType, ProtocolType, TaskType}; pub use error::PersistenceError; -pub use models::Device; -pub use models::Group; -pub use models::Participant; -pub use models::Task; +pub use models::{Device, Group, Participant, Task}; pub use repository::utils::NameValidator; -pub use repository::PgPool; pub use repository::Repository; diff --git a/src/persistence/enums.rs b/src/persistence/enums.rs index 8a06e97..f769806 100644 --- a/src/persistence/enums.rs +++ b/src/persistence/enums.rs @@ -24,7 +24,7 @@ impl From for ProtocolType { } } -#[derive(Debug, DbEnum, Clone, PartialEq, Eq, Serialize)] +#[derive(Copy, Debug, DbEnum, Clone, PartialEq, Eq, Serialize)] #[ExistingTypePath = "crate::persistence::schema::sql_types::TaskType"] #[DbValueStyle = "PascalCase"] pub enum TaskType { @@ -44,7 +44,7 @@ pub enum TaskType { // Decrypted, // } -#[derive(Debug, DbEnum, Clone, PartialEq, Eq)] +#[derive(Copy, Debug, DbEnum, Clone, PartialEq, Eq)] #[ExistingTypePath = "crate::persistence::schema::sql_types::DeviceKind"] #[DbValueStyle = "PascalCase"] pub enum DeviceKind { @@ -52,7 +52,7 @@ pub enum DeviceKind { Bot, } -#[derive(Debug, DbEnum, Clone, PartialEq, Eq, Serialize)] +#[derive(Copy, Debug, DbEnum, Clone, PartialEq, Eq, Serialize)] #[ExistingTypePath = "crate::persistence::schema::sql_types::TaskState"] #[DbValueStyle = "PascalCase"] pub enum TaskState { @@ -62,7 +62,7 @@ pub enum TaskState { Failed, } -#[derive(Debug, Clone, PartialEq, Eq, DbEnum, Serialize)] +#[derive(Copy, Debug, Clone, PartialEq, Eq, DbEnum, Serialize)] #[cfg_attr(test, derive(PartialOrd, Ord))] #[ExistingTypePath = "crate::persistence::schema::sql_types::KeyType"] #[DbValueStyle = "PascalCase"] @@ -82,6 +82,38 @@ impl From for KeyType { } } +impl From for proto::KeyType { + fn from(value: KeyType) -> Self { + match value { + KeyType::SignPdf => proto::KeyType::SignPdf, + KeyType::SignChallenge => proto::KeyType::SignChallenge, + KeyType::Decrypt => proto::KeyType::Decrypt, + } + } +} + +impl From for TaskType { + fn from(task_type: proto::TaskType) -> Self { + match task_type { + proto::TaskType::Group => Self::Group, + proto::TaskType::SignChallenge => Self::SignChallenge, + proto::TaskType::SignPdf => Self::SignPdf, + proto::TaskType::Decrypt => Self::Decrypt, + } + } +} + +impl From for proto::TaskType { + fn from(task_type: crate::persistence::TaskType) -> Self { + match task_type { + TaskType::Group => Self::Group, + TaskType::SignChallenge => Self::SignChallenge, + TaskType::SignPdf => Self::SignPdf, + TaskType::Decrypt => Self::Decrypt, + } + } +} + impl From for proto::ProtocolType { fn from(value: ProtocolType) -> Self { match value { @@ -93,12 +125,11 @@ impl From for proto::ProtocolType { } } -impl From for proto::KeyType { - fn from(value: KeyType) -> Self { - match value { - KeyType::SignPdf => proto::KeyType::SignPdf, - KeyType::SignChallenge => proto::KeyType::SignChallenge, - KeyType::Decrypt => proto::KeyType::Decrypt, +impl From for proto::DeviceKind { + fn from(device_kind: DeviceKind) -> Self { + match device_kind { + DeviceKind::User => Self::User, + DeviceKind::Bot => Self::Bot, } } } diff --git a/src/persistence/models.rs b/src/persistence/models.rs index c5e991a..9b46088 100644 --- a/src/persistence/models.rs +++ b/src/persistence/models.rs @@ -83,15 +83,6 @@ pub struct NewGroupParticipant<'a> { pub shares: i32, } -#[derive(Queryable, Selectable)] -#[diesel(table_name = group_participant)] -#[diesel(check_for_backend(diesel::pg::Pg))] -pub struct GroupParticipant { - pub device_id: Vec, - pub group_id: Vec, - pub shares: i32, -} - #[derive(Insertable)] #[diesel(table_name=task_participant)] pub struct NewTaskParticipant<'a> { @@ -102,6 +93,14 @@ pub struct NewTaskParticipant<'a> { pub acknowledgment: Option, } +#[derive(Queryable, Selectable, Insertable)] +#[diesel(table_name=active_task_participant)] +pub struct ActiveTaskParticipant { + pub task_id: Uuid, + pub device_id: Vec, + pub active_shares: i32, +} + #[derive(Queryable, Serialize, Clone, Eq, PartialEq)] pub struct Task { pub id: Uuid, diff --git a/src/persistence/repository.rs b/src/persistence/repository.rs index 45e2cbb..abd3488 100644 --- a/src/persistence/repository.rs +++ b/src/persistence/repository.rs @@ -6,23 +6,21 @@ use super::{ models::{Device, Group, Participant, Task}, }; +use self::group::{add_group, get_device_groups, get_groups}; use self::{ - device::{add_device, get_devices, get_devices_with_ids}, + device::{add_device, get_devices}, group::get_group, task::{ - get_active_device_tasks, get_task_acknowledgements, get_task_decisions, get_tasks, - increment_task_attempt_count, set_task_acknowledgement, set_task_decision, - set_task_group_certificates_sent, set_task_result, set_task_round, + get_active_device_tasks, get_restart_candidates, get_task_acknowledgements, + get_task_active_shares, get_task_decisions, get_task_models, get_tasks, + increment_task_attempt_count, set_task_acknowledgement, set_task_active_shares, + set_task_decision, set_task_group_certificates_sent, set_task_result, set_task_round, }, }; use self::{ - device::{get_group_participants, get_task_participants, get_tasks_participants}, + device::{get_group_participants, get_tasks_participants}, task::{create_group_task, create_task}, }; -use self::{ - group::{add_group, get_device_groups, get_groups}, - task::get_task, -}; use diesel::{Connection, PgConnection}; use diesel_async::AsyncPgConnection; @@ -34,7 +32,7 @@ use diesel_async::{ pooled_connection::AsyncDieselConnectionManager, scoped_futures::ScopedFutureExt, }; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::env; use std::sync::Arc; @@ -103,14 +101,6 @@ impl Repository { get_devices(connection).await } - pub async fn get_devices_with_ids( - &self, - device_ids: &[&[u8]], - ) -> Result, PersistenceError> { - let connection = &mut self.get_async_connection().await?; - get_devices_with_ids(connection, device_ids).await - } - pub async fn get_group_participants( &self, group_id: &[u8], @@ -119,14 +109,6 @@ impl Repository { get_group_participants(connection, group_id).await } - pub async fn get_task_participants( - &self, - task_id: &Uuid, - ) -> Result, PersistenceError> { - let connection = &mut self.get_async_connection().await?; - get_task_participants(connection, task_id).await - } - pub async fn get_tasks_participants( &self, task_ids: &[Uuid], @@ -190,17 +172,8 @@ impl Repository { get_device_groups(connection, identifier).await } - #[allow(unused_variables)] - pub async fn does_group_contain_device( - &self, - group_id: &[u8], - device_id: &[u8], - ) -> Result { - todo!() - } - /* Tasks */ - pub async fn create_group_task<'a>( + pub async fn create_group_task( &self, id: Option<&Uuid>, participants: &[(&[u8], u32)], @@ -231,7 +204,7 @@ impl Repository { .await } - pub async fn create_threshold_task<'a>( + pub async fn create_threshold_task( &self, id: Option<&Uuid>, group_id: &[u8], @@ -268,26 +241,28 @@ impl Repository { .await } - pub async fn get_task(&self, task_id: &Uuid) -> Result, PersistenceError> { + pub async fn get_task_models(&self, task_ids: &[Uuid]) -> Result, PersistenceError> { let connection = &mut self.get_async_connection().await?; - let task = get_task(connection, task_id).await?; - Ok(task) + let task_models = get_task_models(connection, task_ids).await?; + Ok(task_models) } - pub async fn get_tasks(&self) -> Result, PersistenceError> { + pub async fn get_tasks(&self) -> Result, PersistenceError> { let connection = &mut self.get_async_connection().await?; let tasks = get_tasks(connection).await?; Ok(tasks) } - pub async fn get_tasks_for_restart(&self) -> Result, PersistenceError> { - todo!() + pub async fn get_restart_candidates(&self) -> Result, PersistenceError> { + let connection = &mut self.get_async_connection().await?; + let tasks = get_restart_candidates(connection).await?; + Ok(tasks) } pub async fn get_active_device_tasks( &self, identifier: &[u8], - ) -> Result, PersistenceError> { + ) -> Result, PersistenceError> { let connection = &mut self.get_async_connection().await?; get_active_device_tasks(connection, identifier).await } @@ -322,11 +297,28 @@ impl Repository { pub async fn get_task_acknowledgements( &self, task_id: &Uuid, - ) -> Result, bool>, PersistenceError> { + ) -> Result>, PersistenceError> { let connection = &mut self.get_async_connection().await?; get_task_acknowledgements(connection, task_id).await } + pub async fn set_task_active_shares( + &self, + task_id: &Uuid, + active_shares: &HashMap, u32>, + ) -> Result<(), PersistenceError> { + let connection = &mut self.get_async_connection().await?; + set_task_active_shares(connection, task_id, active_shares).await + } + + pub async fn get_task_active_shares( + &self, + task_id: &Uuid, + ) -> Result, u32>, PersistenceError> { + let connection = &mut self.get_async_connection().await?; + get_task_active_shares(connection, task_id).await + } + pub async fn set_task_round( &self, task_id: &Uuid, diff --git a/src/persistence/repository/device.rs b/src/persistence/repository/device.rs index 48e2c3b..73e9a79 100644 --- a/src/persistence/repository/device.rs +++ b/src/persistence/repository/device.rs @@ -19,25 +19,6 @@ where Ok(device::table.load(connection).await?) } -pub async fn get_devices_with_ids( - connection: &mut Conn, - device_ids: &[&[u8]], -) -> Result, PersistenceError> -where - Conn: AsyncConnection, -{ - let device_map: Vec = device::table - .filter(device::id.eq_any(device_ids)) - .order_by(device::id) - .load(connection) - .await?; - let devices = device_ids - .iter() - .filter_map(|&id| device_map.iter().find(|dev| &dev.id == id).cloned()) - .collect(); - Ok(devices) -} - pub async fn get_group_participants( connection: &mut Conn, group_id: &[u8], @@ -60,28 +41,6 @@ where Ok(devices) } -pub async fn get_task_participants( - connection: &mut Conn, - task_id: &Uuid, -) -> Result, PersistenceError> -where - Conn: AsyncConnection, -{ - let devices = task_participant::table - .inner_join(device::table) - .filter(task_participant::task_id.eq(task_id)) - .select((Device::as_returning(), task_participant::shares)) - .load::<(Device, i32)>(connection) - .await? - .into_iter() - .map(|(device, shares)| Participant { - device, - shares: shares as u32, - }) - .collect(); - Ok(devices) -} - pub async fn get_tasks_participants( connection: &mut Conn, task_ids: &[Uuid], @@ -122,21 +81,21 @@ where Conn: AsyncConnection, { if !name.is_name_valid() { - return Err(PersistenceError::InvalidArgumentError(format!( - "Invalid device name: {name}" - ))); + return Err(PersistenceError::InvalidArgumentError( + "Invalid device name: {name}".to_string(), + )); } if identifier.is_empty() { - return Err(PersistenceError::InvalidArgumentError(format!( - "Empty identifier" - ))); + return Err(PersistenceError::InvalidArgumentError( + "Empty identifier".to_string(), + )); } if certificate.is_empty() { - return Err(PersistenceError::InvalidArgumentError(format!( - "Empty certificate" - ))); + return Err(PersistenceError::InvalidArgumentError( + "Empty certificate".to_string(), + )); } let new_device = NewDevice { diff --git a/src/persistence/repository/group.rs b/src/persistence/repository/group.rs index 039f8db..787f94b 100644 --- a/src/persistence/repository/group.rs +++ b/src/persistence/repository/group.rs @@ -68,7 +68,7 @@ where Ok(groups) } -pub async fn add_group<'a, Conn>( +pub async fn add_group( connection: &mut Conn, group_task_id: &Uuid, id: &[u8], @@ -77,7 +77,7 @@ pub async fn add_group<'a, Conn>( protocol: ProtocolType, key_type: KeyType, certificate: Option<&[u8]>, - note: Option<&'a str>, + note: Option<&str>, ) -> Result where Conn: AsyncConnection, @@ -85,13 +85,13 @@ where let threshold: i32 = threshold.try_into()?; if id.is_empty() { - return Err(PersistenceError::InvalidArgumentError(format!( - "Empty identifier" - ))); + return Err(PersistenceError::InvalidArgumentError( + "Empty identifier".to_string(), + )); } let new_group = NewGroup { id, - threshold: threshold as i32, + threshold, protocol, name, certificate, diff --git a/src/persistence/repository/task.rs b/src/persistence/repository/task.rs index 739f07e..e94a176 100644 --- a/src/persistence/repository/task.rs +++ b/src/persistence/repository/task.rs @@ -2,16 +2,16 @@ use diesel::result::Error::NotFound; use diesel::{pg::Pg, QueryDsl}; use diesel::{BoolExpressionMethods, ExpressionMethods, NullableExpressionMethods}; use diesel_async::{AsyncConnection, RunQueryDsl}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use uuid::Uuid; use super::utils::NameValidator; use crate::persistence::models::{NewTaskResult, Task}; -use crate::persistence::schema::{task_participant, task_result}; +use crate::persistence::schema::{active_task_participant, task_participant, task_result}; use crate::persistence::{ enums::{KeyType, ProtocolType, TaskState, TaskType}, error::PersistenceError, - models::{NewTask, NewTaskParticipant}, + models::{ActiveTaskParticipant, NewTask, NewTaskParticipant}, schema::task, }; @@ -61,7 +61,7 @@ where .await?; let new_task_participants: Vec = participants - .into_iter() + .iter() .map(|(device_id, shares)| NewTaskParticipant { device_id, task_id: &task_id, @@ -129,7 +129,7 @@ where .await?; let new_task_participants: Vec = participants - .into_iter() + .iter() .map(|(device_id, shares)| NewTaskParticipant { device_id, task_id: &task_id, @@ -173,7 +173,7 @@ macro_rules! task_model_columns { }; } -pub async fn get_task( +async fn get_task( connection: &mut Conn, task_id: &Uuid, ) -> Result, PersistenceError> @@ -194,13 +194,29 @@ where Ok(task) } -pub async fn get_tasks(connection: &mut Conn) -> Result, PersistenceError> +pub async fn get_tasks(connection: &mut Conn) -> Result, PersistenceError> +where + Conn: AsyncConnection, +{ + let tasks = task::table + .select(task::id) + .order_by(task::id.asc()) + .load(connection) + .await?; + Ok(tasks) +} + +pub async fn get_task_models( + connection: &mut Conn, + task_ids: &[Uuid], +) -> Result, PersistenceError> where Conn: AsyncConnection, { let tasks = task::table .left_outer_join(task_result::table) .select(task_model_columns!()) + .filter(task::id.eq_any(task_ids)) .order_by(task::id.asc()) .load(connection) .await?; @@ -209,15 +225,15 @@ where pub async fn get_active_device_tasks( connection: &mut Conn, - identifier: &[u8], -) -> Result, PersistenceError> + device_id: &[u8], +) -> Result, PersistenceError> where Conn: AsyncConnection, { let tasks = task::table .left_outer_join(task_result::table) .inner_join(task_participant::table) - .filter(task_participant::device_id.eq(identifier)) + .filter(task_participant::device_id.eq(device_id)) .filter( task_result::is_successful .is_null() @@ -225,7 +241,28 @@ where .is_null() .or(task_participant::acknowledgment.eq(false))), ) - .select(task_model_columns!()) + .select(task::id) + .order_by(task::id.asc()) + .load(connection) + .await?; + Ok(tasks) +} + +pub async fn get_restart_candidates( + connection: &mut Conn, +) -> Result, PersistenceError> +where + Conn: AsyncConnection, +{ + let tasks = task::table + .left_outer_join(task_result::table) + .filter(task_result::is_successful.is_null()) + .filter( + task::group_certificates_sent + .eq(Some(true)) + .or(task::protocol_round.ge(1)), + ) + .select(task::id) .order_by(task::id.asc()) .load(connection) .await?; @@ -303,7 +340,7 @@ where pub async fn get_task_acknowledgements( connection: &mut Conn, task_id: &Uuid, -) -> Result, bool>, PersistenceError> +) -> Result>, PersistenceError> where Conn: AsyncConnection, { @@ -316,11 +353,63 @@ where .load::<(Vec, Option)>(connection) .await? .into_iter() - .map(|(device_id, acknowledgement)| (device_id, acknowledgement.unwrap_or(false))) + .filter_map(|(device_id, acknowledgement)| match acknowledgement { + Some(true) => Some(device_id), + _ => None, + }) .collect(); Ok(acknowledgements) } +pub async fn get_task_active_shares( + connection: &mut Conn, + task_id: &Uuid, +) -> Result, u32>, PersistenceError> +where + Conn: AsyncConnection, +{ + let active_shares = active_task_participant::table + .select(( + active_task_participant::device_id, + active_task_participant::active_shares, + )) + .filter(active_task_participant::task_id.eq(task_id)) + .load::<(Vec, i32)>(connection) + .await? + .into_iter() + .map(|(device_id, active_shares)| (device_id, active_shares as u32)) + .collect(); + Ok(active_shares) +} + +pub async fn set_task_active_shares( + connection: &mut Conn, + task_id: &Uuid, + active_shares: &HashMap, u32>, +) -> Result<(), PersistenceError> +where + Conn: AsyncConnection, +{ + let active_task_participants: Vec<_> = active_shares + .iter() + .map(|(device_id, shares)| ActiveTaskParticipant { + task_id: *task_id, + device_id: device_id.clone(), + active_shares: *shares as i32, + }) + .collect(); + diesel::insert_into(active_task_participant::table) + .values(&active_task_participants) + .on_conflict(( + active_task_participant::task_id, + active_task_participant::device_id, + )) + .do_nothing() + .execute(connection) + .await?; + Ok(()) +} + pub async fn set_task_result( connection: &mut Conn, task_id: &Uuid, diff --git a/src/persistence/schema.rs b/src/persistence/schema.rs index 723d5ce..0fafb9b 100644 --- a/src/persistence/schema.rs +++ b/src/persistence/schema.rs @@ -22,6 +22,14 @@ pub mod sql_types { pub struct TaskType; } +diesel::table! { + active_task_participant (task_id, device_id) { + task_id -> Uuid, + device_id -> Bytea, + active_shares -> Int4, + } +} + diesel::table! { use diesel::sql_types::*; use super::sql_types::DeviceKind; @@ -102,6 +110,8 @@ diesel::table! { } } +diesel::joinable!(active_task_participant -> device (device_id)); +diesel::joinable!(active_task_participant -> task (task_id)); diesel::joinable!(group_participant -> device (device_id)); diesel::joinable!(group_participant -> group (group_id)); diesel::joinable!(task -> group (group_id)); @@ -110,6 +120,7 @@ diesel::joinable!(task_participant -> task (task_id)); diesel::joinable!(task_result -> task (task_id)); diesel::allow_tables_to_appear_in_same_query!( + active_task_participant, device, group, group_participant, diff --git a/src/protocols/elgamal.rs b/src/protocols/elgamal.rs index a2db18e..d6e8a7c 100644 --- a/src/protocols/elgamal.rs +++ b/src/protocols/elgamal.rs @@ -30,7 +30,6 @@ impl ElgamalGroup { impl Protocol for ElgamalGroup { fn initialize(&mut self, communicator: &mut Communicator, _: &[u8]) { - communicator.set_active_devices(None); let parties = self.parties; let threshold = self.threshold; communicator.send_all(|idx| { @@ -96,7 +95,6 @@ impl ElgamalDecrypt { impl Protocol for ElgamalDecrypt { fn initialize(&mut self, communicator: &mut Communicator, data: &[u8]) { - communicator.set_active_devices(None); let participant_indices = communicator.get_protocol_indices(); communicator.send_all(|idx| { (ProtocolInit { diff --git a/src/protocols/frost.rs b/src/protocols/frost.rs index f2d3fd9..598a385 100644 --- a/src/protocols/frost.rs +++ b/src/protocols/frost.rs @@ -30,7 +30,6 @@ impl FROSTGroup { impl Protocol for FROSTGroup { fn initialize(&mut self, communicator: &mut Communicator, _: &[u8]) { - communicator.set_active_devices(None); let parties = self.parties; let threshold = self.threshold; communicator.send_all(|idx| { @@ -96,7 +95,6 @@ impl FROSTSign { impl Protocol for FROSTSign { fn initialize(&mut self, communicator: &mut Communicator, data: &[u8]) { - communicator.set_active_devices(None); let participant_indices = communicator.get_protocol_indices(); communicator.send_all(|idx| { (ProtocolInit { diff --git a/src/protocols/gg18.rs b/src/protocols/gg18.rs index cfaae03..832bc60 100644 --- a/src/protocols/gg18.rs +++ b/src/protocols/gg18.rs @@ -30,7 +30,6 @@ impl GG18Group { impl Protocol for GG18Group { fn initialize(&mut self, communicator: &mut Communicator, _: &[u8]) { - communicator.set_active_devices(None); let parties = self.parties; let threshold = self.threshold; communicator.send_all(|idx| { @@ -96,7 +95,6 @@ impl GG18Sign { impl Protocol for GG18Sign { fn initialize(&mut self, communicator: &mut Communicator, data: &[u8]) { - communicator.set_active_devices(None); let participant_indices = communicator.get_protocol_indices(); communicator.send_all(|idx| { (ProtocolInit { diff --git a/src/protocols/musig2.rs b/src/protocols/musig2.rs index 31fd004..3edbdd7 100644 --- a/src/protocols/musig2.rs +++ b/src/protocols/musig2.rs @@ -25,7 +25,6 @@ impl MuSig2Group { impl Protocol for MuSig2Group { fn initialize(&mut self, communicator: &mut Communicator, _: &[u8]) { - communicator.set_active_devices(None); let parties = self.parties; communicator.send_all(|idx| { (ProtocolGroupInit { @@ -90,7 +89,6 @@ impl MuSig2Sign { impl Protocol for MuSig2Sign { fn initialize(&mut self, communicator: &mut Communicator, data: &[u8]) { - communicator.set_active_devices(None); let participant_indices = communicator.get_protocol_indices(); communicator.send_all(|idx| { (ProtocolInit { diff --git a/src/state.rs b/src/state.rs index 8b86e9a..d5e6119 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,45 +1,52 @@ use dashmap::DashMap; use log::{debug, error, warn}; use std::collections::HashMap; -use tokio::sync::RwLock; use uuid::Uuid; -use crate::communicator::Communicator; use crate::error::Error; use crate::persistence::{ - Device, DeviceKind, Group, NameValidator, Participant, PersistenceError, Repository, - Task as TaskModel, TaskType, + Device, DeviceKind, Group, KeyType, NameValidator, Participant, PersistenceError, ProtocolType, + Repository, TaskType, +}; +use crate::proto; +use crate::task_store::TaskStore; +use crate::tasks::{ + DecisionUpdate, RoundUpdate, RunningTaskContext, Task, TaskInfo, TaskResult, VotingTask, }; -use crate::proto::{self, KeyType, ProtocolType}; -use crate::tasks::decrypt::DecryptTask; -use crate::tasks::group::GroupTask; -use crate::tasks::sign::SignTask; -use crate::tasks::sign_pdf::SignPDFTask; -use crate::tasks::{DecisionUpdate, RestartUpdate, RoundUpdate, Task, TaskResult}; use crate::{get_timestamp, utils}; +use meesign_crypto::proto::ClientMessage; +use prost::Message as _; +use rand::{prelude::IteratorRandom, thread_rng}; use tokio::sync::mpsc::Sender; use tonic::codegen::Arc; use tonic::Status; pub struct State { - // tasks: HashMap>, - subscribers: DashMap, Sender>>, + devices: DashMap, Device>, + subscribers: DashMap, Sender>>, repo: Arc, - communicators: DashMap>>, + task_store: TaskStore, task_last_updates: DashMap, device_last_activations: DashMap, u64>, } impl State { - pub fn new(repo: Arc) -> Self { - State { - // tasks: HashMap::new(), + pub async fn restore(repo: Arc) -> Result { + let devices = repo + .get_devices() + .await? + .into_iter() + .map(|dev| (dev.id.clone(), dev)) + .collect(); + let state = State { + devices, subscribers: DashMap::new(), - repo, - communicators: DashMap::default(), + repo: repo.clone(), + task_store: TaskStore::new(repo), task_last_updates: DashMap::new(), device_last_activations: DashMap::new(), - } + }; + Ok(state) } pub async fn add_device( @@ -49,20 +56,22 @@ impl State { kind: &DeviceKind, certificate: &[u8], ) -> Result { - Ok(self + let device = self .get_repo() .add_device(identifier, name, kind, certificate) - .await?) + .await?; + self.devices.insert(device.id.clone(), device.clone()); + Ok(device) } pub async fn add_group_task( - &mut self, + &self, name: &str, device_ids: &[&[u8]], threshold: u32, - protocol: ProtocolType, + protocol_type: ProtocolType, key_type: KeyType, note: Option, - ) -> Result { + ) -> Result { if !name.is_name_valid() { error!("Group request with invalid group name {}", name); return Err(Error::GeneralProtocolError(format!( @@ -75,37 +84,53 @@ impl State { } let device_ids: Vec<&[u8]> = shares.keys().cloned().collect(); let participants = self - .get_repo() .get_devices_with_ids(&device_ids) - .await? .into_iter() .map(|device| { let shares = shares[device.id.as_slice()]; Participant { device, shares } }) .collect(); - let task = Box::new(GroupTask::try_new( - name, + let task_info = TaskInfo { + id: Uuid::new_v4(), + name: name.to_string(), + task_type: TaskType::Group, + protocol_type, + key_type, participants, + attempts: 0, + }; + let accept_threshold = task_info.total_shares(); + let request = (proto::GroupRequest { + device_ids: device_ids.into_iter().map(Vec::from).collect(), + name: task_info.name.clone(), threshold, - protocol, - key_type, - note, - )?) as Box; - - // TODO: group ID? - let task_id = self.add_task(task, &[], key_type, protocol).await?; + protocol: protocol_type as i32, + key_type: key_type as i32, + note: note.clone(), + }) + .encode_to_vec(); + let running_task_context = RunningTaskContext::Group { + threshold, + note: note.clone(), + }; + let task = VotingTask { + task_info, + decisions: HashMap::new(), + accept_threshold, + request, + running_task_context, + }; - self.send_updates(&task_id).await?; - Ok(task_id) + self.add_task(task).await } pub async fn add_sign_task( - &mut self, + &self, group_id: &[u8], name: &str, data: &[u8], - ) -> Result { + ) -> Result { let group = self.get_repo().get_group(group_id).await?; let Some(group) = group else { warn!( @@ -115,16 +140,26 @@ impl State { return Err(Error::GeneralProtocolError("Invalid group_id".into())); }; let participants = self.repo.get_group_participants(group_id).await?; - let group = crate::group::Group::from_model(group, participants); - let task = match group.key_type() { - KeyType::SignPdf => { - let task = SignPDFTask::try_new(group.clone(), name.to_string(), data.to_vec())?; - Box::new(task) as Box - } - KeyType::SignChallenge => { - let task = SignTask::try_new(group.clone(), name.to_string(), data.to_vec())?; - Box::new(task) as Box - } + let group_id = group.id.clone(); + let key_type = group.key_type; + let protocol_type = group.protocol; + let accept_threshold = group.threshold as u32; + let data = data.to_vec(); + let request = proto::SignRequest { + group_id: group_id.clone(), + name: name.to_string(), + data: data.clone(), + } + .encode_to_vec(); + let (task_type, running_task_context) = match key_type { + KeyType::SignPdf => ( + TaskType::SignPdf, + RunningTaskContext::SignPdf { group, data }, + ), + KeyType::SignChallenge => ( + TaskType::SignChallenge, + RunningTaskContext::SignChallenge { group, data }, + ), KeyType::Decrypt => { warn!( "Signing request made for decryption group group_id={}", @@ -135,22 +170,33 @@ impl State { )); } }; + let task_info = TaskInfo { + id: Uuid::new_v4(), + name: name.to_string(), + task_type, + protocol_type, + key_type, + participants, + attempts: 0, + }; + let task = VotingTask { + task_info, + decisions: HashMap::new(), + accept_threshold, + request, + running_task_context, + }; - let task_id = self - .add_task(task, group.identifier(), group.key_type(), group.protocol()) - .await?; - self.send_updates(&task_id).await?; - - Ok(task_id) + self.add_task(task).await } pub async fn add_decrypt_task( - &mut self, + &self, group_id: &[u8], name: &str, data: &[u8], data_type: &str, - ) -> Result { + ) -> Result { let group: Option = self.get_repo().get_group(group_id).await?; let Some(group) = group else { warn!( @@ -160,17 +206,20 @@ impl State { return Err(Error::GeneralProtocolError("Invalid group_id".into())); }; let participants = self.repo.get_group_participants(group_id).await?; - let group = crate::group::Group::from_model(group, participants); - let task = match group.key_type() { - KeyType::Decrypt => { - let task = DecryptTask::try_new( - group.clone(), - name.to_string(), - data.to_vec(), - data_type.to_string(), - )?; - Box::new(task) as Box - } + let group_id = group.id.clone(); + let key_type = group.key_type; + let protocol_type = group.protocol; + let accept_threshold = group.threshold as u32; + let data = data.to_vec(); + let request = proto::DecryptRequest { + group_id: group_id.clone(), + name: name.to_string(), + data: data.clone(), + data_type: data_type.to_string(), + } + .encode_to_vec(); + let running_task_context = match key_type { + KeyType::Decrypt => RunningTaskContext::Decrypt { group, data }, KeyType::SignPdf | KeyType::SignChallenge => { warn!( "Decryption request made for a signing group group_id={}", @@ -181,73 +230,45 @@ impl State { )); } }; + let task_info = TaskInfo { + id: Uuid::new_v4(), + name: name.to_string(), + task_type: TaskType::Decrypt, + protocol_type, + key_type, + participants, + attempts: 0, + }; + let task = VotingTask { + task_info, + decisions: HashMap::new(), + accept_threshold, + request, + running_task_context, + }; - let task_id = self - .add_task(task, group.identifier(), group.key_type(), group.protocol()) - .await?; - self.send_updates(&task_id).await?; - Ok(task_id) + self.add_task(task).await } - async fn add_task( - &mut self, - task: Box, - group_id: &[u8], - key_type: KeyType, - protocol_type: ProtocolType, - ) -> Result { - let task_participants = task.get_participants(); - let participant_ids_shares: Vec<(&[u8], u32)> = task_participants - .iter() - .map(|participant| (participant.device.id.as_slice(), participant.shares)) - .collect(); - let created_task = match task.get_type() { - crate::proto::TaskType::Group => { - self.get_repo() - .create_group_task( - Some(task.get_id()), - &participant_ids_shares, - task.get_threshold(), - protocol_type.into(), - key_type.into(), - task.get_request(), - None, // TODO: missing note - ) - .await? - } - task_type => { - self.get_repo() - .create_threshold_task( - Some(task.get_id()), - group_id, - &participant_ids_shares, - task.get_threshold(), - "name", - task.get_data().unwrap(), - task.get_request(), - task_type.into(), - key_type.into(), - protocol_type.into(), - ) - .await? - } - }; - if let Some(_communicator) = self - .communicators - .insert(created_task.id, task.get_communicator()) - { - // TODO: create a new "internal error" error variant - error!( - "A communicator for task with id {} already exists!", - created_task.id - ); - return Err(Error::GeneralProtocolError("Data inconsistency".into())); - }; - Ok(created_task.id) + async fn add_task(&self, task: VotingTask) -> Result { + let task_id = task.task_info.id; + self.task_store.persist_task(task).await?; + let task = self.task_store.get_task(&task_id).await?; + self.send_updates(&task).await?; + let formatted = task.format(None, None); + Ok(formatted) } - pub async fn get_active_device_tasks(&self, device: &[u8]) -> Result, Error> { - Ok(self.get_repo().get_active_device_tasks(device).await?) + pub async fn get_formatted_active_device_tasks( + &self, + device_id: &[u8], + ) -> Result, Error> { + let task_ids = self.repo.get_active_device_tasks(device_id).await?; + let mut tasks = Vec::new(); + for task in self.task_store.get_tasks(task_ids).await? { + tasks.push(task.await.format(Some(device_id), None)); + } + Ok(tasks) } pub fn activate_device(&self, device_id: &[u8]) { @@ -255,27 +276,32 @@ impl State { .insert(device_id.to_vec(), get_timestamp()); } - pub async fn device_exists(&self, device_id: &[u8]) -> Result { - // TODO: Optimize query / cache devices in State - let devices = self.repo.get_devices_with_ids(&[device_id]).await?; - Ok(devices.len() == 1 && devices[0].id == device_id) + pub fn device_exists(&self, device_id: &[u8]) -> bool { + self.devices.contains_key(device_id) } - pub async fn get_devices(&self) -> Result, Error> { - let devices = self - .get_repo() - .get_devices() - .await? - .into_iter() - .map(|dev| { - let last_active = *self - .device_last_activations - .entry(dev.id.clone()) - .or_insert(get_timestamp()); - (dev, last_active) + pub fn get_devices(&self) -> Vec<(Device, u64)> { + self.devices + .iter() + .map(|entry| { + let last_active = self.get_device_last_activation(entry.key()); + (entry.value().clone(), last_active) }) - .collect(); - Ok(devices) + .collect() + } + + fn get_device_last_activation(&self, device_id: &[u8]) -> u64 { + *self + .device_last_activations + .entry(device_id.to_vec()) + .or_insert(get_timestamp()) // TODO: Assume inactive device? + } + + fn get_devices_with_ids(&self, device_ids: &[&[u8]]) -> Vec { + device_ids + .iter() + .filter_map(|device_id| self.devices.get(*device_id).map(|dev| dev.clone())) + .collect() } pub async fn get_device_groups(&self, device: &[u8]) -> Result, Error> { @@ -286,28 +312,47 @@ impl State { Ok(self.get_repo().get_groups().await?) } - pub async fn get_tasks(&self) -> Result, Error> { - Ok(self.repo.get_tasks().await?) + pub async fn get_formatted_tasks(&self) -> Result, Error> { + let task_ids = self.repo.get_tasks().await?; + let mut tasks = Vec::new(); + for task in self.task_store.get_tasks(task_ids).await? { + tasks.push(task.await.format(None, None)); + } + Ok(tasks) } - pub async fn get_task(&self, task_id: &Uuid) -> Result { - let Some(task_model) = self.repo.get_task(task_id).await? else { - return Err(Error::GeneralProtocolError("Invalid task id".into())); + pub async fn get_formatted_voting_task( + &self, + task_id: &Uuid, + device_id: Option<&[u8]>, + ) -> Result { + let task = &*self.task_store.get_task(task_id).await?; + let request = if let Task::Voting(task) = task { + task.request.clone() + } else { + return Err(Error::GeneralProtocolError( + "Queried task is not in voting phase".into(), + )); }; - Ok(task_model) + let task = task.format(device_id, Some(request)); + Ok(task) } pub async fn update_task( - &mut self, + &self, task_id: &Uuid, device: &[u8], - data: &Vec>, + data: Vec, attempt: u32, ) -> Result<(), Error> { - let task_model = self.get_task(task_id).await?; - let mut task = self.task_from_task_model(task_model).await?; + let task_entry = &mut *self.task_store.get_task_mut(task_id).await?; + let Task::Running(task) = task_entry else { + return Err(Error::GeneralProtocolError( + "Cannot update non-running task".into(), + )); + }; self.set_task_last_update(task_id); - if attempt != task.get_attempts() { + if attempt != task.task_info().attempts { warn!( "Stale update discarded task_id={} device_id={} attempt={}", utils::hextrunc(task_id.as_bytes()), @@ -317,82 +362,105 @@ impl State { return Err(Error::GeneralProtocolError("Stale update".to_string())); } - match task.update(device, data).await? { + match task.update(device, data)? { RoundUpdate::Listen => {} RoundUpdate::GroupCertificatesSent => unreachable!(), RoundUpdate::NextRound(round) => { self.repo.set_task_round(task_id, round).await?; - self.send_updates(task_id).await?; + self.send_updates(task_entry).await?; } - RoundUpdate::Failed(reason) => { - self.repo.set_task_result(task_id, &Err(reason)).await?; - self.send_updates(task_id).await?; + RoundUpdate::Failed(task) => { + self.repo + .set_task_result(task_id, &Err(task.reason.clone())) + .await?; + *task_entry = Task::Failed(task); + self.send_updates(task_entry).await?; } - RoundUpdate::Finished(round, result) => { + RoundUpdate::Finished(round, task) => { self.repo.set_task_round(task_id, round).await?; + let result = &task.result; let result_bytes = result.as_bytes().to_vec(); if let TaskResult::GroupEstablished(group) = result { self.repo .add_group( - group.identifier(), + &group.id, task_id, - group.name(), - group.threshold(), - group.protocol().into(), - group.key_type().into(), - group.certificate().map(|v| v.as_ref()), - group.note(), + &group.name, + group.threshold as u32, + group.protocol, + group.key_type, + group.certificate.as_deref(), + group.note.as_deref(), ) .await?; } self.repo .set_task_result(task_id, &Ok(result_bytes)) .await?; - // NOTE: Updates must be sent after the group is persisted - self.send_updates(task_id).await?; + *task_entry = Task::Finished(task); + self.send_updates(task_entry).await?; } } Ok(()) } pub async fn decide_task( - &mut self, + &self, task_id: &Uuid, - device: &[u8], - decision: bool, + device_id: &[u8], + accept: bool, ) -> Result<(), Error> { - let task_model = self.get_task(task_id).await?; - let mut task = self.task_from_task_model(task_model).await?; + let task_entry = &mut *self.task_store.get_task_mut(task_id).await?; + let Task::Voting(task) = task_entry else { + return Err(Error::GeneralProtocolError( + "Cannot decide non-voting task".into(), + )); + }; self.set_task_last_update(task_id); - let decision_update = task.decide(device, decision).await?; + let decision_update = task.decide(device_id, accept).await?; self.repo - .set_task_decision(task_id, device, decision) + .set_task_decision(task_id, device_id, accept) .await?; match decision_update { DecisionUpdate::Undecided => {} - DecisionUpdate::Accepted(round_update) => { + DecisionUpdate::Accepted => { log::info!( "Task approved task_id={}", utils::hextrunc(task_id.as_bytes()) ); - match round_update { + + let active_shares = self.choose_active_shares(task).await?; + + let mut task = task + .running_task_context + .clone() + .create_running_task(task, active_shares)?; + + match task.initialize()? { RoundUpdate::Listen => unreachable!(), RoundUpdate::Finished(_, _) => unreachable!(), RoundUpdate::GroupCertificatesSent => { self.repo .set_task_group_certificates_sent(task_id, Some(true)) .await?; + *task_entry = Task::Running(task); + self.send_updates(task_entry).await?; } RoundUpdate::NextRound(round) => { self.repo.set_task_round(task_id, round).await?; + *task_entry = Task::Running(task); + self.send_updates(task_entry).await?; } - RoundUpdate::Failed(reason) => { - self.repo.set_task_result(task_id, &Err(reason)).await?; + RoundUpdate::Failed(task) => { + self.repo + .set_task_result(task_id, &Err(task.reason.clone())) + .await?; + *task_entry = Task::Failed(task); + self.send_updates(task_entry).await?; } } - self.send_updates(task_id).await?; } - DecisionUpdate::Declined => { + DecisionUpdate::Declined(task) => { log::info!( "Task declined task_id={}", utils::hextrunc(task_id.as_bytes()) @@ -400,84 +468,154 @@ impl State { self.repo .set_task_result(task_id, &Err("Task declined".into())) .await?; - self.send_updates(task_id).await?; + *task_entry = Task::Declined(task); + self.send_updates(task_entry).await?; } - } + }; Ok(()) } - pub async fn acknowledge_task(&mut self, task_id: &Uuid, device: &[u8]) -> Result<(), Error> { - let task_model = self.get_task(task_id).await?; - let mut task = self.task_from_task_model(task_model).await?; - task.acknowledge(device).await; - self.repo.set_task_acknowledgement(task_id, device).await?; - Ok(()) - } + /// Picks which shares shall participate in the protocol. + /// Considers only those devices which accepted participation. + /// If enough devices are available, additionaly filters by response latency. + /// Returns a mapping of protocol indices to device ids. + /// The clients expect that out of a participant's [0..n] shares, + /// exactly the first [0..k] will be chosen. + async fn choose_active_shares(&self, task: &VotingTask) -> Result, Error> { + // NOTE: Threshold tasks need to use indices from group establishment, that is, + // the indices assigned to all task participants. Since we don't store + // any such index mapping, we generate it from a sorted list of devices. + let mut all_participants = task.task_info.participants.clone(); + all_participants.sort_by(|a, b| a.device.id.cmp(&b.device.id)); + let first_share_indices: HashMap, u32> = all_participants + .into_iter() + .scan(0, |idx, p| { + let first_share = *idx; + *idx += p.shares; + Some((p.device.id.clone(), first_share)) + }) + .collect(); - pub async fn restart_task(&mut self, task_id: &Uuid) -> Result { - let task_model = self.get_task(task_id).await?; - let mut task = self.task_from_task_model(task_model).await?; - self.set_task_last_update(task_id); + let accepting_participants: Vec<&Participant> = task + .task_info + .participants + .iter() + .filter(|p| task.device_accepted(&p.device.id)) + .collect(); + let latest_acceptable_time = get_timestamp() - 5; + let connected_participants: Vec<&Participant> = accepting_participants + .iter() + .filter(|p| { + let last_active_time = self.get_device_last_activation(&p.device.id); + last_active_time > latest_acceptable_time + }) + .copied() + .collect(); - match task.restart().await? { - RestartUpdate::AlreadyFinished => Ok(false), - RestartUpdate::Voting => Ok(false), - RestartUpdate::Started(round_update) => { - self.repo.increment_task_attempt_count(task_id).await?; - match round_update { - RoundUpdate::Listen => unreachable!(), - RoundUpdate::Finished(_, _) => unreachable!(), - RoundUpdate::GroupCertificatesSent => { - self.repo - .set_task_group_certificates_sent(task_id, Some(true)) - .await?; - } - RoundUpdate::NextRound(round) => { - self.repo.set_task_round(task_id, round).await?; - } - RoundUpdate::Failed(reason) => { - self.repo.set_task_result(task_id, &Err(reason)).await?; - } - } - self.send_updates(task_id).await?; - Ok(true) - } + let total_connected_shares: u32 = connected_participants.iter().map(|p| p.shares).sum(); + let candidates = if total_connected_shares >= task.accept_threshold { + connected_participants + } else { + accepting_participants + }; + + let chosen_devices = candidates + .into_iter() + .flat_map(|p| std::iter::repeat_n(&p.device, p.shares as usize)) + .choose_multiple(&mut thread_rng(), task.accept_threshold as usize); + + let mut active_shares = HashMap::new(); + for device in &chosen_devices { + *active_shares.entry(device.id.clone()).or_default() += 1; } + self.repo + .set_task_active_shares(&task.task_info.id, &active_shares) + .await?; + + let active_devices = chosen_devices + .into_iter() + .scan(first_share_indices, |share_indices, device| { + let share_index = share_indices[&device.id]; + *share_indices.get_mut(&device.id).unwrap() += 1; + Some((share_index, device.clone())) + }) + .collect(); + + Ok(active_devices) } - pub async fn get_tasks_for_restart(&self) -> Result, Error> { - let task_models = self.get_tasks().await?; - let tasks = self.tasks_from_task_models(task_models.clone()).await?; - let now = get_timestamp(); + pub async fn acknowledge_task(&self, task_id: &Uuid, device: &[u8]) -> Result<(), Error> { + let task_entry = &mut *self.task_store.get_task_mut(task_id).await?; + let Task::Finished(task) = task_entry else { + return Err(Error::GeneralProtocolError( + "Cannot acknowledge unfinished task".into(), + )); + }; + task.acknowledge(device); + self.repo.set_task_acknowledgement(task_id, device).await?; + Ok(()) + } - let mut restarts = Vec::new(); - for (task_model, task) in task_models.into_iter().zip(tasks.into_iter()) { - let task_id = &task_model.id; - let result = task_model - .result - .map(|res| res.try_into_result()) - .transpose()?; - let last_update = self.get_task_last_update(task_id); - let stale = result.is_none() - && task.is_approved().await - && now.saturating_sub(last_update) > 30; - if stale { - debug!("Stale task detected task_id={:?}", utils::hextrunc(task_id)); - restarts.push(*task_id); + pub async fn restart_stale_tasks(&self) -> Result<(), Error> { + let now = get_timestamp(); + let task_ids = self + .repo + .get_restart_candidates() + .await? + .into_iter() + .filter(|task_id| { + let last_update = self.get_task_last_update(task_id); + now.saturating_sub(last_update) > 30 + }) + .collect(); + let tasks = self.task_store.get_tasks_mut(task_ids).await?; + for task_entry in tasks { + let task_entry = &mut *task_entry.await; + let Task::Running(task) = task_entry else { + return Err(PersistenceError::DataInconsistencyError( + "non-running task in candidates for restart".into(), + ) + .into()); + }; + let task_id = &task.task_info().id.clone(); + debug!("Stale task detected task_id={:?}", utils::hextrunc(task_id)); + self.set_task_last_update(task_id); + + self.repo.increment_task_attempt_count(task_id).await?; + match task.restart()? { + RoundUpdate::Listen => unreachable!(), + RoundUpdate::Finished(_, _) => unreachable!(), + RoundUpdate::GroupCertificatesSent => { + self.repo + .set_task_group_certificates_sent(task_id, Some(true)) + .await?; + self.send_updates(task_entry).await?; + } + RoundUpdate::NextRound(round) => { + self.repo.set_task_round(task_id, round).await?; + self.send_updates(task_entry).await?; + } + RoundUpdate::Failed(task) => { + self.repo + .set_task_result(task_id, &Err(task.reason.clone())) + .await?; + *task_entry = Task::Failed(task); + self.send_updates(task_entry).await?; + } } } - Ok(restarts) + Ok(()) } pub fn add_subscriber( - &mut self, + &self, device_id: Vec, tx: Sender>, ) { self.subscribers.insert(device_id, tx); } - pub fn remove_subscriber(&mut self, device_id: &Vec) { + pub fn remove_subscriber(&self, device_id: &Vec) { self.subscribers.remove(device_id); debug!( "Removing subscriber device_id={}", @@ -489,19 +627,13 @@ impl State { &self.subscribers } - pub async fn send_updates(&mut self, task_id: &Uuid) -> Result<(), Error> { - let Some(task_model) = self.repo.get_task(task_id).await? else { - return Err(Error::GeneralProtocolError("Invalid task id".into())); - }; - let participants = self.repo.get_task_participants(task_id).await?; + pub async fn send_updates(&self, task: &Task) -> Result<(), Error> { let mut remove = Vec::new(); - for participant in participants { + for participant in &task.task_info().participants { let device_id = participant.device.identifier(); if let Some(tx) = self.subscribers.get(device_id) { - let result = tx.try_send(Ok(self - .format_task(task_model.clone(), Some(device_id), None) - .await?)); + let result = tx.try_send(Ok(task.format(Some(device_id), None))); if result.is_err() { debug!( @@ -524,214 +656,14 @@ impl State { &self.repo } - async fn get_communicator( - &self, - task_model: &TaskModel, - mut participants: Vec, - ) -> Result>, Error> { - use dashmap::mapref::entry::Entry; - let communicator = match self.communicators.entry(task_model.id.clone()) { - Entry::Occupied(entry) => entry.into_ref(), - Entry::Vacant(entry) => { - participants.sort_by(|a, b| a.device.identifier().cmp(b.device.identifier())); - let decisions = self.repo.get_task_decisions(&task_model.id).await?; - let acknowledgements = self.repo.get_task_acknowledgements(&task_model.id).await?; - let threshold = match task_model.task_type { - TaskType::Group => participants.iter().map(|p| p.shares).sum(), - _ => task_model.threshold as u32, - }; - - entry.insert(Arc::new(RwLock::new(Communicator::new( - participants, - threshold, - task_model.protocol_type.into(), - decisions, - acknowledgements, - )))) - } - } - .clone(); - Ok(communicator) - } - - async fn task_from_task_model( - &self, - task_model: TaskModel, - ) -> Result, Error> { - let task = self - .tasks_from_task_models(vec![task_model]) - .await? - .pop() - .unwrap(); - Ok(task) - } - - async fn tasks_from_task_models( - &self, - task_models: Vec, - ) -> Result>, Error> { - // NOTE: Sorted and unique task models (strict inequality) - assert!(task_models.windows(2).all(|w| w[0].id < w[1].id)); - - let task_ids: Vec<_> = task_models.iter().map(|task| task.id.clone()).collect(); - - let task_id_participant_pairs = self.get_repo().get_tasks_participants(&task_ids).await?; - - let mut task_id_participants: HashMap<_, Vec<_>> = HashMap::new(); - - for (task_id, device) in task_id_participant_pairs { - task_id_participants - .entry(task_id) - .or_default() - .push(device); - } - - // NOTE: When hydrating many tasks, `future::join_all` would cause - // a deadlock with `repository::get_async_connection`. - let mut tasks = Vec::new(); - for task in task_models { - let participants = task_id_participants.remove(&task.id).unwrap(); - let communicator = self.get_communicator(&task, participants.clone()).await?; - - let task = match task.task_type { - TaskType::Group => { - self.group_task_from_model(task, communicator, participants) - .await? - } - TaskType::SignChallenge | TaskType::SignPdf | TaskType::Decrypt => { - self.threshold_task_from_model(task, communicator, participants) - .await? - } - }; - - tasks.push(task) - } - Ok(tasks) - } - - async fn group_task_from_model( - &self, - task_model: TaskModel, - communicator: Arc>, - participants: Vec, - ) -> Result, Error> { - let task_result = task_model.result.clone(); - let group = if let Some(task_result) = task_result { - match task_result.try_into_result()? { - Ok(group_id) => { - let Some(group_model) = self.repo.get_group(&group_id).await? else { - return Err(Error::PersistenceError( - PersistenceError::DataInconsistencyError( - "Group task result references nonexistent group".into(), - ), - )); - }; - let group = crate::group::Group::from_model(group_model, participants.clone()); - Some(group) - } - Err(_) => None, - } - } else { - None - }; - - assert_eq!(task_model.task_type, TaskType::Group); - - let task = Box::new(GroupTask::from_model( - task_model, - participants, - communicator, - group, - )?); - Ok(task) - } - - async fn threshold_task_from_model( - &self, - task_model: TaskModel, - communicator: Arc>, - participants: Vec, - ) -> Result, Error> { - let Some(group_id) = &task_model.group_id else { - return Err(Error::PersistenceError( - PersistenceError::DataInconsistencyError( - "Threshold task is missing a group".into(), - ), - )); - }; - let Some(group_model) = self.repo.get_group(group_id).await? else { - return Err(Error::PersistenceError( - PersistenceError::DataInconsistencyError( - "Threshold task references nonexistent group".into(), - ), - )); - }; - let group = crate::group::Group::from_model(group_model, participants); - - let task: Box = match task_model.task_type { - TaskType::Group => unreachable!(), - TaskType::SignPdf => { - Box::new(SignPDFTask::from_model(task_model, communicator, group)?) - } - TaskType::SignChallenge => { - Box::new(SignTask::from_model(task_model, communicator, group)?) - } - TaskType::Decrypt => { - Box::new(DecryptTask::from_model(task_model, communicator, group)?) - } - }; - Ok(task) - } - pub fn set_task_last_update(&self, task_id: &Uuid) { - self.task_last_updates - .insert(task_id.clone(), get_timestamp()); + self.task_last_updates.insert(*task_id, get_timestamp()); } pub fn get_task_last_update(&self, task_id: &Uuid) -> u64 { *self .task_last_updates - .entry(task_id.clone()) + .entry(*task_id) .or_insert(get_timestamp()) } - - pub async fn format_task( - &self, - task_model: TaskModel, - device_id: Option<&[u8]>, - request: Option>, - ) -> Result { - let request = request.map(Vec::from); - let id = task_model.id.as_bytes().to_vec(); - let r#type: proto::TaskType = task_model.task_type.clone().into(); - let r#type = r#type.into(); - let attempt = task_model.attempt_count as u32; - let task = match task_model.result.clone() { - None => { - let task = self.task_from_task_model(task_model).await?; - if !task.is_approved().await { - let (accept, reject) = task.get_decisions().await; - proto::Task::created(id, r#type, accept, reject, request, attempt) - } else { - let round = task.get_round() as u32; - let data = if let Some(device_id) = device_id { - task.get_work(device_id).await - } else { - Vec::new() - }; - proto::Task::running(id, r#type, round, data, request, attempt) - } - } - Some(result) => match result.try_into_result()? { - Ok(result) => proto::Task::finished(id, r#type, result, request, attempt), - Err(reason) => { - let round = task_model.protocol_round as u32; - let task = self.task_from_task_model(task_model).await?; - let (accept, reject) = task.get_decisions().await; - proto::Task::failed(id, r#type, round, accept, reject, reason, request, attempt) - } - }, - }; - Ok(task) - } } diff --git a/src/task_store.rs b/src/task_store.rs new file mode 100644 index 0000000..bcabb40 --- /dev/null +++ b/src/task_store.rs @@ -0,0 +1,491 @@ +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use tokio::sync::{OwnedRwLockWriteGuard, RwLock}; +use uuid::Uuid; + +use crate::communicator::Communicator; +use crate::error::Error; +use crate::persistence::{ + Group, Participant, PersistenceError, Repository, Task as TaskModel, TaskType, +}; +use crate::tasks::{ + decrypt::DecryptTask, group::GroupTask, sign::SignTask, sign_pdf::SignPDFTask, DeclinedTask, + FailedTask, FinishedTask, RunningTask, RunningTaskContext, Task, TaskInfo, TaskResult, + VotingTask, +}; + +/// A lazily populated `Task` cache. +/// +/// All `get*` methods first ensure the `Task` is cached, +/// then a reference to it is returned. +pub struct TaskStore { + // NOTE: DashMap locking applies to its internal shards, we must protect Tasks across awaits using tokio's RwLock. + task_cache: DashMap>>, + repo: Arc, +} + +impl TaskStore { + /// Creates an empty `TaskStore`. + pub fn new(repo: Arc) -> Self { + Self { + task_cache: DashMap::new(), + repo, + } + } + + /// Caches the provided `task` and persists it into the DB. + /// If a `Task` with the same `task_id` already exists in the cache, it is returned. + pub async fn persist_task(&self, task: VotingTask) -> Result, Error> { + let participant_ids_shares: Vec<_> = task + .task_info + .participants + .iter() + .map(|participant| (participant.device.id.as_slice(), participant.shares)) + .collect(); + + match &task.running_task_context { + RunningTaskContext::Group { threshold, note } => { + self.persist_group_task( + &task, + &participant_ids_shares, + *threshold, + note.as_deref(), + ) + .await?; + } + RunningTaskContext::SignChallenge { group, data } => { + self.persist_threshold_task( + &task, + TaskType::SignChallenge, + &participant_ids_shares, + group, + data, + ) + .await?; + } + RunningTaskContext::SignPdf { group, data } => { + self.persist_threshold_task( + &task, + TaskType::SignPdf, + &participant_ids_shares, + group, + data, + ) + .await?; + } + RunningTaskContext::Decrypt { group, data, .. } => { + self.persist_threshold_task( + &task, + TaskType::Decrypt, + &participant_ids_shares, + group, + data, + ) + .await?; + } + } + + let evicted_task = self + .task_cache + .insert(task.task_info.id, Arc::new(RwLock::new(Task::Voting(task)))) + .map(|evicted_task| Arc::into_inner(evicted_task).unwrap().into_inner()); + Ok(evicted_task) + } + + /// Returns an iterator of mutable `Task` references. + /// Returns an error if any of the provided `task_ids` + /// does not reference an existing `Task`. + pub async fn get_tasks_mut( + &self, + task_ids: Vec, + ) -> Result< + impl Iterator>> + use<'_>, + Error, + > { + self.get_tasks_write_guards(task_ids).await + } + + /// Returns an iterator of shared `Task` references. + /// Returns an error if any of the provided `task_ids` + /// does not reference an existing `Task`. + pub async fn get_tasks( + &self, + task_ids: Vec, + ) -> Result< + impl Iterator>> + use<'_>, + Error, + > { + let iterator = self + .get_tasks_write_guards(task_ids) + .await? + .map(|write_guard| async { write_guard.await.downgrade() }); + Ok(iterator) + } + + /// Returns a mutable reference to a `Task`. + /// Returns an error if the provided `task_id` + /// does not reference an existing `Task`. + pub async fn get_task_mut( + &self, + task_id: &Uuid, + ) -> Result, Error> { + let task = self + .get_tasks_write_guards(vec![*task_id]) + .await? + .next() + .unwrap() + .await; + Ok(task) + } + + /// Returns a shared reference to a `Task`. + /// Returns an error if the provided `task_id` + /// does not reference an existing `Task`. + pub async fn get_task(&self, task_id: &Uuid) -> Result, Error> { + let task = self + .get_tasks_write_guards(vec![*task_id]) + .await? + .next() + .unwrap() + .await + .downgrade(); + Ok(task) + } + + async fn persist_group_task( + &self, + task: &VotingTask, + participant_ids_shares: &[(&[u8], u32)], + threshold: u32, + note: Option<&str>, + ) -> Result<(), Error> { + self.repo + .create_group_task( + Some(&task.task_info.id), + participant_ids_shares, + threshold, + task.task_info.protocol_type, + task.task_info.key_type, + &task.request, + note, + ) + .await?; + Ok(()) + } + async fn persist_threshold_task( + &self, + task: &VotingTask, + task_type: TaskType, + participant_ids_shares: &[(&[u8], u32)], + group: &Group, + data: &[u8], + ) -> Result<(), Error> { + self.repo + .create_threshold_task( + Some(&task.task_info.id), + &group.id, + participant_ids_shares, + group.threshold as u32, + "name", // TODO: Fix name checks + data, + &task.request, + task_type, + task.task_info.key_type, + task.task_info.protocol_type, + ) + .await?; + Ok(()) + } + + async fn ensure_cached_tasks(&self, task_ids: impl Iterator) -> Result<(), Error> { + let uncached_task_ids: Vec<_> = task_ids + .filter(|task_id| !self.task_cache.contains_key(task_id)) + .collect(); + let tasks = self.hydrate_tasks(&uncached_task_ids).await?; + for task in tasks { + let task_id = task.task_info().id; + if let Entry::Vacant(entry) = self.task_cache.entry(task_id) { + entry.insert(Arc::new(RwLock::new(task))); + } + } + Ok(()) + } + + async fn get_tasks_write_guards( + &self, + task_ids: Vec, + ) -> Result< + impl Iterator>> + use<'_>, + Error, + > { + let task_ids: HashSet = task_ids.into_iter().collect(); + self.ensure_cached_tasks(task_ids.iter().cloned()).await?; + let iterator = self + .task_cache + .iter() + .filter(move |kv| task_ids.contains(kv.key())) + .map(|kv| kv.clone().write_owned()); + Ok(iterator) + } + + async fn hydrate_tasks(&self, task_ids: &[Uuid]) -> Result, Error> { + let task_models = self.repo.get_task_models(task_ids).await?; + if task_models.len() != task_ids.len() { + return Err(Error::GeneralProtocolError("Invalid task id(s)".into())); + } + let task_id_participant_pairs = self.repo.get_tasks_participants(task_ids).await?; + + let mut task_id_participants: HashMap<_, Vec<_>> = HashMap::new(); + + for (task_id, device) in task_id_participant_pairs { + task_id_participants + .entry(task_id) + .or_default() + .push(device); + } + + // NOTE: When hydrating many tasks, `future::join_all` would cause + // a deadlock with `repository::get_async_connection`. + let mut tasks = Vec::new(); + for task_model in task_models { + let participants = task_id_participants.remove(&task_model.id).unwrap(); + let task_info = TaskInfo { + id: task_model.id, + name: "".into(), // TODO: Persist "name" in TaskModel + task_type: task_model.task_type, + protocol_type: task_model.protocol_type, + key_type: task_model.key_type, + participants, + attempts: task_model.attempt_count as u32, + }; + let task = match task_model.task_type { + TaskType::Group => self.group_task_from_model(task_info, task_model).await?, + TaskType::SignChallenge | TaskType::SignPdf | TaskType::Decrypt => { + self.threshold_task_from_model(task_info, task_model) + .await? + } + }; + + tasks.push(task) + } + Ok(tasks) + } + + async fn group_task_from_model( + &self, + task_info: TaskInfo, + task_model: TaskModel, + ) -> Result { + assert_eq!(task_model.task_type, TaskType::Group); + + let task_result = task_model.result.clone(); + let task = if let Some(task_result) = task_result { + match task_result.try_into_result()? { + Ok(group_id) => { + let Some(group) = self.repo.get_group(&group_id).await? else { + return Err(PersistenceError::DataInconsistencyError( + "Group task result references nonexistent group".into(), + ) + .into()); + }; + let result = TaskResult::GroupEstablished(group); + let acknowledgements = + self.repo.get_task_acknowledgements(&task_model.id).await?; + Task::Finished(FinishedTask { + task_info, + result, + acknowledgements, + }) + } + Err(reason) => { + let decisions = self.repo.get_task_decisions(&task_model.id).await?; + let (accepts, rejects) = VotingTask::accepts_rejects(&decisions); + if rejects > 0 { + Task::Declined(DeclinedTask { + task_info, + accepts, + rejects, + }) + } else { + Task::Failed(FailedTask { task_info, reason }) + } + } + } + } else { + let decisions = self.repo.get_task_decisions(&task_model.id).await?; + let (accepts, _) = VotingTask::accepts_rejects(&decisions); + let accept_threshold = task_info.total_shares(); + if accepts < accept_threshold { + let running_task_context = RunningTaskContext::Group { + threshold: task_model.threshold as u32, + note: task_model.note, + }; + let task = VotingTask { + task_info, + decisions, + accept_threshold, + request: task_model.request, + running_task_context, + }; + Task::Voting(task) + } else { + let communicator = self + .hydrate_communicator(&task_model, task_info.participants.clone()) + .await?; + let task = Box::new(GroupTask::from_model(task_info, task_model, communicator)?); + Task::Running(task) + } + }; + Ok(task) + } + + async fn threshold_task_from_model( + &self, + task_info: TaskInfo, + task_model: TaskModel, + ) -> Result { + let task_result = task_model.result.clone(); + let task = if let Some(task_result) = task_result { + match task_result.try_into_result()? { + Ok(data) => { + let result = match task_model.task_type { + TaskType::Group => unreachable!(), + TaskType::SignPdf => TaskResult::SignedPdf(data), + TaskType::SignChallenge => TaskResult::Signed(data), + TaskType::Decrypt => TaskResult::Decrypted(data), + }; + let acknowledgements = + self.repo.get_task_acknowledgements(&task_model.id).await?; + Task::Finished(FinishedTask { + task_info, + result, + acknowledgements, + }) + } + Err(reason) => { + let decisions = self.repo.get_task_decisions(&task_model.id).await?; + let (accepts, rejects) = VotingTask::accepts_rejects(&decisions); + let total_shares = task_info.total_shares(); + let reject_threshold = total_shares - task_model.threshold as u32 + 1; + if rejects >= reject_threshold { + Task::Declined(DeclinedTask { + task_info, + accepts, + rejects, + }) + } else { + Task::Failed(FailedTask { task_info, reason }) + } + } + } + } else { + let Some(group_id) = &task_model.group_id else { + return Err(PersistenceError::DataInconsistencyError( + "Threshold task is missing a group".into(), + ) + .into()); + }; + let Some(group) = self.repo.get_group(group_id).await? else { + return Err(PersistenceError::DataInconsistencyError( + "Threshold task references nonexistent group".into(), + ) + .into()); + }; + + let decisions = self.repo.get_task_decisions(&task_model.id).await?; + let (accepts, _) = VotingTask::accepts_rejects(&decisions); + let accept_threshold = group.threshold as u32; + if accepts < accept_threshold { + let data = task_model + .task_data + .ok_or(PersistenceError::DataInconsistencyError( + "Threshold task has no task data".into(), + ))?; + let running_task_context = match task_model.task_type { + TaskType::Group => unreachable!(), + TaskType::SignPdf => RunningTaskContext::SignPdf { group, data }, + TaskType::SignChallenge => RunningTaskContext::SignChallenge { group, data }, + TaskType::Decrypt => RunningTaskContext::Decrypt { group, data }, + }; + let task = VotingTask { + task_info, + decisions, + accept_threshold, + request: task_model.request, + running_task_context, + }; + Task::Voting(task) + } else { + let communicator = self + .hydrate_communicator(&task_model, task_info.participants.clone()) + .await?; + let task: Box = match task_model.task_type { + TaskType::Group => unreachable!(), + TaskType::SignPdf => Box::new(SignPDFTask::from_model( + task_info, + task_model, + communicator, + group, + )?), + TaskType::SignChallenge => Box::new(SignTask::from_model( + task_info, + task_model, + communicator, + group, + )?), + TaskType::Decrypt => Box::new(DecryptTask::from_model( + task_info, + task_model, + communicator, + group, + )?), + }; + Task::Running(task) + } + }; + Ok(task) + } + + async fn hydrate_communicator( + &self, + task_model: &TaskModel, + mut all_participants: Vec, + ) -> Result { + let threshold = match task_model.task_type { + TaskType::Group => all_participants.iter().map(|p| p.shares).sum(), + _ => task_model.threshold as u32, + }; + + let active_shares = self.repo.get_task_active_shares(&task_model.id).await?; + + all_participants.sort_by(|a, b| a.device.id.cmp(&b.device.id)); + let first_share_indices: HashMap, u32> = all_participants + .iter() + .scan(0, |idx, p| { + let first_share = *idx; + *idx += p.shares; + Some((p.device.id.clone(), first_share)) + }) + .collect(); + let active_shares = all_participants + .iter() + .filter_map(|p| active_shares.get(&p.device.id).map(|shares| (p, shares))) + .flat_map(|(p, shares)| std::iter::repeat_n(&p.device, *shares as usize)) + .scan(first_share_indices, |share_indices, device| { + let share_index = share_indices[&device.id]; + *share_indices.get_mut(&device.id).unwrap() += 1; + Some((share_index, device.clone())) + }) + .collect(); + + Ok(Communicator::new( + threshold, + task_model.protocol_type.into(), + active_shares, + )) + } +} diff --git a/src/tasks/decrypt.rs b/src/tasks/decrypt.rs index 39bb4c6..710ed62 100644 --- a/src/tasks/decrypt.rs +++ b/src/tasks/decrypt.rs @@ -1,93 +1,46 @@ -use std::sync::Arc; - use crate::communicator::Communicator; use crate::error::Error; -use crate::group::Group; -use crate::persistence::{Participant, PersistenceError, Task as TaskModel}; -use crate::proto::{DecryptRequest, TaskType}; +use crate::persistence::{Group, PersistenceError, Task as TaskModel}; use crate::protocols::elgamal::ElgamalDecrypt; use crate::protocols::{create_threshold_protocol, Protocol}; -use crate::tasks::{DecisionUpdate, RestartUpdate, RoundUpdate, Task, TaskResult}; +use crate::tasks::{FailedTask, FinishedTask, RoundUpdate, RunningTask, TaskInfo, TaskResult}; use crate::utils; -use async_trait::async_trait; use log::info; -use meesign_crypto::proto::{ClientMessage, Message as _}; -use prost::Message as _; -use tokio::sync::RwLock; -use uuid::Uuid; +use meesign_crypto::proto::ClientMessage; pub struct DecryptTask { - id: Uuid, + task_info: TaskInfo, group: Group, - communicator: Arc>, - result: Option, String>>, + communicator: Communicator, pub(super) data: Vec, pub(super) protocol: Box, - request: Vec, - pub(super) attempts: u32, } impl DecryptTask { - pub fn try_new( + pub fn new( + task_info: TaskInfo, group: Group, - name: String, data: Vec, - data_type: String, - ) -> Result { - let mut participants: Vec = group.participants().to_vec(); - participants.sort_by(|a, b| a.device.identifier().cmp(b.device.identifier())); - - let decisions = participants - .iter() - .map(|p| (p.device.identifier().clone(), 0)) - .collect(); - let acknowledgements = participants - .iter() - .map(|p| (p.device.identifier().clone(), false)) - .collect(); - - let communicator = Arc::new(RwLock::new(Communicator::new( - participants, - group.threshold(), - group.protocol(), - decisions, - acknowledgements, - ))); - - let request = (DecryptRequest { - group_id: group.identifier().to_vec(), - name, - data: data.clone(), - data_type, - }) - .encode_to_vec(); - - let id = Uuid::new_v4(); - Ok(DecryptTask { - id, + communicator: Communicator, + ) -> Self { + DecryptTask { + task_info, group, communicator, - result: None, data, protocol: Box::new(ElgamalDecrypt::new()), - request, - attempts: 0, - }) + } } pub fn from_model( + task_info: TaskInfo, task_model: TaskModel, - communicator: Arc>, + communicator: Communicator, group: Group, ) -> Result { - let result = task_model - .result - .map(|res| res.try_into_result()) - .transpose()?; - let protocol = create_threshold_protocol( - group.protocol(), - group.key_type(), + group.protocol.into(), + group.key_type.into(), task_model.protocol_round as u16, )?; let data = task_model @@ -96,237 +49,123 @@ impl DecryptTask { "Task data not set for a sign task".into(), ))?; let task = Self { - id: task_model.id, + task_info, group, communicator, - result, data, protocol, - request: task_model.request, - attempts: task_model.attempt_count as u32, }; Ok(task) } - pub(super) async fn start_task(&mut self) -> Result { - assert!(self.communicator.read().await.accept_count() >= self.group.threshold()); - self.protocol - .initialize(&mut *self.communicator.write().await, &self.data); + pub(super) fn start_task(&mut self) -> Result { + self.protocol.initialize(&mut self.communicator, &self.data); Ok(RoundUpdate::NextRound(self.protocol.round())) } - pub(super) async fn advance_task(&mut self) -> Result { - self.protocol.advance(&mut *self.communicator.write().await); + pub(super) fn advance_task(&mut self) -> Result { + self.protocol.advance(&mut self.communicator); Ok(RoundUpdate::NextRound(self.protocol.round())) } - pub(super) async fn finalize_task(&mut self) -> Result { - let decrypted = self - .protocol - .finalize(&mut *self.communicator.write().await); + pub(super) fn finalize_task(&mut self) -> Result { + let decrypted = self.protocol.finalize(&mut self.communicator); if decrypted.is_none() { let reason = "Task failed (data not output)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info.clone(), + reason, + })); } let decrypted = decrypted.unwrap(); info!( "Data decrypted by group_id={}", - utils::hextrunc(self.group.identifier()) + utils::hextrunc(&self.group.id) ); - self.set_result(Ok(decrypted.clone())); - - self.communicator.write().await.clear_input(); + self.communicator.clear_input(); Ok(RoundUpdate::Finished( self.protocol.round(), - TaskResult::Decrypted(decrypted), + FinishedTask::new(self.task_info.clone(), TaskResult::Decrypted(decrypted)), )) } - pub(super) async fn next_round(&mut self) -> Result { + pub(super) fn next_round(&mut self) -> Result { if self.protocol.round() == 0 { - self.start_task().await + self.start_task() } else if self.protocol.round() < self.protocol.last_round() { - self.advance_task().await + self.advance_task() } else { - self.finalize_task().await + self.finalize_task() } } - pub(super) async fn update_internal( + pub(super) fn update_internal( &mut self, device_id: &[u8], - data: &Vec>, + messages: Vec, ) -> Result { - if self.communicator.read().await.accept_count() < self.group.threshold() { - return Err(Error::GeneralProtocolError( - "Not enough agreements to proceed with the protocol.".into(), - )); - } - - if !self.waiting_for(device_id).await { + if !self.waiting_for(device_id) { return Err(Error::GeneralProtocolError( "Wasn't waiting for a message from this ID.".into(), )); } - let messages = data - .iter() - .map(|d| ClientMessage::decode(d.as_slice())) - .collect::, _>>() - .map_err(|_| Error::GeneralProtocolError("Expected ClientMessage".into()))?; - - self.communicator - .write() - .await - .receive_messages(device_id, messages); + self.communicator.receive_messages(device_id, messages); - if self.communicator.read().await.round_received() - && self.protocol.round() <= self.protocol.last_round() + if self.communicator.round_received() && self.protocol.round() <= self.protocol.last_round() { return Ok(true); } Ok(false) } - // TODO: deduplicate across sign and decrypt - pub(super) async fn decide_internal( - &mut self, - device_id: &[u8], - decision: bool, - ) -> Option { - self.communicator.write().await.decide(device_id, decision); - - if self.result.is_none() && self.protocol.round() == 0 { - if self.communicator.read().await.reject_count() >= self.group.reject_threshold() { - self.set_result(Err("Task declined".to_string())); - return Some(false); - } else if self.communicator.read().await.accept_count() >= self.group.threshold() { - return Some(true); - } - } - None - } - - fn set_result(&mut self, result: Result, String>) { - self.result = Some(result); - } - fn increment_attempt_count(&mut self) { - self.attempts += 1; + self.task_info.attempts += 1; } } -#[async_trait] -impl Task for DecryptTask { - fn get_type(&self) -> TaskType { - TaskType::Decrypt +impl RunningTask for DecryptTask { + fn task_info(&self) -> &TaskInfo { + &self.task_info } - async fn get_work(&self, device_id: &[u8]) -> Vec> { - if !self.waiting_for(device_id).await { + fn get_work(&self, device_id: &[u8]) -> Vec> { + if !self.waiting_for(device_id) { return Vec::new(); } - self.communicator.read().await.get_messages(device_id) + self.communicator.get_messages(device_id) } fn get_round(&self) -> u16 { self.protocol.round() } - async fn get_decisions(&self) -> (u32, u32) { - ( - self.communicator.read().await.accept_count(), - self.communicator.read().await.reject_count(), - ) + fn initialize(&mut self) -> Result { + self.start_task() } - async fn update( + fn update( &mut self, device_id: &[u8], - data: &Vec>, + messages: Vec, ) -> Result { - let round_update = if self.update_internal(device_id, data).await? { - self.next_round().await? + let round_update = if self.update_internal(device_id, messages)? { + self.next_round()? } else { RoundUpdate::Listen }; Ok(round_update) } - async fn restart(&mut self) -> Result { - if self.result.is_some() { - return Ok(RestartUpdate::AlreadyFinished); - } - - if self.is_approved().await { - self.increment_attempt_count(); - let round_update = self.start_task().await?; - Ok(RestartUpdate::Started(round_update)) - } else { - Ok(RestartUpdate::Voting) - } - } - - async fn is_approved(&self) -> bool { - self.communicator.read().await.accept_count() >= self.group.threshold() - } - - fn get_participants(&self) -> &Vec { - &self.group.participants() - } - - async fn waiting_for(&self, device: &[u8]) -> bool { - if self.protocol.round() == 0 { - return !self.communicator.read().await.device_decided(device); - } else if self.protocol.round() >= self.protocol.last_round() { - return !self.communicator.read().await.device_acknowledged(device); - } - - self.communicator.read().await.waiting_for(device) - } - - async fn decide(&mut self, device_id: &[u8], decision: bool) -> Result { - let result = self.decide_internal(device_id, decision).await; - let decision_update = match result { - Some(true) => { - let round_update = self.next_round().await?; - DecisionUpdate::Accepted(round_update) - } - Some(false) => DecisionUpdate::Declined, - None => DecisionUpdate::Undecided, - }; - Ok(decision_update) - } - - async fn acknowledge(&mut self, device_id: &[u8]) { - self.communicator.write().await.acknowledge(device_id); - } - - fn get_request(&self) -> &[u8] { - &self.request - } - - fn get_attempts(&self) -> u32 { - self.attempts - } - - fn get_id(&self) -> &Uuid { - &self.id - } - - fn get_communicator(&self) -> Arc> { - self.communicator.clone() - } - - fn get_threshold(&self) -> u32 { - self.group.threshold() + fn restart(&mut self) -> Result { + self.increment_attempt_count(); + self.start_task() } - fn get_data(&self) -> Option<&[u8]> { - Some(&self.data) + fn waiting_for(&self, device: &[u8]) -> bool { + self.communicator.waiting_for(device) } } diff --git a/src/tasks/group.rs b/src/tasks/group.rs index a61a9ab..1f07c76 100644 --- a/src/tasks/group.rs +++ b/src/tasks/group.rs @@ -1,53 +1,41 @@ use crate::communicator::Communicator; use crate::error::Error; -use crate::group::Group; -use crate::persistence::Participant; -use crate::persistence::PersistenceError; -use crate::persistence::Task as TaskModel; -use crate::proto::{KeyType, ProtocolType, TaskType}; +use crate::persistence::{Group, PersistenceError, Task as TaskModel}; +use crate::proto::ProtocolType; use crate::protocols::{create_keygen_protocol, Protocol}; -use crate::tasks::{DecisionUpdate, RestartUpdate, RoundUpdate, Task, TaskResult}; +use crate::tasks::{FailedTask, FinishedTask, RoundUpdate, RunningTask, TaskInfo, TaskResult}; use crate::utils; -use async_trait::async_trait; use log::{info, warn}; use meesign_crypto::proto::{ClientMessage, Message as _, ServerMessage}; -use prost::Message as _; use std::collections::HashMap; use std::io::Read; use std::process::{Command, Stdio}; -use std::sync::Arc; -use tokio::sync::RwLock; -use uuid::Uuid; pub struct GroupTask { - name: String, - id: Uuid, + task_info: TaskInfo, threshold: u32, - key_type: KeyType, - participants: Vec, - communicator: Arc>, - result: Option>, + communicator: Communicator, protocol: Box, - request: Vec, - attempts: u32, note: Option, certificates_sent: bool, // TODO: remove the field completely } impl GroupTask { pub fn try_new( - name: &str, - mut participants: Vec, + mut task_info: TaskInfo, threshold: u32, - protocol_type: ProtocolType, - key_type: KeyType, note: Option, + communicator: Communicator, ) -> Result { - let id = Uuid::new_v4(); + let total_shares = task_info.total_shares(); - let total_shares: u32 = participants.iter().map(|p| p.shares).sum(); - - let protocol = create_keygen_protocol(protocol_type, key_type, total_shares, threshold, 0)?; + let protocol = create_keygen_protocol( + task_info.protocol_type.into(), + task_info.key_type.into(), + total_shares, + threshold, + 0, + )?; if total_shares < 1 { warn!("Invalid number of parties {}", total_shares); @@ -58,149 +46,74 @@ impl GroupTask { return Err("Invalid input".into()); } - participants.sort_by(|a, b| a.device.identifier().cmp(b.device.identifier())); - - let group_task_threshold = total_shares; - - let decisions = participants - .iter() - .map(|p| (p.device.identifier().clone(), 0)) - .collect(); - let acknowledgements = participants - .iter() - .map(|p| (p.device.identifier().clone(), false)) - .collect(); - - let communicator = Arc::new(RwLock::new(Communicator::new( - participants.clone(), - group_task_threshold, - protocol.get_type(), - decisions, - acknowledgements, - ))); - - let device_ids = participants - .iter() - .flat_map(|p| std::iter::repeat(p.device.identifier().to_vec()).take(p.shares as usize)) - .collect(); - let request = (crate::proto::GroupRequest { - device_ids, - name: String::from(name), - threshold, - protocol: protocol.get_type() as i32, - key_type: key_type as i32, - note: note.to_owned(), - }) - .encode_to_vec(); + task_info + .participants + .sort_by(|a, b| a.device.identifier().cmp(b.device.identifier())); Ok(GroupTask { - name: name.into(), - id, + task_info, threshold, - participants: participants.to_vec(), - key_type, communicator, - result: None, protocol, - request, - attempts: 0, note, certificates_sent: false, }) } pub fn from_model( + task_info: TaskInfo, model: TaskModel, - participants: Vec, - communicator: Arc>, - group: Option, + communicator: Communicator, ) -> Result { - let total_shares = participants.iter().map(|p| p.shares).sum(); - let protocol = create_keygen_protocol( model.protocol_type.into(), - model.key_type.clone().into(), - total_shares, + model.key_type.into(), + task_info.total_shares(), model.threshold as u32, model.protocol_round as u16, )?; - - // TODO: refactor - let result = model.result.map(|res| res.try_into_result()).transpose()?; - let result = match result { - Some(Ok(group_id)) => { - let Some(group) = group else { - return Err(Error::PersistenceError( - PersistenceError::DataInconsistencyError( - "Established group is missing".into(), - ), - )); - }; - assert_eq!(group_id, group.identifier()); - Some(Ok(group)) - } - Some(Err(err)) => Some(Err(err)), - None => None, - }; - let name = if let Some(Ok(group)) = &result { - group.name().into() - } else { - "".into() // TODO add field to the task table - }; let Some(certificates_sent) = model.group_certificates_sent else { - return Err(Error::PersistenceError( - PersistenceError::DataInconsistencyError( - "certificates_sent flag missing in group task".into(), - ), - )); + return Err(PersistenceError::DataInconsistencyError( + "certificates_sent flag missing in group task".into(), + ) + .into()); }; Ok(Self { - name, - id: model.id, + task_info, threshold: model.threshold as u32, - key_type: model.key_type.into(), - participants, communicator, - result, protocol, - request: model.request, - attempts: model.attempt_count as u32, note: model.note, certificates_sent, // TODO: remove the field completely }) } - fn set_result(&mut self, result: Result) { - self.result = Some(result); - } - fn increment_attempt_count(&mut self) { - self.attempts += 1; + self.task_info.attempts += 1; } - async fn start_task(&mut self) -> Result { - self.protocol - .initialize(&mut *self.communicator.write().await, &[]); + fn start_task(&mut self) -> Result { + self.protocol.initialize(&mut self.communicator, &[]); Ok(RoundUpdate::NextRound(self.protocol.round())) } - async fn advance_task(&mut self) -> Result { - self.protocol.advance(&mut *self.communicator.write().await); + fn advance_task(&mut self) -> Result { + self.protocol.advance(&mut self.communicator); Ok(RoundUpdate::NextRound(self.protocol.round())) } - async fn finalize_task(&mut self) -> Result { - let identifier = self - .protocol - .finalize(&mut *self.communicator.write().await); + fn finalize_task(&mut self) -> Result { + let identifier = self.protocol.finalize(&mut self.communicator); let Some(identifier) = identifier else { let reason = "Task failed (group key not output)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info.clone(), + reason, + })); }; // TODO let certificate = if self.protocol.get_type() == ProtocolType::Gg18 { - Some(issue_certificate(&self.name, &identifier)?) + Some(issue_certificate(&self.task_info.name, &identifier)?) } else { None }; @@ -208,54 +121,56 @@ impl GroupTask { info!( "Group established group_id={} devices={:?}", utils::hextrunc(&identifier), - self.participants + self.task_info + .participants .iter() .map(|p| (utils::hextrunc(p.device.identifier()), p.shares)) .collect::>() ); - let group = Group::new( - identifier.clone(), - self.name.clone(), - self.threshold, - self.participants.clone(), - self.protocol.get_type(), - self.key_type, + let group = Group { + id: identifier.clone(), + name: self.task_info.name.clone(), + threshold: self.threshold as i32, + protocol: self.protocol.get_type().into(), + key_type: self.task_info.key_type, certificate, - self.note.clone(), - ); - - self.set_result(Ok(group.clone())); + note: self.note.clone(), + participant_ids_shares: self + .task_info + .participants + .iter() + .map(|p| (p.device.id.clone(), p.shares)) + .collect(), + }; - self.communicator.write().await.clear_input(); + self.communicator.clear_input(); Ok(RoundUpdate::Finished( self.protocol.round(), - TaskResult::GroupEstablished(group), + FinishedTask::new(self.task_info.clone(), TaskResult::GroupEstablished(group)), )) } - async fn next_round(&mut self) -> Result { + fn next_round(&mut self) -> Result { if !self.certificates_sent { - self.send_certificates().await + self.send_certificates() } else if self.protocol.round() == 0 { - self.start_task().await + self.start_task() } else if self.protocol.round() < self.protocol.last_round() { - self.advance_task().await + self.advance_task() } else { - self.finalize_task().await + self.finalize_task() } } - async fn send_certificates(&mut self) -> Result { - self.communicator.write().await.set_active_devices(None); - + fn send_certificates(&mut self) -> Result { let certs: HashMap> = { - let communicator_read = self.communicator.read().await; - self.participants + self.task_info + .participants .iter() .flat_map(|p| { let cert = &p.device.certificate; - communicator_read + self.communicator .identifier_to_indices(p.device.identifier()) .into_iter() .zip(std::iter::repeat(cert).cloned()) @@ -269,24 +184,23 @@ impl GroupTask { } .encode_to_vec(); - self.communicator.write().await.send_all(|_| certs.clone()); + self.communicator.send_all(|_| certs.clone()); self.certificates_sent = true; Ok(RoundUpdate::GroupCertificatesSent) } } -#[async_trait] -impl Task for GroupTask { - fn get_type(&self) -> TaskType { - TaskType::Group +impl RunningTask for GroupTask { + fn task_info(&self) -> &TaskInfo { + &self.task_info } - async fn get_work(&self, device_id: &[u8]) -> Vec> { - if !self.waiting_for(device_id).await { + fn get_work(&self, device_id: &[u8]) -> Vec> { + if !self.waiting_for(device_id) { return Vec::new(); } - self.communicator.read().await.get_messages(device_id) + self.communicator.get_messages(device_id) } fn get_round(&self) -> u16 { @@ -297,130 +211,41 @@ impl Task for GroupTask { } } - async fn get_decisions(&self) -> (u32, u32) { - let communicator = self.communicator.read().await; - (communicator.accept_count(), communicator.reject_count()) + fn initialize(&mut self) -> Result { + self.send_certificates() } - async fn update( + fn update( &mut self, device_id: &[u8], - data: &Vec>, + messages: Vec, ) -> Result { - let total_shares: u32 = self.participants.iter().map(|p| p.shares).sum(); - if self.communicator.read().await.accept_count() != total_shares { - return Err(Error::GeneralProtocolError( - "Not enough agreements to proceed with the protocol.".into(), - )); - } - - if !self.waiting_for(device_id).await { + if !self.waiting_for(device_id) { return Err(Error::GeneralProtocolError( "Wasn't waiting for a message from this ID.".into(), )); } - assert_eq!(self.certificates_sent, true); + assert!(self.certificates_sent); - let messages = data - .iter() - .map(|d| ClientMessage::decode(d.as_slice())) - .collect::, _>>() - .map_err(|_| Error::GeneralProtocolError("Expected ClientMessage.".into()))?; + self.communicator.receive_messages(device_id, messages); - self.communicator - .write() - .await - .receive_messages(device_id, messages); - - if self.communicator.read().await.round_received() - && self.protocol.round() <= self.protocol.last_round() + if self.communicator.round_received() && self.protocol.round() <= self.protocol.last_round() { - return self.next_round().await; + return self.next_round(); } Ok(RoundUpdate::Listen) } - async fn restart(&mut self) -> Result { - if self.result.is_some() { - return Ok(RestartUpdate::AlreadyFinished); - } - - if self.is_approved().await { - self.increment_attempt_count(); - // TODO: Should this instead be the certificate exchange round? - let round_update = self.start_task().await?; - Ok(RestartUpdate::Started(round_update)) - } else { - Ok(RestartUpdate::Voting) - } - } - - async fn is_approved(&self) -> bool { - let total_shares: u32 = self.participants.iter().map(|p| p.shares).sum(); - self.communicator.read().await.accept_count() == total_shares - } - - fn get_participants(&self) -> &Vec { - &self.participants - } - - async fn waiting_for(&self, device: &[u8]) -> bool { - let communicator = self.communicator.write().await; - if !self.certificates_sent && self.protocol.round() == 0 { - return !communicator.device_decided(device); - } else if self.protocol.round() >= self.protocol.last_round() { - return !communicator.device_acknowledged(device); - } - - communicator.waiting_for(device) - } - - async fn decide(&mut self, device_id: &[u8], decision: bool) -> Result { - self.communicator.write().await.decide(device_id, decision); - let decision_update = if self.result.is_none() && self.protocol.round() == 0 { - if self.communicator.read().await.reject_count() > 0 { - self.set_result(Err("Task declined".to_string())); - DecisionUpdate::Declined - } else if self.is_approved().await { - let round_update = self.next_round().await?; - DecisionUpdate::Accepted(round_update) - } else { - DecisionUpdate::Undecided - } - } else { - DecisionUpdate::Undecided - }; - Ok(decision_update) - } - - async fn acknowledge(&mut self, device_id: &[u8]) { - self.communicator.write().await.acknowledge(device_id); - } - - fn get_request(&self) -> &[u8] { - &self.request - } - - fn get_attempts(&self) -> u32 { - self.attempts - } - - fn get_id(&self) -> &Uuid { - &self.id - } - - fn get_communicator(&self) -> Arc> { - self.communicator.clone() - } - - fn get_threshold(&self) -> u32 { - self.threshold + fn restart(&mut self) -> Result { + self.increment_attempt_count(); + // TODO: Should this instead be the certificate exchange round? + self.start_task() } - fn get_data(&self) -> Option<&[u8]> { - None + fn waiting_for(&self, device: &[u8]) -> bool { + self.communicator.waiting_for(device) } } diff --git a/src/tasks/mod.rs b/src/tasks/mod.rs index 59e1354..83098c9 100644 --- a/src/tasks/mod.rs +++ b/src/tasks/mod.rs @@ -3,41 +3,31 @@ pub(crate) mod group; pub(crate) mod sign; pub(crate) mod sign_pdf; -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::sync::RwLock; +use meesign_crypto::proto::ClientMessage; +use std::collections::{HashMap, HashSet}; use uuid::Uuid; use crate::communicator::Communicator; use crate::error::Error; -use crate::group::Group; -use crate::persistence::Participant; +use crate::persistence::{Device, Group, KeyType, Participant, ProtocolType, TaskType}; +use crate::proto; #[must_use = "updates must be persisted"] pub enum RoundUpdate { Listen, GroupCertificatesSent, - NextRound(u16), // round number - Finished(u16, TaskResult), // round number, result - Failed(String), // failure reason + NextRound(u16), // round number + Finished(u16, FinishedTask), // round number, finished task + Failed(FailedTask), } #[must_use = "updates must be persisted"] pub enum DecisionUpdate { Undecided, - Accepted(RoundUpdate), - Declined, -} - -#[must_use = "updates must be persisted"] -pub enum RestartUpdate { - AlreadyFinished, - Voting, - Started(RoundUpdate), + Accepted, + Declined(DeclinedTask), } -#[derive(Clone)] pub enum TaskResult { GroupEstablished(Group), Signed(Vec), @@ -48,7 +38,7 @@ pub enum TaskResult { impl TaskResult { pub fn as_bytes(&self) -> &[u8] { match self { - TaskResult::GroupEstablished(group) => group.identifier(), + TaskResult::GroupEstablished(group) => &group.id, TaskResult::Signed(data) => data, TaskResult::SignedPdf(data) => data, TaskResult::Decrypted(data) => data, @@ -56,34 +46,250 @@ impl TaskResult { } } -#[async_trait] -pub trait Task: Send + Sync { - fn get_type(&self) -> crate::proto::TaskType; - async fn get_work(&self, device_id: &[u8]) -> Vec>; - fn get_round(&self) -> u16; - async fn get_decisions(&self) -> (u32, u32); - /// Update protocol state with `data` from `device_id` - async fn update(&mut self, device_id: &[u8], data: &Vec>) - -> Result; +#[must_use] +pub struct VotingTask { + pub task_info: TaskInfo, + pub decisions: HashMap, i8>, + pub accept_threshold: u32, + pub request: Vec, + pub running_task_context: RunningTaskContext, +} +impl VotingTask { + pub async fn decide( + &mut self, + device_id: &[u8], + accept: bool, + ) -> Result { + let shares = self + .task_info + .participants + .iter() + .find(|p| p.device.id == device_id) + .ok_or(Error::GeneralProtocolError( + "Invalid task participant id".into(), + ))? + .shares as i8; + let vote = if accept { shares } else { -shares }; - /// Attempt to restart protocol in task - async fn restart(&mut self) -> Result; + // TODO: Check if this device has already decided + self.decisions.insert(device_id.to_vec(), vote); + + let (accepts, rejects) = Self::accepts_rejects(&self.decisions); + + let decision_update = if accepts >= self.accept_threshold { + DecisionUpdate::Accepted + } else if rejects >= self.reject_threshold() { + DecisionUpdate::Declined(DeclinedTask { + task_info: self.task_info.clone(), + accepts, + rejects, + }) + } else { + DecisionUpdate::Undecided + }; + Ok(decision_update) + } + pub fn accepts_rejects(decisions: &HashMap, i8>) -> (u32, u32) { + let mut accepts = 0; + let mut rejects = 0; + for &vote in decisions.values() { + if vote > 0 { + accepts += (vote as i32).unsigned_abs(); + } + if vote < 0 { + rejects += (vote as i32).unsigned_abs(); + } + } + (accepts, rejects) + } + pub fn reject_threshold(&self) -> u32 { + self.task_info.total_shares() - self.accept_threshold + 1 + } + pub fn device_accepted(&self, device_id: &[u8]) -> bool { + self.decisions.get(device_id) > Some(&0) + } +} +#[must_use] +pub struct DeclinedTask { + pub task_info: TaskInfo, + pub accepts: u32, + pub rejects: u32, +} +#[must_use] +pub struct FinishedTask { + pub task_info: TaskInfo, + pub result: TaskResult, + pub acknowledgements: HashSet>, +} +impl FinishedTask { + pub fn new(task_info: TaskInfo, result: TaskResult) -> Self { + Self { + task_info, + result, + acknowledgements: HashSet::new(), + } + } + pub fn acknowledge(&mut self, device_id: &[u8]) { + // TODO: Check if device_id is a participant + self.acknowledgements.insert(device_id.to_vec()); + } +} +#[must_use] +pub struct FailedTask { + pub task_info: TaskInfo, + pub reason: String, +} +#[must_use] +pub enum Task { + Voting(VotingTask), + Running(Box), + Declined(DeclinedTask), + Finished(FinishedTask), + Failed(FailedTask), +} - /// True if the task has been approved - async fn is_approved(&self) -> bool; +impl Task { + pub fn task_info(&self) -> &TaskInfo { + match self { + Task::Voting(task) => &task.task_info, + Task::Running(task) => task.task_info(), + Task::Declined(task) => &task.task_info, + Task::Finished(task) => &task.task_info, + Task::Failed(task) => &task.task_info, + } + } + pub fn format(&self, device_id: Option<&[u8]>, request: Option>) -> proto::Task { + let task_info = self.task_info(); + let id = task_info.id.as_bytes().to_vec(); + let r#type: proto::TaskType = task_info.task_type.into(); + let r#type = r#type.into(); + let attempt = task_info.attempts; + match self { + Task::Voting(task) => { + let (accept, reject) = VotingTask::accepts_rejects(&task.decisions); + proto::Task::created(id, r#type, accept, reject, request, attempt) + } + Task::Running(task) => { + let round = task.get_round() as u32; + let data = if let Some(device_id) = device_id { + task.get_work(device_id) + } else { + Vec::new() + }; + proto::Task::running(id, r#type, round, data, request, attempt) + } + Task::Finished(task) => proto::Task::finished( + id, + r#type, + task.result.as_bytes().to_vec(), + request, + attempt, + ), + Task::Failed(task) => { + proto::Task::failed(id, r#type, task.reason.clone(), request, attempt) + } + Task::Declined(task) => { + proto::Task::declined(id, r#type, task.accepts, task.rejects, request, attempt) + } + } + } +} - fn get_participants(&self) -> &Vec; - async fn waiting_for(&self, device_id: &[u8]) -> bool; +#[derive(Clone)] +pub struct TaskInfo { + pub id: Uuid, + pub name: String, + pub task_type: TaskType, + pub protocol_type: ProtocolType, + pub key_type: KeyType, + pub participants: Vec, + pub attempts: u32, +} +impl TaskInfo { + pub fn total_shares(&self) -> u32 { + self.participants.iter().map(|p| p.shares).sum() + } +} - /// Store `decision` by `device_id` - async fn decide(&mut self, device_id: &[u8], decision: bool) -> Result; +pub trait RunningTask: Send + Sync { + fn task_info(&self) -> &TaskInfo; - async fn acknowledge(&mut self, device_id: &[u8]); - fn get_request(&self) -> &[u8]; + fn get_work(&self, device_id: &[u8]) -> Vec>; - fn get_attempts(&self) -> u32; - fn get_id(&self) -> &Uuid; - fn get_communicator(&self) -> Arc>; - fn get_threshold(&self) -> u32; - fn get_data(&self) -> Option<&[u8]>; + fn get_round(&self) -> u16; + + fn initialize(&mut self) -> Result; + + /// Update protocol state with `messages` from `device_id` + fn update( + &mut self, + device_id: &[u8], + messages: Vec, + ) -> Result; + + /// Attempt to restart protocol in task + fn restart(&mut self) -> Result; + + fn waiting_for(&self, device_id: &[u8]) -> bool; +} + +#[derive(Clone)] +pub enum RunningTaskContext { + Group { + threshold: u32, + note: Option, + }, + SignChallenge { + group: Group, + data: Vec, + }, + SignPdf { + group: Group, + data: Vec, + }, + Decrypt { + group: Group, + data: Vec, + }, +} +impl RunningTaskContext { + pub fn create_running_task( + self, + voting_task: &VotingTask, + active_shares: HashMap, + ) -> Result, Error> { + let task_info = voting_task.task_info.clone(); + let communicator = Communicator::new( + voting_task.accept_threshold, + task_info.protocol_type.into(), + active_shares, + ); + let task: Box = match self { + Self::Group { threshold, note } => Box::new(group::GroupTask::try_new( + task_info, + threshold, + note, + communicator, + )?), + Self::SignChallenge { group, data } => Box::new(sign::SignTask::try_new( + task_info, + group, + data, + communicator, + )?), + Self::SignPdf { group, data } => Box::new(sign_pdf::SignPDFTask::try_new( + task_info, + group, + data, + communicator, + )?), + Self::Decrypt { group, data } => Box::new(decrypt::DecryptTask::new( + task_info, + group, + data, + communicator, + )), + }; + Ok(task) + } } diff --git a/src/tasks/sign.rs b/src/tasks/sign.rs index 2036349..3943472 100644 --- a/src/tasks/sign.rs +++ b/src/tasks/sign.rs @@ -1,100 +1,59 @@ -use std::sync::Arc; - use crate::communicator::Communicator; use crate::error::Error; -use crate::group::Group; -use crate::persistence::{Participant, PersistenceError, Task as TaskModel}; -use crate::proto::{ProtocolType, SignRequest, TaskType}; +use crate::persistence::{Group, PersistenceError, ProtocolType, Task as TaskModel}; use crate::protocols::frost::FROSTSign; use crate::protocols::gg18::GG18Sign; use crate::protocols::musig2::MuSig2Sign; use crate::protocols::{create_threshold_protocol, Protocol}; -use crate::tasks::{DecisionUpdate, RestartUpdate, RoundUpdate, Task, TaskResult}; +use crate::tasks::{FailedTask, FinishedTask, RoundUpdate, RunningTask, TaskInfo, TaskResult}; use crate::utils; -use async_trait::async_trait; use log::{info, warn}; -use meesign_crypto::proto::{ClientMessage, Message as _}; -use prost::Message as _; -use tokio::sync::RwLock; -use uuid::Uuid; +use meesign_crypto::proto::ClientMessage; pub struct SignTask { - id: Uuid, + pub(super) task_info: TaskInfo, group: Group, - communicator: Arc>, - result: Option, String>>, + communicator: Communicator, pub(super) data: Vec, preprocessed: Option>, pub(super) protocol: Box, - request: Vec, - pub(super) attempts: u32, } impl SignTask { - pub fn try_new(group: Group, name: String, data: Vec) -> Result { - let mut participants: Vec = group.participants().to_vec(); - participants.sort_by(|a, b| a.device.identifier().cmp(b.device.identifier())); - let protocol_type = group.protocol(); - - let decisions = participants - .iter() - .map(|p| (p.device.identifier().clone(), 0)) - .collect(); - let acknowledgements = participants - .iter() - .map(|p| (p.device.identifier().clone(), false)) - .collect(); - - let communicator = Arc::new(RwLock::new(Communicator::new( - participants, - group.threshold(), - protocol_type, - decisions, - acknowledgements, - ))); - - let request = (SignRequest { - group_id: group.identifier().to_vec(), - name, - data: data.clone(), - }) - .encode_to_vec(); - - let id = Uuid::new_v4(); + pub fn try_new( + task_info: TaskInfo, + group: Group, + data: Vec, + communicator: Communicator, + ) -> Result { + let protocol_type = group.protocol; Ok(SignTask { - id, + task_info, group, communicator, - result: None, data, preprocessed: None, protocol: match protocol_type { ProtocolType::Gg18 => Box::new(GG18Sign::new()), ProtocolType::Frost => Box::new(FROSTSign::new()), ProtocolType::Musig2 => Box::new(MuSig2Sign::new()), - _ => { - warn!("Protocol type {:?} does not support signing", protocol_type); + other => { + warn!("Protocol type {:?} does not support signing", other); return Err("Unsupported protocol type for signing".into()); } }, - request, - attempts: 0, }) } pub fn from_model( + task_info: TaskInfo, task_model: TaskModel, - communicator: Arc>, + communicator: Communicator, group: Group, ) -> Result { - let result = task_model - .result - .map(|res| res.try_into_result()) - .transpose()?; - // TODO refactor let protocol = create_threshold_protocol( - group.protocol(), - group.key_type(), + group.protocol.into(), + group.key_type.into(), task_model.protocol_round as u16, )?; @@ -104,15 +63,12 @@ impl SignTask { "Task data not set for a sign task".into(), ))?; let task = Self { - id: task_model.id, + task_info, group, communicator, - result, data, protocol, preprocessed: task_model.preprocessed, - request: task_model.request, - attempts: task_model.attempt_count as u32, }; Ok(task) } @@ -126,227 +82,117 @@ impl SignTask { self.preprocessed = Some(preprocessed); } - pub(super) async fn start_task(&mut self) -> Result { - assert!(self.communicator.read().await.accept_count() >= self.group.threshold()); + pub(super) fn start_task(&mut self) -> Result { self.protocol.initialize( - &mut *self.communicator.write().await, + &mut self.communicator, self.preprocessed.as_ref().unwrap_or(&self.data), ); Ok(RoundUpdate::NextRound(self.protocol.round())) } - pub(super) async fn advance_task(&mut self) -> Result { - self.protocol.advance(&mut *self.communicator.write().await); + pub(super) fn advance_task(&mut self) -> Result { + self.protocol.advance(&mut self.communicator); Ok(RoundUpdate::NextRound(self.protocol.round())) } - pub(super) async fn finalize_task(&mut self) -> Result { - let signature = self - .protocol - .finalize(&mut *self.communicator.write().await); + pub(super) fn finalize_task(&mut self) -> Result { + let signature = self.protocol.finalize(&mut self.communicator); if signature.is_none() { let reason = "Task failed (signature not output)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info.clone(), + reason, + })); } let signature = signature.unwrap(); info!( "Signature created by group_id={}", - utils::hextrunc(self.group.identifier()) + utils::hextrunc(&self.group.id) ); - self.set_result(Ok(signature.clone())); - - self.communicator.write().await.clear_input(); + self.communicator.clear_input(); Ok(RoundUpdate::Finished( self.protocol.round(), - TaskResult::Signed(signature), + FinishedTask::new(self.task_info.clone(), TaskResult::Signed(signature)), )) } - pub(super) async fn next_round(&mut self) -> Result { + pub(super) fn next_round(&mut self) -> Result { if self.protocol.round() == 0 { - self.start_task().await + self.start_task() } else if self.protocol.round() < self.protocol.last_round() { - self.advance_task().await + self.advance_task() } else { - self.finalize_task().await + self.finalize_task() } } - pub(super) async fn update_internal( + pub(super) fn update_internal( &mut self, device_id: &[u8], - data: &Vec>, + messages: Vec, ) -> Result { - if self.communicator.read().await.accept_count() < self.group.threshold() { - return Err(Error::GeneralProtocolError( - "Not enough agreements to proceed with the protocol.".into(), - )); - } - - if !self.waiting_for(device_id).await { + if !self.waiting_for(device_id) { return Err(Error::GeneralProtocolError( "Wasn't waiting for a message from this ID.".into(), )); } - let messages = data - .iter() - .map(|d| ClientMessage::decode(d.as_slice())) - .collect::, _>>() - .map_err(|_| Error::GeneralProtocolError("Expected ClientMessage.".into()))?; - - self.communicator - .write() - .await - .receive_messages(device_id, messages); + self.communicator.receive_messages(device_id, messages); - if self.communicator.read().await.round_received() - && self.protocol.round() <= self.protocol.last_round() + if self.communicator.round_received() && self.protocol.round() <= self.protocol.last_round() { return Ok(true); } Ok(false) } - // TODO: deduplicate across sign and decrypt - pub(super) async fn decide_internal( - &mut self, - device_id: &[u8], - decision: bool, - ) -> Option { - self.communicator.write().await.decide(device_id, decision); - - if self.result.is_none() && self.protocol.round() == 0 { - if self.communicator.read().await.reject_count() >= self.group.reject_threshold() { - self.set_result(Err("Task declined".to_string())); - return Some(false); - } else if self.communicator.read().await.accept_count() >= self.group.threshold() { - return Some(true); - } - } - None - } - - fn set_result(&mut self, result: Result, String>) { - self.result = Some(result); - } - pub fn increment_attempt_count(&mut self) { - self.attempts += 1; + self.task_info.attempts += 1; } } -#[async_trait] -impl Task for SignTask { - fn get_type(&self) -> TaskType { - TaskType::SignChallenge +impl RunningTask for SignTask { + fn task_info(&self) -> &TaskInfo { + &self.task_info } - async fn get_work(&self, device_id: &[u8]) -> Vec> { - if !self.waiting_for(device_id).await { + fn get_work(&self, device_id: &[u8]) -> Vec> { + if !self.waiting_for(device_id) { return Vec::new(); } - self.communicator.read().await.get_messages(device_id) + self.communicator.get_messages(device_id) } fn get_round(&self) -> u16 { self.protocol.round() } - async fn get_decisions(&self) -> (u32, u32) { - ( - self.communicator.read().await.accept_count(), - self.communicator.read().await.reject_count(), - ) + fn initialize(&mut self) -> Result { + self.start_task() } - async fn update( + fn update( &mut self, device_id: &[u8], - data: &Vec>, + messages: Vec, ) -> Result { - let round_update = if self.update_internal(device_id, data).await? { - self.next_round().await? + let round_update = if self.update_internal(device_id, messages)? { + self.next_round()? } else { RoundUpdate::Listen }; Ok(round_update) } - async fn restart(&mut self) -> Result { - if self.result.is_some() { - return Ok(RestartUpdate::AlreadyFinished); - } - - if self.is_approved().await { - self.increment_attempt_count(); - let round_update = self.start_task().await?; - Ok(RestartUpdate::Started(round_update)) - } else { - Ok(RestartUpdate::Voting) - } - } - - async fn is_approved(&self) -> bool { - self.communicator.read().await.accept_count() >= self.group.threshold() - } - - fn get_participants(&self) -> &Vec { - &self.group.participants() - } - - async fn waiting_for(&self, device: &[u8]) -> bool { - if self.protocol.round() == 0 { - return !self.communicator.read().await.device_decided(device); - } else if self.protocol.round() >= self.protocol.last_round() { - return !self.communicator.read().await.device_acknowledged(device); - } - - self.communicator.read().await.waiting_for(device) - } - - async fn decide(&mut self, device_id: &[u8], decision: bool) -> Result { - let result = self.decide_internal(device_id, decision).await; - let decision_update = match result { - Some(true) => { - let round_update = self.next_round().await?; - DecisionUpdate::Accepted(round_update) - } - Some(false) => DecisionUpdate::Declined, - None => DecisionUpdate::Undecided, - }; - Ok(decision_update) - } - - async fn acknowledge(&mut self, device_id: &[u8]) { - self.communicator.write().await.acknowledge(device_id); - } - - fn get_request(&self) -> &[u8] { - &self.request - } - - fn get_attempts(&self) -> u32 { - self.attempts - } - - fn get_id(&self) -> &Uuid { - &self.id - } - - fn get_communicator(&self) -> Arc> { - self.communicator.clone() - } - - fn get_threshold(&self) -> u32 { - self.get_group().threshold() + fn restart(&mut self) -> Result { + self.increment_attempt_count(); + self.start_task() } - fn get_data(&self) -> Option<&[u8]> { - Some(&self.data) + fn waiting_for(&self, device: &[u8]) -> bool { + self.communicator.waiting_for(device) } } diff --git a/src/tasks/sign_pdf.rs b/src/tasks/sign_pdf.rs index cf228c1..8c4c62b 100644 --- a/src/tasks/sign_pdf.rs +++ b/src/tasks/sign_pdf.rs @@ -1,19 +1,16 @@ use crate::communicator::Communicator; use crate::error::Error; -use crate::group::Group; -use crate::persistence::{Participant, Task as TaskModel}; -use crate::proto::TaskType; +use crate::persistence::{Group, Task as TaskModel}; use crate::tasks::sign::SignTask; -use crate::tasks::{DecisionUpdate, RestartUpdate, RoundUpdate, Task, TaskResult}; -use async_trait::async_trait; +use crate::tasks::{FailedTask, RoundUpdate, RunningTask, TaskInfo, TaskResult}; use lazy_static::lazy_static; use log::{error, info, warn}; +use meesign_crypto::proto::ClientMessage; use std::collections::HashMap; use std::io::{Read, Write}; use std::process::{Child, Command, Stdio}; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use tempfile::NamedTempFile; -use tokio::sync::RwLock; use uuid::Uuid; lazy_static! { @@ -22,53 +19,56 @@ lazy_static! { pub struct SignPDFTask { sign_task: SignTask, - result: Option, String>>, } impl SignPDFTask { - pub fn try_new(group: Group, name: String, data: Vec) -> Result { - if data.len() > 8 * 1024 * 1024 || name.len() > 256 || name.chars().any(|x| x.is_control()) + pub fn try_new( + task_info: TaskInfo, + group: Group, + data: Vec, + communicator: Communicator, + ) -> Result { + if data.len() > 8 * 1024 * 1024 + || task_info.name.len() > 256 + || task_info.name.chars().any(|x| x.is_control()) { - warn!("Invalid input name={} len={}", name, data.len()); + warn!("Invalid input name={} len={}", task_info.name, data.len()); return Err("Invalid input".to_string()); } - let sign_task = SignTask::try_new(group, name, data)?; + let sign_task = SignTask::try_new(task_info, group, data, communicator)?; - Ok(SignPDFTask { - sign_task, - result: None, - }) + Ok(SignPDFTask { sign_task }) } pub fn from_model( + task_info: TaskInfo, model: TaskModel, - communicator: Arc>, + communicator: Communicator, group: Group, ) -> Result { - let result = model - .result - .clone() - .map(|res| res.try_into_result()) - .transpose()?; - let sign_task = SignTask::from_model(model, communicator, group)?; - Ok(Self { sign_task, result }) + let sign_task = SignTask::from_model(task_info, model, communicator, group)?; + Ok(Self { sign_task }) } - async fn start_task(&mut self) -> Result { + fn start_task(&mut self) -> Result { let file = NamedTempFile::new(); if file.is_err() { error!("Could not create temporary file"); let reason = "Task failed (server error)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info().clone(), + reason, + })); } let mut file = file.unwrap(); if file.write_all(&self.sign_task.data).is_err() { error!("Could not write in temporary file"); let reason = "Task failed (server error)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info().clone(), + reason, + })); } let pdfhelper = Command::new("java") @@ -83,167 +83,112 @@ impl SignPDFTask { if pdfhelper.is_err() { error!("Could not start PDFHelper"); let reason = "Task failed (server error)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info().clone(), + reason, + })); } let mut pdfhelper = pdfhelper.unwrap(); let hash = request_hash( &mut pdfhelper, - self.sign_task.get_group().certificate().unwrap(), + self.sign_task.get_group().certificate.as_ref().unwrap(), ); if hash.is_empty() { let reason = "Task failed (invalid PDF)".to_string(); - self.set_result(Err(reason.clone())); - return Ok(RoundUpdate::Failed(reason)); + return Ok(RoundUpdate::Failed(FailedTask { + task_info: self.task_info().clone(), + reason, + })); } PDF_HELPERS .lock() .unwrap() - .insert(self.get_id().clone(), pdfhelper); + .insert(self.task_info().id, pdfhelper); self.sign_task.set_preprocessed(hash); - self.sign_task.start_task().await + self.sign_task.start_task() } - async fn advance_task(&mut self) -> Result { - self.sign_task.advance_task().await + fn advance_task(&mut self) -> Result { + self.sign_task.advance_task() } - async fn finalize_task(&mut self) -> Result { - let round_update = match self.sign_task.finalize_task().await? { - RoundUpdate::Finished(round, TaskResult::Signed(signature)) => { - let mut pdfhelper = PDF_HELPERS.lock().unwrap().remove(self.get_id()).unwrap(); + fn finalize_task(&mut self) -> Result { + let round_update = match self.sign_task.finalize_task()? { + RoundUpdate::Finished(round, mut task) => { + let mut pdfhelper = PDF_HELPERS + .lock() + .unwrap() + .remove(&self.task_info().id) + .unwrap(); + let TaskResult::Signed(signature) = task.result else { + unreachable!() + }; let signed = include_signature(&mut pdfhelper, &signature); info!( "PDF signed by group_id={}", - hex::encode(self.sign_task.get_group().identifier()) + hex::encode(&self.sign_task.get_group().id) ); - - self.set_result(Ok(signed.clone())); - RoundUpdate::Finished(round, TaskResult::SignedPdf(signed)) + task.result = TaskResult::SignedPdf(signed); + RoundUpdate::Finished(round, task) } other => other, }; Ok(round_update) } - async fn next_round(&mut self) -> Result { + fn next_round(&mut self) -> Result { if self.sign_task.protocol.round() == 0 { - self.start_task().await + self.start_task() } else if self.sign_task.protocol.round() < self.sign_task.protocol.last_round() { - self.advance_task().await + self.advance_task() } else { - self.finalize_task().await + self.finalize_task() } } - - fn set_result(&mut self, result: Result, String>) { - self.result = Some(result); - } } -#[async_trait] -impl Task for SignPDFTask { - fn get_type(&self) -> TaskType { - TaskType::SignPdf +impl RunningTask for SignPDFTask { + fn task_info(&self) -> &TaskInfo { + &self.sign_task.task_info } - async fn get_work(&self, device_id: &[u8]) -> Vec> { - self.sign_task.get_work(device_id).await + fn get_work(&self, device_id: &[u8]) -> Vec> { + self.sign_task.get_work(device_id) } fn get_round(&self) -> u16 { self.sign_task.get_round() } - async fn get_decisions(&self) -> (u32, u32) { - self.sign_task.get_decisions().await + fn initialize(&mut self) -> Result { + self.start_task() } - async fn update( + fn update( &mut self, device_id: &[u8], - data: &Vec>, + messages: Vec, ) -> Result { - let round_update = if self.sign_task.update_internal(device_id, data).await? { - self.next_round().await? + let round_update = if self.sign_task.update_internal(device_id, messages)? { + self.next_round()? } else { RoundUpdate::Listen }; Ok(round_update) } - async fn restart(&mut self) -> Result { - if self.result.is_some() { - return Ok(RestartUpdate::AlreadyFinished); - } - - if self.is_approved().await { - if let Some(mut pdfhelper) = PDF_HELPERS.lock().unwrap().remove(self.get_id()) { - pdfhelper.kill().unwrap(); - } - self.sign_task.increment_attempt_count(); - let round_update = self.start_task().await?; - Ok(RestartUpdate::Started(round_update)) - } else { - Ok(RestartUpdate::Voting) + fn restart(&mut self) -> Result { + if let Some(mut pdfhelper) = PDF_HELPERS.lock().unwrap().remove(&self.task_info().id) { + pdfhelper.kill().unwrap(); } + self.sign_task.increment_attempt_count(); + self.start_task() } - async fn is_approved(&self) -> bool { - self.sign_task.is_approved().await - } - - fn get_participants(&self) -> &Vec { - self.sign_task.get_participants() - } - - async fn waiting_for(&self, device: &[u8]) -> bool { - self.sign_task.waiting_for(device).await - } - - async fn decide(&mut self, device_id: &[u8], decision: bool) -> Result { - let result = self.sign_task.decide_internal(device_id, decision).await; - let decision_update = match result { - Some(true) => { - let round_update = self.next_round().await?; - DecisionUpdate::Accepted(round_update) - } - Some(false) => { - self.set_result(Err("Task declined".into())); - DecisionUpdate::Declined - } - None => DecisionUpdate::Undecided, - }; - Ok(decision_update) - } - - async fn acknowledge(&mut self, device_id: &[u8]) { - self.sign_task.acknowledge(device_id).await; - } - - fn get_request(&self) -> &[u8] { - self.sign_task.get_request() - } - - fn get_attempts(&self) -> u32 { - self.sign_task.get_attempts() - } - - fn get_id(&self) -> &Uuid { - self.sign_task.get_id() - } - - fn get_communicator(&self) -> Arc> { - self.sign_task.get_communicator() - } - - fn get_threshold(&self) -> u32 { - self.sign_task.get_threshold() - } - fn get_data(&self) -> Option<&[u8]> { - self.sign_task.get_data() + fn waiting_for(&self, device: &[u8]) -> bool { + self.sign_task.waiting_for(device) } }