@@ -24,7 +24,7 @@ type StdResult<T> = Result<T, StdError>;
24
24
/// single signature for this message from which a multi signature will be
25
25
/// generated if possible.
26
26
#[ allow( dead_code) ]
27
- #[ derive( Debug , Clone ) ]
27
+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
28
28
pub struct OpenMessage {
29
29
/// OpenMessage unique identifier
30
30
pub open_message_id : Uuid ,
@@ -252,6 +252,45 @@ impl<'client> Provider<'client> for InsertOpenMessageProvider<'client> {
252
252
}
253
253
}
254
254
255
+ struct UpdateOpenMessageProvider < ' client > {
256
+ connection : & ' client Connection ,
257
+ }
258
+ impl < ' client > UpdateOpenMessageProvider < ' client > {
259
+ pub fn new ( connection : & ' client Connection ) -> Self {
260
+ Self { connection }
261
+ }
262
+
263
+ fn get_update_condition ( & self , open_message : & OpenMessage ) -> StdResult < WhereCondition > {
264
+ let expression = "(open_message_id, epoch_setting_id, beacon, signed_entity_type_id, protocol_message, is_certified) values (?*, ?*, ?*, ?*, ?*, ?*)" ;
265
+ let beacon_str = open_message. signed_entity_type . get_json_beacon ( ) ?;
266
+ let parameters = vec ! [
267
+ Value :: String ( open_message. open_message_id. to_string( ) ) ,
268
+ Value :: Integer ( open_message. epoch. 0 as i64 ) ,
269
+ Value :: String ( beacon_str) ,
270
+ Value :: Integer ( open_message. signed_entity_type. index( ) as i64 ) ,
271
+ Value :: String ( serde_json:: to_string( & open_message. protocol_message) ?) ,
272
+ Value :: Integer ( open_message. is_certified as i64 ) ,
273
+ ] ;
274
+
275
+ Ok ( WhereCondition :: new ( expression, parameters) )
276
+ }
277
+ }
278
+
279
+ impl < ' client > Provider < ' client > for UpdateOpenMessageProvider < ' client > {
280
+ type Entity = OpenMessage ;
281
+
282
+ fn get_connection ( & ' client self ) -> & ' client Connection {
283
+ self . connection
284
+ }
285
+
286
+ fn get_definition ( & self , condition : & str ) -> String {
287
+ let aliases = SourceAlias :: new ( & [ ( "{:open_message:}" , "open_message" ) ] ) ;
288
+ let projection = Self :: Entity :: get_projection ( ) . expand ( aliases) ;
289
+
290
+ format ! ( "replace into open_message {condition} returning {projection}" )
291
+ }
292
+ }
293
+
255
294
struct DeleteOpenMessageProvider < ' client > {
256
295
connection : & ' client Connection ,
257
296
}
@@ -330,6 +369,18 @@ impl OpenMessageRepository {
330
369
. ok_or_else ( || panic ! ( "Inserting an open_message should not return nothing." ) )
331
370
}
332
371
372
+ /// Updates an [OpenMessage] in the database.
373
+ pub async fn update_open_message ( & self , open_message : & OpenMessage ) -> StdResult < OpenMessage > {
374
+ let lock = self . connection . lock ( ) . await ;
375
+ let provider = UpdateOpenMessageProvider :: new ( & lock) ;
376
+ let filters = provider. get_update_condition ( open_message) ?;
377
+ let mut cursor = provider. find ( filters) ?;
378
+
379
+ cursor
380
+ . next ( )
381
+ . ok_or_else ( || panic ! ( "Updating an open_message should not return nothing." ) )
382
+ }
383
+
333
384
/// Remove all the [OpenMessage] for the given Epoch in the database.
334
385
/// It returns the number of messages removed.
335
386
pub async fn clean_epoch ( & self , epoch : Epoch ) -> StdResult < usize > {
@@ -482,7 +533,7 @@ from open_message
482
533
on open_message.open_message_id = single_signature.open_message_id
483
534
where {condition}
484
535
group by open_message.open_message_id
485
- order by open_message.rowid desc
536
+ order by open_message.created_at desc, open_message. rowid desc
486
537
"#
487
538
)
488
539
}
@@ -593,6 +644,37 @@ mod tests {
593
644
assert ! ( !params[ 4 ] . as_string( ) . unwrap( ) . is_empty( ) ) ;
594
645
}
595
646
647
+ #[ test]
648
+ fn update_provider_condition ( ) {
649
+ let connection = Connection :: open ( ":memory:" ) . unwrap ( ) ;
650
+ let provider = UpdateOpenMessageProvider :: new ( & connection) ;
651
+ let open_message = OpenMessage {
652
+ open_message_id : Uuid :: new_v4 ( ) ,
653
+ epoch : Epoch ( 12 ) ,
654
+ signed_entity_type : SignedEntityType :: dummy ( ) ,
655
+ protocol_message : ProtocolMessage :: new ( ) ,
656
+ is_certified : true ,
657
+ created_at : NaiveDateTime :: default ( ) ,
658
+ } ;
659
+ let ( expr, params) = provider
660
+ . get_update_condition ( & open_message)
661
+ . unwrap ( )
662
+ . expand ( ) ;
663
+
664
+ assert_eq ! ( "(open_message_id, epoch_setting_id, beacon, signed_entity_type_id, protocol_message, is_certified) values (?1, ?2, ?3, ?4, ?5, ?6)" . to_string( ) , expr) ;
665
+ assert_eq ! (
666
+ vec![
667
+ Value :: String ( open_message. open_message_id. to_string( ) ) ,
668
+ Value :: Integer ( open_message. epoch. 0 as i64 ) ,
669
+ Value :: String ( open_message. signed_entity_type. get_json_beacon( ) . unwrap( ) ) ,
670
+ Value :: Integer ( open_message. signed_entity_type. index( ) as i64 ) ,
671
+ Value :: String ( serde_json:: to_string( & open_message. protocol_message) . unwrap( ) ) ,
672
+ Value :: Integer ( open_message. is_certified as i64 ) ,
673
+ ] ,
674
+ params
675
+ ) ;
676
+ }
677
+
596
678
#[ test]
597
679
fn delete_provider_epoch_condition ( ) {
598
680
let connection = Connection :: open ( ":memory:" ) . unwrap ( ) ;
@@ -657,6 +739,30 @@ mod tests {
657
739
assert_eq ! ( open_message. epoch, message. epoch) ;
658
740
}
659
741
742
+ #[ tokio:: test]
743
+ async fn repository_update_open_message ( ) {
744
+ let connection = get_connection ( ) . await ;
745
+ let repository = OpenMessageRepository :: new ( connection. clone ( ) ) ;
746
+ let epoch = Epoch ( 1 ) ;
747
+ let open_message = repository
748
+ . create_open_message (
749
+ epoch,
750
+ & SignedEntityType :: CardanoImmutableFilesFull ( Beacon :: default ( ) ) ,
751
+ & ProtocolMessage :: new ( ) ,
752
+ )
753
+ . await
754
+ . unwrap ( ) ;
755
+
756
+ let mut open_message_updated = open_message;
757
+ open_message_updated. is_certified = true ;
758
+ let open_message_saved = repository
759
+ . update_open_message ( & open_message_updated)
760
+ . await
761
+ . unwrap ( ) ;
762
+
763
+ assert_eq ! ( open_message_updated, open_message_saved) ;
764
+ }
765
+
660
766
#[ tokio:: test]
661
767
async fn repository_clean_open_message ( ) {
662
768
let connection = get_connection ( ) . await ;
0 commit comments