Skip to content

Commit 6287609

Browse files
committed
fixup
- Remove the `Impl` suffix in `PostgresBackendImpl`. - Delete the `DbConnectionType` enum; pass a generic parameter everywhere with the connection information. - Further DRY connection setup code.
1 parent 2e1ba6c commit 6287609

File tree

2 files changed

+90
-131
lines changed

2 files changed

+90
-131
lines changed

rust/impls/src/postgres_store.rs

Lines changed: 83 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ use bb8_postgres::bb8::Pool;
1111
use bb8_postgres::PostgresConnectionManager;
1212
use bytes::Bytes;
1313
use chrono::Utc;
14-
use native_tls::{Certificate, TlsConnector};
14+
use native_tls::TlsConnector;
1515
use postgres_native_tls::MakeTlsConnector;
1616
use std::cmp::min;
1717
use std::io::{self, Error, ErrorKind};
1818
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
1919
use tokio_postgres::{error, Client, NoTls, Socket, Transaction};
2020

21+
pub use native_tls::Certificate;
22+
2123
pub(crate) struct VssDbRecord {
2224
pub(crate) user_token: String,
2325
pub(crate) store_id: String,
@@ -46,7 +48,7 @@ pub const LIST_KEY_VERSIONS_MAX_PAGE_SIZE: i32 = 100;
4648
pub const MAX_PUT_REQUEST_ITEM_COUNT: usize = 1000;
4749

4850
/// A [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
49-
pub struct PostgresBackendImpl<T>
51+
pub struct PostgresBackend<T>
5052
where
5153
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
5254
<T as MakeTlsConnect<Socket>>::Stream: Send + Sync,
@@ -56,67 +58,42 @@ where
5658
pool: Pool<PostgresConnectionManager<T>>,
5759
}
5860

59-
/// A postgres backend with plain connections to the database
60-
pub type PostgresBackendImplPlain = PostgresBackendImpl<NoTls>;
61+
/// A postgres backend with plaintext connections to the database
62+
pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6163

6264
/// A postgres backend with TLS connections to the database
63-
pub type PostgresBackendImplTls = PostgresBackendImpl<MakeTlsConnector>;
64-
65-
enum DbConnectionType {
66-
Plain,
67-
Tls(Option<String>),
68-
}
65+
pub type PostgresTlsBackend = PostgresBackend<MakeTlsConnector>;
6966

