@@ -12,7 +12,7 @@ use chrono::Utc;
1212use std:: cmp:: min;
1313use std:: io;
1414use std:: io:: { Error , ErrorKind } ;
15- use tokio_postgres:: { NoTls , Transaction } ;
15+ use tokio_postgres:: { error , NoTls , Transaction } ;
1616
1717pub ( crate ) struct VssDbRecord {
1818 pub ( crate ) user_token : String ,
@@ -27,6 +27,32 @@ const KEY_COLUMN: &str = "key";
2727const VALUE_COLUMN : & str = "value" ;
2828const VERSION_COLUMN : & str = "version" ;
2929
30+ const DB_VERSION_COLUMN : & str = "db_version" ;
31+
32+ const CHECK_DB_STMT : & str = "SELECT 1 FROM pg_database WHERE datname = $1" ;
33+ const INIT_DB_CMD : & str = "CREATE DATABASE" ;
34+ const GET_VERSION_STMT : & str = "SELECT db_version FROM vss_db_version;" ;
35+ const UPDATE_VERSION_STMT : & str = "UPDATE vss_db_version SET db_version=$1;" ;
36+ const LOG_MIGRATION_STMT : & str = "INSERT INTO vss_db_upgrades VALUES($1);" ;
37+
38+ const MIGRATIONS : & [ & str ] = & [
39+ "CREATE TABLE vss_db_version (db_version INTEGER);" ,
40+ "INSERT INTO vss_db_version VALUES(1);" ,
41+ "CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);" ,
42+ // We do not complain if the table already exists, as a previous version of VSS could have already created
43+ // this table
44+ "CREATE TABLE IF NOT EXISTS vss_db (
45+ user_token character varying(120) NOT NULL CHECK (user_token <> ''),
46+ store_id character varying(120) NOT NULL CHECK (store_id <> ''),
47+ key character varying(600) NOT NULL,
48+ value bytea NULL,
49+ version bigint NOT NULL,
50+ created_at TIMESTAMP WITH TIME ZONE,
51+ last_updated_at TIMESTAMP WITH TIME ZONE,
52+ PRIMARY KEY (user_token, store_id, key)
53+ );" ,
54+ ] ;
55+
3056/// The maximum number of key versions that can be returned in a single page.
3157///
3258/// This constant helps control memory and bandwidth usage for list operations,
@@ -46,17 +72,149 @@ pub struct PostgresBackendImpl {
4672 pool : Pool < PostgresConnectionManager < NoTls > > ,
4773}
4874
75+ async fn initialize_vss_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
76+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
77+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
78+ . await
79+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
80+ // Connection must be driven on a separate task, and will resolve when the client is dropped
81+ tokio:: spawn ( async move {
82+ if let Err ( e) = connection. await {
83+ eprintln ! ( "Connection error: {}" , e) ;
84+ }
85+ } ) ;
86+
87+ let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
88+ Error :: new (
89+ ErrorKind :: Other ,
90+ format ! ( "Failed to check presence of database {}: {}" , db_name, e) ,
91+ )
92+ } ) ?;
93+
94+ if num_rows == 0 {
95+ let stmt = format ! ( "{} {}" , INIT_DB_CMD , db_name) ;
96+ client. execute ( & stmt, & [ ] ) . await . map_err ( |e| {
97+ Error :: new ( ErrorKind :: Other , format ! ( "Failed to create database {}: {}" , db_name, e) )
98+ } ) ?;
99+ println ! ( "Created database {}" , db_name) ;
100+ }
101+
102+ Ok ( ( ) )
103+ }
104+
49105impl PostgresBackendImpl {
50106 /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
51- pub async fn new ( dsn : & str ) -> Result < Self , Error > {
52- let manager = PostgresConnectionManager :: new_from_stringlike ( dsn, NoTls ) . map_err ( |e| {
53- Error :: new ( ErrorKind :: Other , format ! ( "Connection manager error: {}" , e) )
54- } ) ?;
107+ pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
108+ initialize_vss_database ( postgres_endpoint, db_name) . await ?;
109+
110+ let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
111+ let manager =
112+ PostgresConnectionManager :: new_from_stringlike ( vss_dsn, NoTls ) . map_err ( |e| {
113+ Error :: new (
114+ ErrorKind :: Other ,
115+ format ! ( "Failed to create PostgresConnectionManager: {}" , e) ,
116+ )
117+ } ) ?;
118+ // By default, Pool maintains 0 long-running connections, so returning a pool
119+ // here is no guarantee that Pool established a connection to the database.
120+ //
121+ // See Builder::min_idle to increase the long-running connection count.
55122 let pool = Pool :: builder ( )
56123 . build ( manager)
57124 . await
58- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Pool build error: {}" , e) ) ) ?;
59- Ok ( PostgresBackendImpl { pool } )
125+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
126+ let postgres_backend = PostgresBackendImpl { pool } ;
127+
128+ postgres_backend. migrate_vss_database ( ) . await ?;
129+
130+ Ok ( postgres_backend)
131+ }
132+
133+ async fn migrate_vss_database ( & self ) -> Result < ( ) , Error > {
134+ let mut conn = self . pool . get ( ) . await . map_err ( |e| {
135+ Error :: new (
136+ ErrorKind :: Other ,
137+ format ! ( "Failed to fetch a connection from Pool: {}" , e) ,
138+ )
139+ } ) ?;
140+
141+ // Get the next migration to be applied.
142+ let migration_start = match conn. query_one ( GET_VERSION_STMT , & [ ] ) . await {
143+ Ok ( row) => {
144+ let i: i32 = row. get ( DB_VERSION_COLUMN ) ;
145+ usize:: try_from ( i) . expect ( "The column should always contain unsigned integers" )
146+ } ,
147+ Err ( e) => {
148+ // If the table is not defined, start at migration 0
149+ if let Some ( & error:: SqlState :: UNDEFINED_TABLE ) = e. code ( ) {
150+ 0
151+ } else {
152+ return Err ( Error :: new (
153+ ErrorKind :: Other ,
154+ format ! ( "Failed to query the version of the database schema: {}" , e) ,
155+ ) ) ;
156+ }
157+ } ,
158+ } ;
159+
160+ let tx = conn
161+ . transaction ( )
162+ . await
163+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Transaction start error: {}" , e) ) ) ?;
164+
165+ if migration_start == MIGRATIONS . len ( ) {
166+ // No migrations needed, we are done
167+ return Ok ( ( ) ) ;
168+ } else if migration_start > MIGRATIONS . len ( ) {
169+ panic ! ( "We do not allow downgrades" ) ;
170+ }
171+
172+ println ! ( "Applying migration(s) {} through {}" , migration_start, MIGRATIONS . len( ) - 1 ) ;
173+
174+ for ( idx, & stmt) in ( & MIGRATIONS [ migration_start..] ) . iter ( ) . enumerate ( ) {
175+ let _num_rows = tx. execute ( stmt, & [ ] ) . await . map_err ( |e| {
176+ Error :: new (
177+ ErrorKind :: Other ,
178+ format ! (
179+ "Database migration no {} with stmt {} failed: {}" ,
180+ migration_start + idx,
181+ stmt,
182+ e
183+ ) ,
184+ )
185+ } ) ?;
186+ }
187+
188+ let num_rows = tx
189+ . execute (
190+ LOG_MIGRATION_STMT ,
191+ & [ & i32:: try_from ( migration_start) . expect ( "Read from an i32 further above" ) ] ,
192+ )
193+ . await
194+ . map_err ( |e| {
195+ Error :: new ( ErrorKind :: Other , format ! ( "Failed to log database migration: {}" , e) )
196+ } ) ?;
197+ assert_eq ! ( num_rows, 1 , "LOG_MIGRATION_STMT should only add one row at a time" ) ;
198+
199+ let next_migration_start =
200+ i32:: try_from ( MIGRATIONS . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
201+ let num_rows =
202+ tx. execute ( UPDATE_VERSION_STMT , & [ & next_migration_start] ) . await . map_err ( |e| {
203+ Error :: new (
204+ ErrorKind :: Other ,
205+ format ! ( "Failed to update the version of the schema: {}" , e) ,
206+ )
207+ } ) ?;
208+ assert_eq ! (
209+ num_rows, 1 ,
210+ "UPDATE_VERSION_STMT should only update the unique row in the version table"
211+ ) ;
212+
213+ tx. commit ( ) . await . map_err ( |e| {
214+ Error :: new ( ErrorKind :: Other , format ! ( "Transaction commit error: {}" , e) )
215+ } ) ?;
216+
217+ Ok ( ( ) )
60218 }
61219
62220 fn build_vss_record ( & self , user_token : String , store_id : String , kv : KeyValue ) -> VssDbRecord {
@@ -413,7 +571,7 @@ mod tests {
413571 define_kv_store_tests ! (
414572 PostgresKvStoreTest ,
415573 PostgresBackendImpl ,
416- PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432/ postgres" )
574+ PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432" , " postgres")
417575 . await
418576 . unwrap( )
419577 ) ;
0 commit comments