Skip to content

Commit 9c74f8f

Browse files
committed
Add migrations tests
1 parent eace62d commit 9c74f8f

File tree

1 file changed

+148
-19
lines changed

1 file changed

+148
-19
lines changed

rust/impls/src/postgres_store.rs

Lines changed: 148 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,18 @@ const VALUE_COLUMN: &str = "value";
2828
const VERSION_COLUMN: &str = "version";
2929

3030
const DB_VERSION_COLUMN: &str = "db_version";
31+
#[cfg(test)]
32+
const MIGRATION_LOG_COLUMN: &str = "upgrade_from";
3133

3234
const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1";
3335
const INIT_DB_CMD: &str = "CREATE DATABASE";
36+
#[cfg(test)]
37+
const DROP_DB_CMD: &str = "DROP DATABASE";
3438
const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;";
3539
const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;";
3640
const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);";
41+
#[cfg(test)]
42+
const GET_MIGRATION_LOG_STMT: &str = "SELECT upgrade_from FROM vss_db_upgrades;";
3743

3844
const MIGRATIONS: &[&str] = &[
3945
"CREATE TABLE vss_db_version (db_version INTEGER);",
@@ -52,6 +58,8 @@ const MIGRATIONS: &[&str] = &[
5258
PRIMARY KEY (user_token, store_id, key)
5359
);",
5460
];
61+
#[cfg(test)]
62+
const DUMMY_MIGRATION: &str = "SELECT 1 WHERE FALSE;";
5563

5664
/// The maximum number of key versions that can be returned in a single page.
5765
///
@@ -102,6 +110,31 @@ async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Resu
102110
Ok(())
103111
}
104112

