Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions migrations/2025-08-27-171451_initial_schema/up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ CREATE TABLE task (
"protocol_round" integer NOT NULL CHECK ("protocol_round" >= 0),
"attempt_count" integer NOT NULL CHECK ("attempt_count" >= 0),
"threshold" integer NOT NULL CHECK ("threshold" > 0),
"name" varchar NOT NULL,
"task_data" bytea,
"preprocessed" bytea,
"request" bytea NOT NULL,
Expand Down
7 changes: 3 additions & 4 deletions src/interfaces/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use tonic::{Request, Response, Status};
use uuid::Uuid;

use crate::persistence::DeviceKind;
use crate::proto::{Group, KeyType, MeeSign, MeeSignServer, ProtocolType};
use crate::proto::{KeyType, MeeSign, MeeSignServer, ProtocolType};
use crate::state::State;
use crate::{proto as msg, utils, CA_CERT, CA_KEY};

Expand Down Expand Up @@ -261,21 +261,20 @@ impl MeeSign for MeeSignService {
.unwrap_or_else(|| "unknown".to_string());
debug!("GroupsRequest device_id={}", device_str);

// TODO: refactor, consider storing device IDS in the group model directly
let groups = if let Some(device_id) = device_id {
self.state.activate_device(&device_id);
self.state
.get_device_groups(&device_id)
.await?
.into_iter()
.map(Group::from_model)
.map(msg::Group::from_model)
.collect()
} else {
self.state
.get_groups()
.await?
.into_iter()
.map(Group::from_model)
.map(msg::Group::from_model)
.collect()
};

Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod proto {
let device_ids = model
.participant_ids_shares
.into_iter()
.map(|(device_id, _)| device_id)
.flat_map(|(device_id, shares)| std::iter::repeat_n(device_id, shares as usize))
.collect();
Self {
identifier: model.id,
Expand Down
2 changes: 2 additions & 0 deletions src/persistence/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub struct Task {
pub protocol_round: i32,
pub attempt_count: i32,
pub threshold: i32,
pub name: String,
pub task_data: Option<Vec<u8>>,
pub preprocessed: Option<Vec<u8>>,
pub request: Vec<u8>,
Expand All @@ -128,6 +129,7 @@ pub struct NewTask<'a> {
pub protocol_round: i32,
pub attempt_count: i32,
pub threshold: i32,
pub name: &'a str,
pub task_data: Option<&'a [u8]>,
pub preprocessed: Option<&'a [u8]>,
pub request: &'a [u8],
Expand Down
2 changes: 2 additions & 0 deletions src/persistence/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ impl Repository {
id: Option<&Uuid>,
participants: &[(&[u8], u32)],
threshold: u32,
name: &str,
protocol_type: ProtocolType,
key_type: KeyType,
request: &[u8],
Expand All @@ -192,6 +193,7 @@ impl Repository {
id,
participants,
threshold,
name,
key_type,
protocol_type,
request,
Expand Down
3 changes: 3 additions & 0 deletions src/persistence/repository/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ mod test {
None,
participants,
threshold,
GROUP_1_NAME,
KeyType::SignPdf,
ProtocolType::Gg18,
&[],
Expand Down Expand Up @@ -387,6 +388,7 @@ mod test {
None,
group_1_participants,
threshold,
GROUP_1_NAME,
KeyType::Decrypt,
ProtocolType::ElGamal,
&[],
Expand All @@ -411,6 +413,7 @@ mod test {
None,
group_2_participants,
threshold,
GROUP_2_NAME,
KeyType::SignChallenge,
ProtocolType::Frost,
&[],
Expand Down
10 changes: 10 additions & 0 deletions src/persistence/repository/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ where
protocol_round: 0,
attempt_count: 0,
threshold: threshold as i32,
name,
task_data,
preprocessed: None,
request,
Expand Down Expand Up @@ -89,6 +90,7 @@ pub async fn create_group_task<Conn>(
id: Option<&Uuid>,
participants: &[(&[u8], u32)],
threshold: u32,
name: &str,
key_type: KeyType,
protocol_type: ProtocolType,
request: &[u8],
Expand All @@ -97,6 +99,12 @@ pub async fn create_group_task<Conn>(
where
Conn: AsyncConnection<Backend = Pg>,
{
if !name.is_name_valid() {
return Err(PersistenceError::InvalidArgumentError(format!(
"Invalid group name {name}"
)));
}

let total_shares: u32 = participants.iter().map(|(_, shares)| shares).sum();
if !(1..=total_shares).contains(&threshold) {
return Err(PersistenceError::InvalidArgumentError(format!(
Expand All @@ -111,6 +119,7 @@ where
protocol_round: 0,
attempt_count: 0,
threshold,
name,
task_data: None,
preprocessed: None,
request,
Expand Down Expand Up @@ -158,6 +167,7 @@ macro_rules! task_model_columns {
task::protocol_round,
task::attempt_count,
task::threshold,
task::name,
task::task_data,
task::preprocessed,
task::request,
Expand Down
8 changes: 3 additions & 5 deletions src/persistence/repository/utils.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
const MAX_USERNAME_LENGTH: usize = 64;
const MAX_NAME_LENGTH: usize = 256;

pub trait NameValidator {
fn is_name_valid(&self) -> bool;
}

impl NameValidator for &str {
fn is_name_valid(&self) -> bool {
self.chars().count() <= MAX_USERNAME_LENGTH
&& !self
.chars()
.any(|x| x.is_ascii_punctuation() || x.is_control())
self.chars().count() <= MAX_NAME_LENGTH
&& !self.chars().any(|x| x.is_control())
&& !self.is_empty()
}
}
Expand Down
1 change: 1 addition & 0 deletions src/persistence/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ diesel::table! {
protocol_round -> Int4,
attempt_count -> Int4,
threshold -> Int4,
name -> Varchar,
task_data -> Nullable<Bytea>,
preprocessed -> Nullable<Bytea>,
request -> Bytea,
Expand Down
5 changes: 3 additions & 2 deletions src/task_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ impl TaskStore {
Some(&task.task_info.id),
participant_ids_shares,
threshold,
&task.task_info.name,
task.task_info.protocol_type,
task.task_info.key_type,
&task.request,
Expand All @@ -190,7 +191,7 @@ impl TaskStore {
&group.id,
participant_ids_shares,
group.threshold as u32,
"name", // TODO: Fix name checks
&task.task_info.name,
data,
&task.request,
task_type,
Expand Down Expand Up @@ -255,7 +256,7 @@ impl TaskStore {
let participants = task_id_participants.remove(&task_model.id).unwrap();
let task_info = TaskInfo {
id: task_model.id,
name: "".into(), // TODO: Persist "name" in TaskModel
name: task_model.name.clone(),
task_type: task_model.task_type,
protocol_type: task_model.protocol_type,
key_type: task_model.key_type,
Expand Down
Loading