Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/impls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ chrono = "0.4.38"
tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] }
bb8-postgres = "0.7"
bytes = "1.4.0"
tokio = { version = "1.38.0", default-features = false }

[dev-dependencies]
tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] }
Expand Down
1 change: 1 addition & 0 deletions rust/impls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#![deny(rustdoc::private_intra_doc_links)]
#![deny(missing_docs)]

mod migrations;
/// Contains [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
pub mod postgres_store;

Expand Down
39 changes: 39 additions & 0 deletions rust/impls/src/migrations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
pub(crate) const DB_VERSION_COLUMN: &str = "db_version";
#[cfg(test)]
pub(crate) const MIGRATION_LOG_COLUMN: &str = "upgrade_from";

pub(crate) const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1";
pub(crate) const INIT_DB_CMD: &str = "CREATE DATABASE";
#[cfg(test)]
const DROP_DB_CMD: &str = "DROP DATABASE";
pub(crate) const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;";
pub(crate) const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;";
pub(crate) const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);";
#[cfg(test)]
pub(crate) const GET_MIGRATION_LOG_STMT: &str = "SELECT upgrade_from FROM vss_db_upgrades;";

// APPEND-ONLY list of migration statements
//
// Each statement MUST be applied in-order, and only once per database.
//
// We make an exception for the vss_db table creation statement, as users of VSS could have initialized the table
// themselves.
pub(crate) const MIGRATIONS: &[&str] = &[
"CREATE TABLE vss_db_version (db_version INTEGER);",
"INSERT INTO vss_db_version VALUES(1);",
// A write-only log of all the migrations performed on this database, useful for debugging and testing
"CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);",
// We do not complain if the table already exists, as users of VSS could have already created this table
"CREATE TABLE IF NOT EXISTS vss_db (
user_token character varying(120) NOT NULL CHECK (user_token <> ''),
store_id character varying(120) NOT NULL CHECK (store_id <> ''),
key character varying(600) NOT NULL,
value bytea NULL,
version bigint NOT NULL,
created_at TIMESTAMP WITH TIME ZONE,
last_updated_at TIMESTAMP WITH TIME ZONE,
PRIMARY KEY (user_token, store_id, key)
);",
];
#[cfg(test)]
const DUMMY_MIGRATION: &str = "SELECT 1 WHERE FALSE;";
295 changes: 281 additions & 14 deletions rust/impls/src/postgres_store.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::migrations::*;

use api::error::VssError;
use api::kv_store::{KvStore, GLOBAL_VERSION_KEY, INITIAL_RECORD_VERSION};
use api::types::{
Expand All @@ -12,7 +14,7 @@ use chrono::Utc;
use std::cmp::min;
use std::io;
use std::io::{Error, ErrorKind};
use tokio_postgres::{NoTls, Transaction};
use tokio_postgres::{error, NoTls, Transaction};

pub(crate) struct VssDbRecord {
pub(crate) user_token: String,
Expand Down Expand Up @@ -46,17 +48,189 @@ pub struct PostgresBackendImpl {
pool: Pool<PostgresConnectionManager<NoTls>>,
}

async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
// Connection must be driven on a separate task, and will resolve when the client is dropped
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("Connection error: {}", e);
}
});

let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to check presence of database {}: {}", db_name, e),
)
})?;

if num_rows == 0 {
let stmt = format!("{} {};", INIT_DB_CMD, db_name);
client.execute(&stmt, &[]).await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to create database {}: {}", db_name, e))
})?;
println!("Created database {}", db_name);
}

Ok(())
}

#[cfg(test)]
async fn drop_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
// Connection must be driven on a separate task, and will resolve when the client is dropped
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("Connection error: {}", e);
}
});

let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to drop database {}: {}", db_name, e),
)
})?;
assert_eq!(num_rows, 0);

Ok(())
}

impl PostgresBackendImpl {
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
pub async fn new(dsn: &str) -> Result<Self, Error> {
let manager = PostgresConnectionManager::new_from_stringlike(dsn, NoTls).map_err(|e| {
Error::new(ErrorKind::Other, format!("Connection manager error: {}", e))
})?;
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
initialize_vss_database(postgres_endpoint, db_name).await?;

let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
let manager =
PostgresConnectionManager::new_from_stringlike(vss_dsn, NoTls).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to create PostgresConnectionManager: {}", e),
)
})?;
// By default, Pool maintains 0 long-running connections, so returning a pool
// here is no guarantee that Pool established a connection to the database.
//
// See Builder::min_idle to increase the long-running connection count.
let pool = Pool::builder()
.build(manager)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Pool build error: {}", e)))?;
Ok(PostgresBackendImpl { pool })
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
let postgres_backend = PostgresBackendImpl { pool };

#[cfg(not(test))]
postgres_backend.migrate_vss_database(MIGRATIONS).await?;

Ok(postgres_backend)
}

async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
let mut conn = self.pool.get().await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to fetch a connection from Pool: {}", e),
)
})?;

// Get the next migration to be applied.
let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await {
Ok(row) => {
let i: i32 = row.get(DB_VERSION_COLUMN);
usize::try_from(i).expect("The column should always contain unsigned integers")
},
Err(e) => {
// If the table is not defined, start at migration 0
if let Some(&error::SqlState::UNDEFINED_TABLE) = e.code() {
0
} else {
return Err(Error::new(
ErrorKind::Other,
format!("Failed to query the version of the database schema: {}", e),
));
}
},
};

