Skip to content

Commit f6d97e9

Browse files
committed
add foreign key support to sqlite
1 parent f28fab4 commit f6d97e9

File tree

4 files changed

+172
-36
lines changed

4 files changed

+172
-36
lines changed

mithril-aggregator/src/database/migration.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,24 @@ insert into certificate (certificate_id,
128128
from certificate_temp as c;
129129
create index epoch_index ON certificate(epoch);
130130
drop table certificate_temp;
131+
"#,
132+
),
133+
// Migration 5
134+
// Add the `open_message` table
135+
SqlMigration::new(
136+
5,
137+
r#"
138+
create table open_message (
139+
open_message_id text not null,
140+
epoch_setting_id int not null,
141+
beacon text not null,
142+
signed_entity_type_id int not null,
143+
message text not null,
144+
created_at text not null default current_timestamp,
145+
primary key (open_message_id),
146+
foreign key (epoch_setting_id) references epoch_setting (epoch_setting_id),
147+
foreign key (signed_entity_type_id) references signed_entity_type (signed_entity_type_id)
148+
);
131149
"#,
132150
),
133151
]

mithril-aggregator/src/database/provider/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ mod stake_pool;
66

77
pub use certificate::*;
88
pub use epoch_setting::*;
9+
pub use open_message::*;
910
pub use stake_pool::*;

mithril-aggregator/src/database/provider/open_message.rs

Lines changed: 134 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
use chrono::NaiveDateTime;
2-
use mithril_common::entities::Beacon;
3-
use mithril_common::entities::Epoch;
41
use mithril_common::StdError;
52

6-
use mithril_common::sqlite::Provider;
7-
use mithril_common::sqlite::SourceAlias;
83
use mithril_common::{
9-
entities::SignedEntityType,
4+
entities::{Beacon, Epoch, SignedEntityType},
105
sqlite::{HydrationError, Projection, SqLiteEntity, WhereCondition},
6+
sqlite::{Provider, SourceAlias},
117
};
8+
9+
use chrono::NaiveDateTime;
1210
use sqlite::Row;
1311
use sqlite::{Connection, Value};
12+
13+
use std::sync::Arc;
14+
15+
use tokio::sync::Mutex;
1416
use uuid::Uuid;
1517

1618
type StdResult<T> = Result<T, StdError>;
@@ -20,6 +22,7 @@ type StdResult<T> = Result<T, StdError>;
2022
/// An open message is a message open for signatures. Every signer may send a
2123
/// single signature for this message from which a multi signature will be
2224
/// generated if possible.
25+
#[allow(dead_code)]
2326
pub struct OpenMessage {
2427
/// OpenMessage unique identifier
2528
open_message_id: Uuid,
@@ -53,9 +56,9 @@ impl SqLiteEntity for OpenMessage {
5356
))
5457
})?;
5558
let message = row.get::<String, _>(4);
56-
let epoch_settings_id = row.get::<i64, _>(1);
57-
let epoch_val = u64::try_from(epoch_settings_id)
58-
.map_err(|e| panic!("Integer field open_message.epoch_settings_id (value={epoch_settings_id}) is incompatible with u64 Epoch representation. Error = {e}"))?;
59+
let epoch_setting_id = row.get::<i64, _>(1);
60+
let epoch_val = u64::try_from(epoch_setting_id)
61+
.map_err(|e| panic!("Integer field open_message.epoch_setting_id (value={epoch_setting_id}) is incompatible with u64 Epoch representation. Error = {e}"))?;
5962

