@@ -11,13 +11,15 @@ use bb8_postgres::bb8::Pool;
1111use bb8_postgres:: PostgresConnectionManager ;
1212use bytes:: Bytes ;
1313use chrono:: Utc ;
14- use native_tls:: { Certificate , TlsConnector } ;
14+ use native_tls:: TlsConnector ;
1515use postgres_native_tls:: MakeTlsConnector ;
1616use std:: cmp:: min;
1717use std:: io:: { self , Error , ErrorKind } ;
1818use tokio_postgres:: tls:: { MakeTlsConnect , TlsConnect } ;
1919use tokio_postgres:: { error, Client , NoTls , Socket , Transaction } ;
2020
21+ pub use native_tls:: Certificate ;
22+
2123pub ( crate ) struct VssDbRecord {
2224 pub ( crate ) user_token : String ,
2325 pub ( crate ) store_id : String ,
@@ -46,7 +48,7 @@ pub const LIST_KEY_VERSIONS_MAX_PAGE_SIZE: i32 = 100;
4648pub const MAX_PUT_REQUEST_ITEM_COUNT : usize = 1000 ;
4749
4850/// A [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
49- pub struct PostgresBackendImpl < T >
51+ pub struct PostgresBackend < T >
5052where
5153 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
5254 <T as MakeTlsConnect < Socket > >:: Stream : Send + Sync ,
@@ -56,67 +58,42 @@ where
5658 pool : Pool < PostgresConnectionManager < T > > ,
5759}
5860
59- /// A postgres backend with plain connections to the database
60- pub type PostgresBackendImplPlain = PostgresBackendImpl < NoTls > ;
61+ /// A postgres backend with plaintext connections to the database
62+ pub type PostgresPlaintextBackend = PostgresBackend < NoTls > ;
6163
6264/// A postgres backend with TLS connections to the database
63- pub type PostgresBackendImplTls = PostgresBackendImpl < MakeTlsConnector > ;
64-
65- enum DbConnectionType {
66- Plain ,
67- Tls ( Option < String > ) ,
68- }
65+ pub type PostgresTlsBackend = PostgresBackend < MakeTlsConnector > ;
6966
70- async fn make_postgres_db_connection (
71- postgres_endpoint : & str , connection_type : DbConnectionType ,
72- ) -> Result < Client , Error > {
67+ async fn make_postgres_db_connection < T > ( postgres_endpoint : & str , tls : T ) -> Result < Client , Error >
68+ where
69+ T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
70+ T :: Stream : Send + Sync ,
71+ T :: TlsConnect : Send ,
72+ <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
73+ {
7374 let dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
74- match connection_type {
75- DbConnectionType :: Plain => {
76- let ( client, connection) = tokio_postgres:: connect ( & dsn, NoTls )
77- . await
78- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
79- // Connection must be driven on a separate task, and will resolve when the client is dropped
80- tokio:: spawn ( async move {
81- if let Err ( e) = connection. await {
82- eprintln ! ( "Connection error: {}" , e) ;
83- }
84- } ) ;
85- Ok ( client)
86- } ,
87- DbConnectionType :: Tls ( ca_file) => {
88- let mut builder = TlsConnector :: builder ( ) ;
89- if let Some ( ca) = ca_file {
90- let cert = std:: fs:: read ( ca) . map_err ( |e| {
91- Error :: new ( ErrorKind :: Other , format ! ( "Error reading certificate file: {}" , e) )
92- } ) ?;
93- let cert = Certificate :: from_pem ( & cert) . map_err ( |e| {
94- Error :: new ( ErrorKind :: Other , format ! ( "Error loading certificate file: {}" , e) )
95- } ) ?;
96- builder. add_root_certificate ( cert) ;
97- }
98- let connector = builder. build ( ) . map_err ( |e| {
99- Error :: new ( ErrorKind :: Other , format ! ( "Error building tls connector: {}" , e) )
100- } ) ?;
101- let connector = MakeTlsConnector :: new ( connector) ;
102- let ( client, connection) = tokio_postgres:: connect ( & dsn, connector)
103- . await
104- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
105- // Connection must be driven on a separate task, and will resolve when the client is dropped
106- tokio:: spawn ( async move {
107- if let Err ( e) = connection. await {
108- eprintln ! ( "Connection error: {}" , e) ;
109- }
110- } ) ;
111- Ok ( client)
112- } ,
113- }
75+ let ( client, connection) = tokio_postgres:: connect ( & dsn, tls)
76+ . await
77+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
78+ // Connection must be driven on a separate task, and will resolve when the client is dropped
79+ tokio:: spawn ( async move {
80+ if let Err ( e) = connection. await {
81+ eprintln ! ( "Connection error: {}" , e) ;
82+ }
83+ } ) ;
84+ Ok ( client)
11485}
11586
116- async fn initialize_vss_database (
117- postgres_endpoint : & str , db_name : & str , connection_type : DbConnectionType ,
118- ) -> Result < ( ) , Error > {
119- let client = make_postgres_db_connection ( & postgres_endpoint, connection_type) . await ?;
87+ async fn initialize_vss_database < T > (
88+ postgres_endpoint : & str , db_name : & str , tls : T ,
89+ ) -> Result < ( ) , Error >
90+ where
91+ T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
92+ T :: Stream : Send + Sync ,
93+ T :: TlsConnect : Send ,
94+ <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
95+ {
96+ let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
12097
12198 let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
12299 Error :: new (
@@ -136,10 +113,14 @@ async fn initialize_vss_database(
136113}
137114
138115#[ cfg( test) ]
139- async fn drop_database (
140- postgres_endpoint : & str , db_name : & str , connection_type : DbConnectionType ,
141- ) -> Result < ( ) , Error > {
142- let client = make_postgres_db_connection ( & postgres_endpoint, connection_type) . await ?;
116+ async fn drop_database < T > ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < ( ) , Error >
117+ where
118+ T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
119+ T :: Stream : Send + Sync ,
120+ T :: TlsConnect : Send ,
121+ <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
122+ {
123+ let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
143124
144125 let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
145126 let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
@@ -150,61 +131,42 @@ async fn drop_database(
150131 Ok ( ( ) )
151132}
152133
153- impl PostgresBackendImplTls {
154- /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information, with a tls connection pool.
134+ impl PostgresPlaintextBackend {
135+ /// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136+ pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
137+ PostgresBackend :: new_internal ( postgres_endpoint, db_name, NoTls ) . await
138+ }
139+ }
140+
141+ impl PostgresTlsBackend {
142+ /// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
155143 pub async fn new (
156- postgres_endpoint : & str , db_name : & str , ca_file : Option < String > ,
144+ postgres_endpoint : & str , db_name : & str , additional_certificate : Option < Certificate > ,
157145 ) -> Result < Self , Error > {
158- initialize_vss_database ( postgres_endpoint, db_name, DbConnectionType :: Tls ( ca_file. clone ( ) ) )
159- . await ?;
160-
161- let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
162146 let mut builder = TlsConnector :: builder ( ) ;
163- if let Some ( ca) = ca_file {
164- let cert = std:: fs:: read ( ca) . map_err ( |e| {
165- Error :: new ( ErrorKind :: Other , format ! ( "Error reading certificate file: {}" , e) )
166- } ) ?;
167- let cert = Certificate :: from_pem ( & cert) . map_err ( |e| {
168- Error :: new ( ErrorKind :: Other , format ! ( "Error loading certificate file: {}" , e) )
169- } ) ?;
147+ if let Some ( cert) = additional_certificate {
170148 builder. add_root_certificate ( cert) ;
171149 }
172150 let connector = builder. build ( ) . map_err ( |e| {
173151 Error :: new ( ErrorKind :: Other , format ! ( "Error building tls connector: {}" , e) )
174152 } ) ?;
175- let connector = MakeTlsConnector :: new ( connector) ;
176- let manager =
177- PostgresConnectionManager :: new_from_stringlike ( vss_dsn, connector) . map_err ( |e| {
178- Error :: new (
179- ErrorKind :: Other ,
180- format ! ( "Failed to create PostgresConnectionManager: {}" , e) ,
181- )
182- } ) ?;
183- // By default, Pool maintains 0 long-running connections, so returning a pool
184- // here is no guarantee that Pool established a connection to the database.
185- //
186- // See Builder::min_idle to increase the long-running connection count.
187- let pool = Pool :: builder ( )
188- . build ( manager)
153+ PostgresBackend :: new_internal ( postgres_endpoint, db_name, MakeTlsConnector :: new ( connector) )
189154 . await
190- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
191- let postgres_backend = PostgresBackendImpl { pool } ;
192-
193- #[ cfg( not( test) ) ]
194- postgres_backend. migrate_vss_database ( MIGRATIONS ) . await ?;
195-
196- Ok ( postgres_backend)
197155 }
198156}
199157
200- impl PostgresBackendImplPlain {
201- /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information, with a plain connection pool.
202- pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
203- initialize_vss_database ( postgres_endpoint, db_name, DbConnectionType :: Plain ) . await ?;
204-
158+ impl < T > PostgresBackend < T >
159+ where
160+ T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
161+ T :: Stream : Send + Sync ,
162+ T :: TlsConnect : Send ,
163+ <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
164+ {
165+ async fn new_internal ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < Self , Error > {
166+ initialize_vss_database ( postgres_endpoint, db_name, tls. clone ( ) ) . await ?;
205167 let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
206168 let manager =
207- PostgresConnectionManager :: new_from_stringlike ( vss_dsn, NoTls ) . map_err ( |e| {
169+ PostgresConnectionManager :: new_from_stringlike ( vss_dsn, tls ) . map_err ( |e| {
208170 Error :: new (
209171 ErrorKind :: Other ,
210172 format ! ( "Failed to create PostgresConnectionManager: {}" , e) ,
@@ -218,22 +180,14 @@ impl PostgresBackendImplPlain {
218180 . build ( manager)
219181 . await
220182 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
221- let postgres_backend = PostgresBackendImpl { pool } ;
183+ let postgres_backend = PostgresBackend { pool } ;
222184
223185 #[ cfg( not( test) ) ]
224186 postgres_backend. migrate_vss_database ( MIGRATIONS ) . await ?;
225187
226188 Ok ( postgres_backend)
227189 }
228- }
229190
230- impl < T > PostgresBackendImpl < T >
231- where
232- T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
233- T :: Stream : Send + Sync ,
234- T :: TlsConnect : Send ,
235- <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
236- {
237191 async fn migrate_vss_database ( & self , migrations : & [ & str ] ) -> Result < ( usize , usize ) , Error > {
238192 let mut conn = self . pool . get ( ) . await . map_err ( |e| {
239193 Error :: new ( ErrorKind :: Other , format ! ( "Failed to fetch a connection from Pool: {}" , e) )
@@ -481,7 +435,7 @@ where
481435}
482436
483437#[ async_trait]
484- impl < T > KvStore for PostgresBackendImpl < T >
438+ impl < T > KvStore for PostgresBackend < T >
485439where
486440 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
487441 T :: Stream : Send + Sync ,
@@ -688,30 +642,31 @@ where
688642
689643#[ cfg( test) ]
690644mod tests {
691- use super :: { drop_database, DbConnectionType , DUMMY_MIGRATION , MIGRATIONS } ;
692- use crate :: postgres_store:: PostgresBackendImplPlain ;
645+ use super :: { drop_database, DUMMY_MIGRATION , MIGRATIONS } ;
646+ use crate :: postgres_store:: PostgresPlaintextBackend ;
693647 use api:: define_kv_store_tests;
694648 use tokio:: sync:: OnceCell ;
649+ use tokio_postgres:: NoTls ;
695650
696651 const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
697652 const MIGRATIONS_START : usize = 0 ;
698653 const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
699654
700655 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
701656
702- define_kv_store_tests ! ( PostgresKvStoreTest , PostgresBackendImplPlain , {
657+ define_kv_store_tests ! ( PostgresKvStoreTest , PostgresPlaintextBackend , {
703658 let db_name = "postgres_kv_store_tests" ;
704659 START
705660 . get_or_init( || async {
706- let _ = drop_database( POSTGRES_ENDPOINT , db_name, DbConnectionType :: Plain ) . await ;
661+ let _ = drop_database( POSTGRES_ENDPOINT , db_name, NoTls ) . await ;
707662 let store =
708- PostgresBackendImplPlain :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
663+ PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
709664 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
710665 assert_eq!( start, MIGRATIONS_START ) ;
711666 assert_eq!( end, MIGRATIONS_END ) ;
712667 } )
713668 . await ;
714- let store = PostgresBackendImplPlain :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
669+ let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
715670 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
716671 assert_eq!( start, MIGRATIONS_END ) ;
717672 assert_eq!( end, MIGRATIONS_END ) ;
@@ -724,35 +679,35 @@ mod tests {
724679 #[ should_panic( expected = "We do not allow downgrades" ) ]
725680 async fn panic_on_downgrade ( ) {
726681 let db_name = "panic_on_downgrade_test" ;
727- let _ = drop_database ( POSTGRES_ENDPOINT , db_name, DbConnectionType :: Plain ) . await ;
682+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name, NoTls ) . await ;
728683 {
729684 let mut migrations = MIGRATIONS . to_vec ( ) ;
730685 migrations. push ( DUMMY_MIGRATION ) ;
731- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
686+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
732687 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
733688 assert_eq ! ( start, MIGRATIONS_START ) ;
734689 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
735690 } ;
736691 {
737- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
692+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
738693 let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
739694 } ;
740695 }
741696
742697 #[ tokio:: test]
743698 async fn new_migrations_increments_upgrades ( ) {
744699 let db_name = "new_migrations_increments_upgrades_test" ;
745- let _ = drop_database ( POSTGRES_ENDPOINT , db_name, DbConnectionType :: Plain ) . await ;
700+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name, NoTls ) . await ;
746701 {
747- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
702+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
748703 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
749704 assert_eq ! ( start, MIGRATIONS_START ) ;
750705 assert_eq ! ( end, MIGRATIONS_END ) ;
751706 assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
752707 assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
753708 } ;
754709 {
755- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
710+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
756711 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
757712 assert_eq ! ( start, MIGRATIONS_END ) ;
758713 assert_eq ! ( end, MIGRATIONS_END ) ;
@@ -763,7 +718,7 @@ mod tests {
763718 let mut migrations = MIGRATIONS . to_vec ( ) ;
764719 migrations. push ( DUMMY_MIGRATION ) ;
765720 {
766- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
721+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
767722 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
768723 assert_eq ! ( start, MIGRATIONS_END ) ;
769724 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
@@ -774,7 +729,7 @@ mod tests {
774729 migrations. push ( DUMMY_MIGRATION ) ;
775730 migrations. push ( DUMMY_MIGRATION ) ;
776731 {
777- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
732+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
778733 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
779734 assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
780735 assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
@@ -786,13 +741,13 @@ mod tests {
786741 } ;
787742
788743 {
789- let store = PostgresBackendImplPlain :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
744+ let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
790745 let list = store. get_upgrades_list ( ) . await ;
791746 assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
792747 let version = store. get_schema_version ( ) . await ;
793748 assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
794749 }
795750
796- drop_database ( POSTGRES_ENDPOINT , db_name, DbConnectionType :: Plain ) . await . unwrap ( ) ;
751+ drop_database ( POSTGRES_ENDPOINT , db_name, NoTls ) . await . unwrap ( ) ;
797752 }
798753}
0 commit comments