Skip to content

Commit 17a7692

Browse files
committed
Fix open message missing update feature
1 parent 9f02c0f commit 17a7692

File tree

1 file changed

+108
-2
lines changed

1 file changed

+108
-2
lines changed

mithril-aggregator/src/database/provider/open_message.rs

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type StdResult<T> = Result<T, StdError>;
2424
/// single signature for this message from which a multi signature will be
2525
/// generated if possible.
2626
#[allow(dead_code)]
27-
#[derive(Debug, Clone)]
27+
#[derive(Debug, Clone, PartialEq, Eq)]
2828
pub struct OpenMessage {
2929
/// OpenMessage unique identifier
3030
pub open_message_id: Uuid,
@@ -252,6 +252,45 @@ impl<'client> Provider<'client> for InsertOpenMessageProvider<'client> {
252252
}
253253
}
254254

255+
struct UpdateOpenMessageProvider<'client> {
256+
connection: &'client Connection,
257+
}
258+
impl<'client> UpdateOpenMessageProvider<'client> {
259+
pub fn new(connection: &'client Connection) -> Self {
260+
Self { connection }
261+
}
262+
263+
fn get_update_condition(&self, open_message: &OpenMessage) -> StdResult<WhereCondition> {
264+
let expression = "(open_message_id, epoch_setting_id, beacon, signed_entity_type_id, protocol_message, is_certified) values (?*, ?*, ?*, ?*, ?*, ?*)";
265+
let beacon_str = open_message.signed_entity_type.get_json_beacon()?;
266+
let parameters = vec![
267+
Value::String(open_message.open_message_id.to_string()),
268+
Value::Integer(open_message.epoch.0 as i64),
269+
Value::String(beacon_str),
270+
Value::Integer(open_message.signed_entity_type.index() as i64),
271+
Value::String(serde_json::to_string(&open_message.protocol_message)?),
272+
Value::Integer(open_message.is_certified as i64),
273+
];
274+
275+
Ok(WhereCondition::new(expression, parameters))
276+
}
277+
}
278+
279+
impl<'client> Provider<'client> for UpdateOpenMessageProvider<'client> {
280+
type Entity = OpenMessage;
281+
282+
fn get_connection(&'client self) -> &'client Connection {
283+
self.connection
284+
}
285+
286+
fn get_definition(&self, condition: &str) -> String {
287+
let aliases = SourceAlias::new(&[("{:open_message:}", "open_message")]);
288+
let projection = Self::Entity::get_projection().expand(aliases);
289+
290+
format!("replace into open_message {condition} returning {projection}")
291+
}
292+
}
293+
255294
struct DeleteOpenMessageProvider<'client> {
256295
connection: &'client Connection,
257296
}
@@ -330,6 +369,18 @@ impl OpenMessageRepository {
330369
.ok_or_else(|| panic!("Inserting an open_message should not return nothing."))
331370
}
332371

372+
/// Updates an [OpenMessage] in the database.
373+
pub async fn update_open_message(&self, open_message: &OpenMessage) -> StdResult<OpenMessage> {
374+
let lock = self.connection.lock().await;
375+
let provider = UpdateOpenMessageProvider::new(&lock);
376+
let filters = provider.get_update_condition(open_message)?;
377+
let mut cursor = provider.find(filters)?;
378+
379+
cursor
380+
.next()
381+
.ok_or_else(|| panic!("Updating an open_message should not return nothing."))
382+
}
383+
333384
/// Remove all the [OpenMessage] for the given Epoch in the database.
334385
/// It returns the number of messages removed.
335386
pub async fn clean_epoch(&self, epoch: Epoch) -> StdResult<usize> {
@@ -482,7 +533,7 @@ from open_message
482533
on open_message.open_message_id = single_signature.open_message_id
483534
where {condition}
484535
group by open_message.open_message_id
485-
order by open_message.rowid desc
536+
order by open_message.created_at desc, open_message.rowid desc
486537
"#
487538
)
488539
}
@@ -593,6 +644,37 @@ mod tests {
593644
assert!(!params[4].as_string().unwrap().is_empty());
594645
}
595646

647+
#[test]
648+
fn update_provider_condition() {
649+
let connection = Connection::open(":memory:").unwrap();
650+
let provider = UpdateOpenMessageProvider::new(&connection);
651+
let open_message = OpenMessage {
652+
open_message_id: Uuid::new_v4(),
653+
epoch: Epoch(12),
654+
signed_entity_type: SignedEntityType::dummy(),
655+
protocol_message: ProtocolMessage::new(),
656+
is_certified: true,
657+
created_at: NaiveDateTime::default(),
658+
};
659+
let (expr, params) = provider
660+
.get_update_condition(&open_message)
661+
.unwrap()
662+
.expand();
663+
664+
assert_eq!("(open_message_id, epoch_setting_id, beacon, signed_entity_type_id, protocol_message, is_certified) values (?1, ?2, ?3, ?4, ?5, ?6)".to_string(), expr);
665+
assert_eq!(
666+
vec![
667+
Value::String(open_message.open_message_id.to_string()),
668+
Value::Integer(open_message.epoch.0 as i64),
669+
Value::String(open_message.signed_entity_type.get_json_beacon().unwrap()),
670+
Value::Integer(open_message.signed_entity_type.index() as i64),
671+
Value::String(serde_json::to_string(&open_message.protocol_message).unwrap()),
672+
Value::Integer(open_message.is_certified as i64),
673+
],
674+
params
675+
);
676+
}
677+
596678
#[test]
597679
fn delete_provider_epoch_condition() {
598680
let connection = Connection::open(":memory:").unwrap();
@@ -657,6 +739,30 @@ mod tests {
657739
assert_eq!(open_message.epoch, message.epoch);
658740
}
659741

742+
#[tokio::test]
743+
async fn repository_update_open_message() {
744+
let connection = get_connection().await;
745+
let repository = OpenMessageRepository::new(connection.clone());
746+
let epoch = Epoch(1);
747+
let open_message = repository
748+
.create_open_message(
749+
epoch,
750+
&SignedEntityType::CardanoImmutableFilesFull(Beacon::default()),
751+
&ProtocolMessage::new(),
752+
)
753+
.await
754+
.unwrap();
755+
756+
let mut open_message_updated = open_message;
757+
open_message_updated.is_certified = true;
758+
let open_message_saved = repository
759+
.update_open_message(&open_message_updated)
760+
.await
761+
.unwrap();
762+
763+
assert_eq!(open_message_updated, open_message_saved);
764+
}
765+
660766
#[tokio::test]
661767
async fn repository_clean_open_message() {
662768
let connection = get_connection().await;

0 commit comments

Comments
 (0)