let tx = conn
.transaction()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?;

if migration_start == migrations.len() {
// No migrations needed, we are done
return Ok((migration_start, migrations.len()));
} else if migration_start > migrations.len() {
panic!("We do not allow downgrades");
}

println!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1);

for (idx, &stmt) in (&migrations[migration_start..]).iter().enumerate() {
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!(
"Database migration no {} with stmt {} failed: {}",
migration_start + idx,
stmt,
e
),
)
})?;
}

let num_rows = tx
.execute(
LOG_MIGRATION_STMT,
&[&i32::try_from(migration_start).expect("Read from an i32 further above")],
)
.await
.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to log database migration: {}", e))
})?;
assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time");

let next_migration_start =
i32::try_from(migrations.len()).expect("Length is definitely smaller than i32::MAX");
let num_rows =
tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to update the version of the schema: {}", e),
)
})?;
assert_eq!(
num_rows, 1,
"UPDATE_VERSION_STMT should only update the unique row in the version table"
);

tx.commit().await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e))
})?;

Ok((migration_start, migrations.len()))
}

#[cfg(test)]
async fn get_schema_version(&self) -> usize {
let conn = self.pool.get().await.unwrap();
let row = conn.query_one(GET_VERSION_STMT, &[]).await.unwrap();
usize::try_from(row.get::<&str, i32>(DB_VERSION_COLUMN)).unwrap()
}

#[cfg(test)]
async fn get_upgrades_list(&self) -> Vec<usize> {
let conn = self.pool.get().await.unwrap();
let rows = conn.query(GET_MIGRATION_LOG_STMT, &[]).await.unwrap();
rows.iter().map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()).collect()
}

fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord {
Expand Down Expand Up @@ -409,12 +583,105 @@ impl KvStore for PostgresBackendImpl {
mod tests {
use crate::postgres_store::PostgresBackendImpl;
use api::define_kv_store_tests;
use tokio::sync::OnceCell;
use super::{MIGRATIONS, DUMMY_MIGRATION, drop_database};

const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
const MIGRATIONS_START: usize = 0;
const MIGRATIONS_END: usize = MIGRATIONS.len();

static START: OnceCell<()> = OnceCell::const_new();

define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImpl, {
let db_name = "postgres_kv_store_tests";
START
.get_or_init(|| async {
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
assert_eq!(start, MIGRATIONS_START);
assert_eq!(end, MIGRATIONS_END);
})
.await;
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
assert_eq!(start, MIGRATIONS_END);
assert_eq!(end, MIGRATIONS_END);
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
store
});

#[tokio::test]
#[should_panic(expected = "We do not allow downgrades")]
async fn panic_on_downgrade() {
let db_name = "panic_on_downgrade_test";
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
{
let mut migrations = MIGRATIONS.to_vec();
migrations.push(DUMMY_MIGRATION);
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
assert_eq!(start, MIGRATIONS_START);
assert_eq!(end, MIGRATIONS_END + 1);
};
{
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
};
}

define_kv_store_tests!(
PostgresKvStoreTest,
PostgresBackendImpl,
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432/postgres")
.await
.unwrap()
);
#[tokio::test]
async fn new_migrations_increments_upgrades() {
let db_name = "new_migrations_increments_upgrades_test";
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
{
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
assert_eq!(start, MIGRATIONS_START);
assert_eq!(end, MIGRATIONS_END);
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
};
{
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
assert_eq!(start, MIGRATIONS_END);
assert_eq!(end, MIGRATIONS_END);
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
};

let mut migrations = MIGRATIONS.to_vec();
migrations.push(DUMMY_MIGRATION);
{
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
assert_eq!(start, MIGRATIONS_END);
assert_eq!(end, MIGRATIONS_END + 1);
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END]);
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 1);
};

migrations.push(DUMMY_MIGRATION);
migrations.push(DUMMY_MIGRATION);
{
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
assert_eq!(start, MIGRATIONS_END + 1);
assert_eq!(end, MIGRATIONS_END + 3);
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 3);
};

{
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
let list = store.get_upgrades_list().await;
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
let version = store.get_schema_version().await;
assert_eq!(version, MIGRATIONS_END + 3);
}

drop_database(POSTGRES_ENDPOINT, db_name).await.unwrap();
}
}
7 changes: 6 additions & 1 deletion rust/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,18 @@ fn main() {
},
};
let authorizer = Arc::new(NoopAuthorizer {});
let postgresql_config = config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.");
let endpoint = postgresql_config.to_postgresql_endpoint();
let db_name = postgresql_config.database;
let store = Arc::new(
PostgresBackendImpl::new(&config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.").to_connection_string())
PostgresBackendImpl::new(&endpoint, &db_name)
.await
.unwrap(),
);
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name);
let rest_svc_listener =
TcpListener::bind(&addr).await.expect("Failed to bind listening port");
println!("Listening for incoming connections on {}", addr);
loop {
tokio::select! {
res = rest_svc_listener.accept() => {
Expand Down
Loading
Loading