70-
async fn make_postgres_db_connection(
71-
postgres_endpoint: &str, connection_type: DbConnectionType,
72-
) -> Result<Client, Error> {
67+
async fn make_postgres_db_connection<T>(postgres_endpoint: &str, tls: T) -> Result<Client, Error>
68+
where
69+
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
70+
T::Stream: Send + Sync,
71+
T::TlsConnect: Send,
72+
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
73+
{
7374
let dsn = format!("{}/{}", postgres_endpoint, "postgres");
74-
match connection_type {
75-
DbConnectionType::Plain => {
76-
let (client, connection) = tokio_postgres::connect(&dsn, NoTls)
77-
.await
78-
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
79-
// Connection must be driven on a separate task, and will resolve when the client is dropped
80-
tokio::spawn(async move {
81-
if let Err(e) = connection.await {
82-
eprintln!("Connection error: {}", e);
83-
}
84-
});
85-
Ok(client)
86-
},
87-
DbConnectionType::Tls(ca_file) => {
88-
let mut builder = TlsConnector::builder();
89-
if let Some(ca) = ca_file {
90-
let cert = std::fs::read(ca).map_err(|e| {
91-
Error::new(ErrorKind::Other, format!("Error reading certificate file: {}", e))
92-
})?;
93-
let cert = Certificate::from_pem(&cert).map_err(|e| {
94-
Error::new(ErrorKind::Other, format!("Error loading certificate file: {}", e))
95-
})?;
96-
builder.add_root_certificate(cert);
97-
}
98-
let connector = builder.build().map_err(|e| {
99-
Error::new(ErrorKind::Other, format!("Error building tls connector: {}", e))
100-
})?;
101-
let connector = MakeTlsConnector::new(connector);
102-
let (client, connection) = tokio_postgres::connect(&dsn, connector)
103-
.await
104-
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
105-
// Connection must be driven on a separate task, and will resolve when the client is dropped
106-
tokio::spawn(async move {
107-
if let Err(e) = connection.await {
108-
eprintln!("Connection error: {}", e);
109-
}
110-
});
111-
Ok(client)
112-
},
113-
}
75+
let (client, connection) = tokio_postgres::connect(&dsn, tls)
76+
.await
77+
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
78+
// Connection must be driven on a separate task, and will resolve when the client is dropped
79+
tokio::spawn(async move {
80+
if let Err(e) = connection.await {
81+
eprintln!("Connection error: {}", e);
82+
}
83+
});
84+
Ok(client)
11485
}
11586

116-
async fn initialize_vss_database(
117-
postgres_endpoint: &str, db_name: &str, connection_type: DbConnectionType,
118-
) -> Result<(), Error> {
119-
let client = make_postgres_db_connection(&postgres_endpoint, connection_type).await?;
87+
async fn initialize_vss_database<T>(
88+
postgres_endpoint: &str, db_name: &str, tls: T,
89+
) -> Result<(), Error>
90+
where
91+
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
92+
T::Stream: Send + Sync,
93+
T::TlsConnect: Send,
94+
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
95+
{
96+
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
12097

12198
let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
12299
Error::new(
@@ -136,10 +113,14 @@ async fn initialize_vss_database(
136113
}
137114

138115
#[cfg(test)]
139-
async fn drop_database(
140-
postgres_endpoint: &str, db_name: &str, connection_type: DbConnectionType,
141-
) -> Result<(), Error> {
142-
let client = make_postgres_db_connection(&postgres_endpoint, connection_type).await?;
116+
async fn drop_database<T>(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<(), Error>
117+
where
118+
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
119+
T::Stream: Send + Sync,
120+
T::TlsConnect: Send,
121+
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
122+
{
123+
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
143124

144125
let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
145126
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
@@ -150,61 +131,42 @@ async fn drop_database(
150131
Ok(())
151132
}
152133

153-
impl PostgresBackendImplTls {
154-
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information, with a tls connection pool.
134+
impl PostgresPlaintextBackend {
135+
/// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136+
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
137+
PostgresBackend::new_internal(postgres_endpoint, db_name, NoTls).await
138+
}
139+
}
140+
141+
impl PostgresTlsBackend {
142+
/// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
155143
pub async fn new(
156-
postgres_endpoint: &str, db_name: &str, ca_file: Option<String>,
144+
postgres_endpoint: &str, db_name: &str, additional_certificate: Option<Certificate>,
157145
) -> Result<Self, Error> {
158-
initialize_vss_database(postgres_endpoint, db_name, DbConnectionType::Tls(ca_file.clone()))
159-
.await?;
160-
161-
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
162146
let mut builder = TlsConnector::builder();
163-
if let Some(ca) = ca_file {
164-
let cert = std::fs::read(ca).map_err(|e| {
165-
Error::new(ErrorKind::Other, format!("Error reading certificate file: {}", e))
166-
})?;
167-
let cert = Certificate::from_pem(&cert).map_err(|e| {
168-
Error::new(ErrorKind::Other, format!("Error loading certificate file: {}", e))
169-
})?;
147+
if let Some(cert) = additional_certificate {
170148
builder.add_root_certificate(cert);
171149
}
172150
let connector = builder.build().map_err(|e| {
173151
Error::new(ErrorKind::Other, format!("Error building tls connector: {}", e))
174152
})?;
175-
let connector = MakeTlsConnector::new(connector);
176-
let manager =
177-
PostgresConnectionManager::new_from_stringlike(vss_dsn, connector).map_err(|e| {
178-
Error::new(
179-
ErrorKind::Other,
180-
format!("Failed to create PostgresConnectionManager: {}", e),
181-
)
182-
})?;
183-
// By default, Pool maintains 0 long-running connections, so returning a pool
184-
// here is no guarantee that Pool established a connection to the database.
185-
//
186-
// See Builder::min_idle to increase the long-running connection count.
187-
let pool = Pool::builder()
188-
.build(manager)
153+
PostgresBackend::new_internal(postgres_endpoint, db_name, MakeTlsConnector::new(connector))
189154
.await
190-
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
191-
let postgres_backend = PostgresBackendImpl { pool };
192-
193-
#[cfg(not(test))]
194-
postgres_backend.migrate_vss_database(MIGRATIONS).await?;
195-
196-
Ok(postgres_backend)
197155
}
198156
}
199157

200-
impl PostgresBackendImplPlain {
201-
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information, with a plain connection pool.
202-
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
203-
initialize_vss_database(postgres_endpoint, db_name, DbConnectionType::Plain).await?;
204-
158+
impl<T> PostgresBackend<T>
159+
where
160+
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
161+
T::Stream: Send + Sync,
162+
T::TlsConnect: Send,
163+
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
164+
{
165+
async fn new_internal(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<Self, Error> {
166+
initialize_vss_database(postgres_endpoint, db_name, tls.clone()).await?;
205167
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
206168
let manager =
207-
PostgresConnectionManager::new_from_stringlike(vss_dsn, NoTls).map_err(|e| {
169+
PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| {
208170
Error::new(
209171
ErrorKind::Other,
210172
format!("Failed to create PostgresConnectionManager: {}", e),
@@ -218,22 +180,14 @@ impl PostgresBackendImplPlain {
218180
.build(manager)
219181
.await
220182
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
221-
let postgres_backend = PostgresBackendImpl { pool };
183+
let postgres_backend = PostgresBackend { pool };
222184

223185
#[cfg(not(test))]
224186
postgres_backend.migrate_vss_database(MIGRATIONS).await?;
225187

226188
Ok(postgres_backend)
227189
}
228-
}
229190

230-
impl<T> PostgresBackendImpl<T>
231-
where
232-
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
233-
T::Stream: Send + Sync,
234-
T::TlsConnect: Send,
235-
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
236-
{
237191
async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
238192
let mut conn = self.pool.get().await.map_err(|e| {
239193
Error::new(ErrorKind::Other, format!("Failed to fetch a connection from Pool: {}", e))
@@ -481,7 +435,7 @@ where
481435
}
482436

483437
#[async_trait]
484-
impl<T> KvStore for PostgresBackendImpl<T>
438+
impl<T> KvStore for PostgresBackend<T>
485439
where
486440
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
487441
T::Stream: Send + Sync,
@@ -688,30 +642,31 @@ where
688642

689643
#[cfg(test)]
690644
mod tests {
691-
use super::{drop_database, DbConnectionType, DUMMY_MIGRATION, MIGRATIONS};
692-
use crate::postgres_store::PostgresBackendImplPlain;
645+
use super::{drop_database, DUMMY_MIGRATION, MIGRATIONS};
646+
use crate::postgres_store::PostgresPlaintextBackend;
693647
use api::define_kv_store_tests;
694648
use tokio::sync::OnceCell;
649+
use tokio_postgres::NoTls;
695650

696651
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
697652
const MIGRATIONS_START: usize = 0;
698653
const MIGRATIONS_END: usize = MIGRATIONS.len();
699654

700655
static START: OnceCell<()> = OnceCell::const_new();
701656

702-
define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImplPlain, {
657+
define_kv_store_tests!(PostgresKvStoreTest, PostgresPlaintextBackend, {
703658
let db_name = "postgres_kv_store_tests";
704659
START
705660
.get_or_init(|| async {
706-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, DbConnectionType::Plain).await;
661+
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
707662
let store =
708-
PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
663+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
709664
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
710665
assert_eq!(start, MIGRATIONS_START);
711666
assert_eq!(end, MIGRATIONS_END);
712667
})
713668
.await;
714-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
669+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
715670
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
716671
assert_eq!(start, MIGRATIONS_END);
717672
assert_eq!(end, MIGRATIONS_END);
@@ -724,35 +679,35 @@ mod tests {
724679
#[should_panic(expected = "We do not allow downgrades")]
725680
async fn panic_on_downgrade() {
726681
let db_name = "panic_on_downgrade_test";
727-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, DbConnectionType::Plain).await;
682+
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
728683
{
729684
let mut migrations = MIGRATIONS.to_vec();
730685
migrations.push(DUMMY_MIGRATION);
731-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
686+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
732687
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
733688
assert_eq!(start, MIGRATIONS_START);
734689
assert_eq!(end, MIGRATIONS_END + 1);
735690
};
736691
{
737-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
692+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
738693
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
739694
};
740695
}
741696

742697
#[tokio::test]
743698
async fn new_migrations_increments_upgrades() {
744699
let db_name = "new_migrations_increments_upgrades_test";
745-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, DbConnectionType::Plain).await;
700+
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
746701
{
747-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
702+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
748703
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
749704
assert_eq!(start, MIGRATIONS_START);
750705
assert_eq!(end, MIGRATIONS_END);
751706
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
752707
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
753708
};
754709
{
755-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
710+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
756711
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
757712
assert_eq!(start, MIGRATIONS_END);
758713
assert_eq!(end, MIGRATIONS_END);
@@ -763,7 +718,7 @@ mod tests {
763718
let mut migrations = MIGRATIONS.to_vec();
764719
migrations.push(DUMMY_MIGRATION);
765720
{
766-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
721+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
767722
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
768723
assert_eq!(start, MIGRATIONS_END);
769724
assert_eq!(end, MIGRATIONS_END + 1);
@@ -774,7 +729,7 @@ mod tests {
774729
migrations.push(DUMMY_MIGRATION);
775730
migrations.push(DUMMY_MIGRATION);
776731
{
777-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
732+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
778733
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
779734
assert_eq!(start, MIGRATIONS_END + 1);
780735
assert_eq!(end, MIGRATIONS_END + 3);
@@ -786,13 +741,13 @@ mod tests {
786741
};
787742

788743
{
789-
let store = PostgresBackendImplPlain::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
744+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
790745
let list = store.get_upgrades_list().await;
791746
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
792747
let version = store.get_schema_version().await;
793748
assert_eq!(version, MIGRATIONS_END + 3);
794749
}
795750

796-
drop_database(POSTGRES_ENDPOINT, db_name, DbConnectionType::Plain).await.unwrap();
751+
drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await.unwrap();
797752
}
798753
}

0 commit comments

Comments
 (0)