6063
let signed_entity_type_id = usize::try_from(row.get::<i64, _>(3)).map_err(|e| {
6164
panic!(
@@ -94,10 +97,14 @@ impl SqLiteEntity for OpenMessage {
9497

9598
fn get_projection() -> Projection {
9699
Projection::from(&[
97-
("open_message_id", "{:open_message:}.open_message_id", "int"),
98100
(
99-
"epoch_settings_id",
100-
"{:open_message:}.epoch_settings_id",
101+
"open_message_id",
102+
"{:open_message:}.open_message_id",
103+
"text",
104+
),
105+
(
106+
"epoch_setting_id",
107+
"{:open_message:}.epoch_setting_id",
101108
"int",
102109
),
103110
("beacon", "{:open_message:}.beacon", "text"),
@@ -123,7 +130,7 @@ impl<'client> OpenMessageProvider<'client> {
123130

124131
fn get_epoch_condition(&self, epoch: Epoch) -> WhereCondition {
125132
WhereCondition::new(
126-
"epoch_settings_id = ?*",
133+
"epoch_setting_id = ?*",
127134
vec![Value::Integer(epoch.0 as i64)],
128135
)
129136
}
@@ -138,6 +145,8 @@ impl<'client> OpenMessageProvider<'client> {
138145
)
139146
}
140147

148+
// Useful in test and probably in the future.
149+
#[allow(dead_code)]
141150
fn get_open_message_id_condition(&self, open_message_id: &str) -> WhereCondition {
142151
WhereCondition::new(
143152
"open_message_id = ?*",
@@ -176,7 +185,7 @@ impl<'client> InsertOpenMessageProvider<'client> {
176185
signed_entity_type: &SignedEntityType,
177186
message: &str,
178187
) -> StdResult<WhereCondition> {
179-
let expression = "(open_message_id, epoch_settings_id, beacon, signed_entity_type_id, message) values (?*, ?*, ?*, ?*, ?*)";
188+
let expression = "(open_message_id, epoch_setting_id, beacon, signed_entity_type_id, message) values (?*, ?*, ?*, ?*, ?*)";
180189
let parameters = vec![
181190
Value::String(Uuid::new_v4().to_string()),
182191
Value::Integer(epoch.0 as i64),
@@ -215,7 +224,7 @@ impl<'client> DeleteOpenMessageProvider<'client> {
215224

216225
fn get_epoch_condition(&self, epoch: Epoch) -> WhereCondition {
217226
WhereCondition::new(
218-
"epoch_settings_id = ?*",
227+
"epoch_setting_id = ?*",
219228
vec![Value::Integer(epoch.0 as i64)],
220229
)
221230
}
@@ -236,18 +245,28 @@ impl<'client> Provider<'client> for DeleteOpenMessageProvider<'client> {
236245
}
237246
}
238247

239-
pub struct OpenMessageRepository<'client> {
240-
connection: &'client Connection,
248+
/// ## Open message repository
249+
///
250+
/// This is a business oriented layer to perform actions on the database through
251+
/// providers.
252+
pub struct OpenMessageRepository {
253+
connection: Arc<Mutex<Connection>>,
241254
}
242255

243-
impl<'client> OpenMessageRepository<'client> {
256+
impl OpenMessageRepository {
257+
/// Instanciate service
258+
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
259+
Self { connection }
260+
}
261+
244262
/// Return the latest [OpenMessage] for the given Epoch and [SignedEntityType].
245-
pub fn get_open_message(
263+
pub async fn get_open_message(
246264
&self,
247265
epoch: Epoch,
248266
signed_entity_type: &SignedEntityType,
249267
) -> StdResult<Option<OpenMessage>> {
250-
let provider = OpenMessageProvider::new(self.connection);
268+
let lock = self.connection.lock().await;
269+
let provider = OpenMessageProvider::new(&lock);
251270
let filters = provider
252271
.get_epoch_condition(epoch)
253272
.and_where(provider.get_signed_entity_type_condition(signed_entity_type));
@@ -257,14 +276,15 @@ impl<'client> OpenMessageRepository<'client> {
257276
}
258277

259278
/// Create a new [OpenMessage] in the database.
260-
pub fn create_open_message(
279+
pub async fn create_open_message(
261280
&self,
262281
epoch: Epoch,
263282
beacon: &Beacon,
264283
signed_entity_type: &SignedEntityType,
265284
message: &str,
266285
) -> StdResult<OpenMessage> {
267-
let provider = InsertOpenMessageProvider::new(self.connection);
286+
let lock = self.connection.lock().await;
287+
let provider = InsertOpenMessageProvider::new(&lock);
268288
let filters = provider.get_insert_condition(epoch, beacon, signed_entity_type, message)?;
269289
let mut cursor = provider.find(filters)?;
270290

@@ -274,19 +294,23 @@ impl<'client> OpenMessageRepository<'client> {
274294
}
275295

276296
/// Remove all the [OpenMessage] for the given Epoch in the database.
277-
pub fn clean_epoch(&self, epoch: Epoch) -> StdResult<()> {
278-
let provider = DeleteOpenMessageProvider::new(self.connection);
297+
/// It returns the number of messages removed.
298+
pub async fn clean_epoch(&self, epoch: Epoch) -> StdResult<usize> {
299+
let lock = self.connection.lock().await;
300+
let provider = DeleteOpenMessageProvider::new(&lock);
279301
let filters = provider.get_epoch_condition(epoch);
280-
let _ = provider.find(filters)?;
302+
let cursor = provider.find(filters)?;
281303

282-
Ok(())
304+
Ok(cursor.count())
283305
}
284306
}
285307

286308
#[cfg(test)]
287309
mod tests {
288310
use mithril_common::sqlite::SourceAlias;
289311

312+
use crate::{dependency_injection::DependenciesBuilder, Configuration};
313+
290314
use super::*;
291315

292316
#[test]
@@ -295,7 +319,7 @@ mod tests {
295319
let aliases = SourceAlias::new(&[("{:open_message:}", "open_message")]);
296320

297321
assert_eq!(
298-
"open_message.open_message_id as open_message_id, open_message.epoch_settings_id as epoch_settings_id, open_message.beacon as beacon, open_message.signed_entity_type_id as signed_entity_type_id, open_message.message as message, open_message.created_at as created_at".to_string(),
322+
"open_message.open_message_id as open_message_id, open_message.epoch_setting_id as epoch_setting_id, open_message.beacon as beacon, open_message.signed_entity_type_id as signed_entity_type_id, open_message.message as message, open_message.created_at as created_at".to_string(),
299323
projection.expand(aliases)
300324
)
301325
}
@@ -306,7 +330,7 @@ mod tests {
306330
let provider = OpenMessageProvider::new(&connection);
307331
let (expr, params) = provider.get_epoch_condition(Epoch(12)).expand();
308332

309-
assert_eq!("epoch_settings_id = ?1".to_string(), expr);
333+
assert_eq!("epoch_setting_id = ?1".to_string(), expr);
310334
assert_eq!(vec![Value::Integer(12)], params,);
311335
}
312336

@@ -353,7 +377,7 @@ mod tests {
353377
.unwrap()
354378
.expand();
355379

356-
assert_eq!("(open_message_id, epoch_settings_id, beacon, signed_entity_type_id, message) values (?1, ?2, ?3, ?4, ?5)".to_string(), expr);
380+
assert_eq!("(open_message_id, epoch_setting_id, beacon, signed_entity_type_id, message) values (?1, ?2, ?3, ?4, ?5)".to_string(), expr);
357381
assert_eq!(Value::Integer(12), params[1]);
358382
assert_eq!(
359383
Value::String(r#"{"network":"","epoch":0,"immutable_file_number":0}"#.to_string()),
@@ -369,7 +393,87 @@ mod tests {
369393
let provider = DeleteOpenMessageProvider::new(&connection);
370394
let (expr, params) = provider.get_epoch_condition(Epoch(12)).expand();
371395

372-
assert_eq!("epoch_settings_id = ?1".to_string(), expr);
396+
assert_eq!("epoch_setting_id = ?1".to_string(), expr);
373397
assert_eq!(vec![Value::Integer(12)], params,);
374398
}
399+
400+
async fn get_connection() -> Arc<Mutex<Connection>> {
401+
let config = Configuration::new_sample();
402+
let mut builder = DependenciesBuilder::new(config);
403+
let connection = builder.get_sqlite_connection().await.unwrap();
404+
{
405+
let lock = connection.lock().await;
406+
lock.execute(r#"insert into epoch_setting(epoch_setting_id, protocol_parameters) values (1, '{"k": 100, "m": 5, "phi": 0.65 }');"#).unwrap();
407+
}
408+
409+
connection
410+
}
411+
412+
#[tokio::test]
413+
async fn repository_create_open_message() {
414+
let connection = get_connection().await;
415+
let repository = OpenMessageRepository::new(connection.clone());
416+
let open_message = repository
417+
.create_open_message(
418+
Epoch(1),
419+
&Beacon::default(),
420+
&SignedEntityType::CardanoImmutableFilesFull,
421+
"this is a message",
422+
)
423+
.await
424+
.unwrap();
425+
426+
assert_eq!(Epoch(1), open_message.epoch);
427+
assert_eq!("this is a message".to_string(), open_message.message);
428+
assert_eq!(
429+
SignedEntityType::CardanoImmutableFilesFull,
430+
open_message.signed_entity_type
431+
);
432+
433+
let message = {
434+
let lock = connection.lock().await;
435+
let provider = OpenMessageProvider::new(&lock);
436+
let mut cursor = provider
437+
.find(WhereCondition::new(
438+
"open_message_id = ?*",
439+
vec![Value::String(open_message.open_message_id.to_string())],
440+
))
441+
.unwrap();
442+
443+
cursor.next().expect(&format!(
444+
"OpenMessage ID='{}' should exist in the database.",
445+
open_message.open_message_id.to_string()
446+
))
447+
};
448+
449+
assert_eq!(open_message.message, message.message);
450+
assert_eq!(open_message.epoch, message.epoch);
451+
}
452+
453+
#[tokio::test]
454+
async fn repository_clean_open_message() {
455+
let connection = get_connection().await;
456+
let repository = OpenMessageRepository::new(connection.clone());
457+
let _ = repository
458+
.create_open_message(
459+
Epoch(1),
460+
&Beacon::default(),
461+
&SignedEntityType::CardanoImmutableFilesFull,
462+
"this is a message",
463+
)
464+
.await
465+
.unwrap();
466+
let _ = repository
467+
.create_open_message(
468+
Epoch(1),
469+
&Beacon::default(),
470+
&SignedEntityType::MithrilStakeDistribution,
471+
"this is a stake distribution",
472+
)
473+
.await
474+
.unwrap();
475+
let count = repository.clean_epoch(Epoch(1)).await.unwrap();
476+
477+
assert_eq!(2, count);
478+
}
375479
}

mithril-aggregator/src/dependency_injection/builder.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,19 @@ impl DependenciesBuilder {
184184
}
185185

186186
async fn build_sqlite_connection(&self) -> Result<Arc<Mutex<Connection>>> {
187-
let connection = match self.configuration.environment {
187+
let path = match self.configuration.environment {
188188
ExecutionEnvironment::Production => {
189-
Connection::open(self.configuration.get_sqlite_dir().join(SQLITE_FILE))
189+
self.configuration.get_sqlite_dir().join(SQLITE_FILE)
190190
}
191-
_ => Connection::open(":memory:"),
191+
_ => ":memory:".into(),
192192
};
193-
let connection = connection
194-
.map(|conn| Arc::new(Mutex::new(conn)))
193+
let connection = Connection::open(&path)
194+
.map(|c| Arc::new(Mutex::new(c)))
195195
.map_err(|e| DependenciesBuilderError::Initialization {
196-
message: "Could not initialize SQLite driver.".to_string(),
196+
message: format!(
197+
"SQLite initialization: could not open connection with string '{}'.",
198+
path.display()
199+
),
197200
error: Some(Box::new(e)),
198201
})?;
199202
// Check database migrations
@@ -207,6 +210,16 @@ impl DependenciesBuilder {
207210
db_checker.add_migration(migration);
208211
}
209212

213+
// configure session
214+
connection
215+
.lock()
216+
.await
217+
.execute("pragma foreign_keys=true")
218+
.map_err(|e| DependenciesBuilderError::Initialization {
219+
message: "SQLite initialization: could not enable FOREIGN KEY support.".to_string(),
220+
error: Some(e.into()),
221+
})?;
222+
210223
db_checker.apply().await?;
211224

212225
Ok(connection)

0 commit comments

Comments
 (0)