113+
#[cfg(test)]
114+
async fn drop_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
115+
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
116+
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
117+
.await
118+
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
119+
// Connection must be driven on a separate task, and will resolve when the client is dropped
120+
tokio::spawn(async move {
121+
if let Err(e) = connection.await {
122+
eprintln!("Connection error: {}", e);
123+
}
124+
});
125+
126+
let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
127+
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
128+
Error::new(
129+
ErrorKind::Other,
130+
format!("Failed to drop database {}: {}", db_name, e),
131+
)
132+
})?;
133+
assert_eq!(num_rows, 0);
134+
135+
Ok(())
136+
}
137+
105138
impl PostgresBackendImpl {
106139
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
107140
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
@@ -125,12 +158,13 @@ impl PostgresBackendImpl {
125158
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
126159
let postgres_backend = PostgresBackendImpl { pool };
127160

128-
postgres_backend.migrate_vss_database().await?;
161+
#[cfg(not(test))]
162+
postgres_backend.migrate_vss_database(MIGRATIONS).await?;
129163

130164
Ok(postgres_backend)
131165
}
132166

133-
async fn migrate_vss_database(&self) -> Result<(), Error> {
167+
async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
134168
let mut conn = self.pool.get().await.map_err(|e| {
135169
Error::new(
136170
ErrorKind::Other,
@@ -162,16 +196,16 @@ impl PostgresBackendImpl {
162196
.await
163197
.map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?;
164198

165-
if migration_start == MIGRATIONS.len() {
199+
if migration_start == migrations.len() {
166200
// No migrations needed, we are done
167-
return Ok(());
168-
} else if migration_start > MIGRATIONS.len() {
201+
return Ok((migration_start, migrations.len()));
202+
} else if migration_start > migrations.len() {
169203
panic!("We do not allow downgrades");
170204
}
171205

172-
println!("Applying migration(s) {} through {}", migration_start, MIGRATIONS.len() - 1);
206+
println!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1);
173207

174-
for (idx, &stmt) in (&MIGRATIONS[migration_start..]).iter().enumerate() {
208+
for (idx, &stmt) in (&migrations[migration_start..]).iter().enumerate() {
175209
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
176210
Error::new(
177211
ErrorKind::Other,
@@ -197,7 +231,7 @@ impl PostgresBackendImpl {
197231
assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time");
198232

199233
let next_migration_start =
200-
i32::try_from(MIGRATIONS.len()).expect("Length is definitely smaller than i32::MAX");
234+
i32::try_from(migrations.len()).expect("Length is definitely smaller than i32::MAX");
201235
let num_rows =
202236
tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| {
203237
Error::new(
@@ -214,7 +248,21 @@ impl PostgresBackendImpl {
214248
Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e))
215249
})?;
216250

217-
Ok(())
251+
Ok((migration_start, migrations.len()))
252+
}
253+
254+
#[cfg(test)]
255+
async fn get_schema_version(&self) -> usize {
256+
let conn = self.pool.get().await.unwrap();
257+
let row = conn.query_one(GET_VERSION_STMT, &[]).await.unwrap();
258+
usize::try_from(row.get::<&str, i32>(DB_VERSION_COLUMN)).unwrap()
259+
}
260+
261+
#[cfg(test)]
262+
async fn get_upgrades_list(&self) -> Vec<usize> {
263+
let conn = self.pool.get().await.unwrap();
264+
let rows = conn.query(GET_MIGRATION_LOG_STMT, &[]).await.unwrap();
265+
rows.iter().map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()).collect()
218266
}
219267

220268
fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord {
@@ -568,23 +616,104 @@ mod tests {
568616
use crate::postgres_store::PostgresBackendImpl;
569617
use api::define_kv_store_tests;
570618
use tokio::sync::OnceCell;
619+
use super::{MIGRATIONS, DUMMY_MIGRATION, drop_database};
620+
621+
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
622+
const MIGRATIONS_START: usize = 0;
623+
const MIGRATIONS_END: usize = MIGRATIONS.len();
571624

572625
static START: OnceCell<()> = OnceCell::const_new();
573626

574627
define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImpl, {
628+
let db_name = "postgres_kv_store_tests";
575629
START
576630
.get_or_init(|| async {
577-
// Initialize the database once, and have other threads wait
578-
PostgresBackendImpl::new(
579-
"postgresql://postgres:postgres@localhost:5432",
580-
"postgres",
581-
)
582-
.await
583-
.unwrap();
631+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
632+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
633+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
634+
assert_eq!(start, MIGRATIONS_START);
635+
assert_eq!(end, MIGRATIONS_END);
584636
})
585637
.await;
586-
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432", "postgres")
587-
.await
588-
.unwrap()
638+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
639+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
640+
assert_eq!(start, MIGRATIONS_END);
641+
assert_eq!(end, MIGRATIONS_END);
642+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
643+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
644+
store
589645
});
646+
647+
#[tokio::test]
648+
#[should_panic(expected = "We do not allow downgrades")]
649+
async fn panic_on_downgrade() {
650+
let db_name = "panic_on_downgrade_test";
651+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
652+
{
653+
let mut migrations = MIGRATIONS.to_vec();
654+
migrations.push(DUMMY_MIGRATION);
655+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
656+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
657+
assert_eq!(start, MIGRATIONS_START);
658+
assert_eq!(end, MIGRATIONS_END + 1);
659+
};
660+
{
661+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
662+
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
663+
};
664+
}
665+
666+
#[tokio::test]
667+
async fn new_migrations_increments_upgrades() {
668+
let db_name = "new_migrations_increments_upgrades_test";
669+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
670+
{
671+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
672+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
673+
assert_eq!(start, MIGRATIONS_START);
674+
assert_eq!(end, MIGRATIONS_END);
675+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
676+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
677+
};
678+
{
679+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
680+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
681+
assert_eq!(start, MIGRATIONS_END);
682+
assert_eq!(end, MIGRATIONS_END);
683+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
684+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
685+
};
686+
687+
let mut migrations = MIGRATIONS.to_vec();
688+
migrations.push(DUMMY_MIGRATION);
689+
{
690+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
691+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
692+
assert_eq!(start, MIGRATIONS_END);
693+
assert_eq!(end, MIGRATIONS_END + 1);
694+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END]);
695+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 1);
696+
};
697+
698+
migrations.push(DUMMY_MIGRATION);
699+
migrations.push(DUMMY_MIGRATION);
700+
{
701+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
702+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
703+
assert_eq!(start, MIGRATIONS_END + 1);
704+
assert_eq!(end, MIGRATIONS_END + 3);
705+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
706+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 3);
707+
};
708+
709+
{
710+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
711+
let list = store.get_upgrades_list().await;
712+
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
713+
let version = store.get_schema_version().await;
714+
assert_eq!(version, MIGRATIONS_END + 3);
715+
}
716+
717+
drop_database(POSTGRES_ENDPOINT, db_name).await.unwrap();
718+
}
590719
}

0 commit comments

Comments
 (0)