1- //! PostgreSQL backend for cold storage.
1+ //! Unified SQL backend for cold storage.
2+ //!
3+ //! Supports both PostgreSQL and SQLite via [`sqlx::Any`]. The backend
4+ //! auto-detects the database type at construction time and runs the
5+ //! appropriate migration.
26
37use crate :: SqlColdError ;
48use crate :: convert:: {
@@ -13,37 +17,74 @@ use signet_cold::{
1317use signet_storage_types:: {
1418 ConfirmationMeta , DbSignetEvent , DbZenithHeader , Receipt , TransactionSigned ,
1519} ;
16- use sqlx:: { PgPool , Row } ;
20+ use sqlx:: { AnyPool , Row } ;
1721
18- /// PostgreSQL -based cold storage backend.
22+ /// SQL -based cold storage backend.
1923///
20- /// Uses an `sqlx::PgPool` for connection management and connection pooling.
24+ /// Uses [`sqlx::Any`] for database-agnostic access, supporting both
25+ /// PostgreSQL and SQLite through a single implementation. The backend
26+ /// is determined by the connection URL at construction time.
2127///
2228/// # Example
2329///
2430/// ```no_run
2531/// # async fn example() {
26- /// use signet_cold_sql::PostgresColdBackend;
27- /// use sqlx::PgPool;
32+ /// use signet_cold_sql::SqlColdBackend;
2833///
29- /// let pool = PgPool::connect("postgres://localhost/signet").await.unwrap();
30- /// let backend = PostgresColdBackend::new(pool).await.unwrap();
34+ /// // SQLite (in-memory)
35+ /// let backend = SqlColdBackend::connect("sqlite::memory:").await.unwrap();
36+ ///
37+ /// // PostgreSQL
38+ /// let backend = SqlColdBackend::connect("postgres://localhost/signet").await.unwrap();
3139/// # }
3240/// ```
3341#[ derive( Debug , Clone ) ]
34- pub struct PostgresColdBackend {
35- pool : PgPool ,
42+ pub struct SqlColdBackend {
43+ pool : AnyPool ,
3644}
3745
38- impl PostgresColdBackend {
39- /// Create a new PostgreSQL cold storage backend.
46+ impl SqlColdBackend {
47+ /// Create a new SQL cold storage backend from an existing [`AnyPool`] .
4048 ///
41- /// Creates all tables if they do not already exist.
42- pub async fn new ( pool : PgPool ) -> Result < Self , SqlColdError > {
43- sqlx:: raw_sql ( include_str ! ( "../migrations/001_initial_pg.sql" ) ) . execute ( & pool) . await ?;
49+ /// Auto-detects the database backend and creates all tables if they
50+ /// do not already exist. Callers must ensure
51+ /// [`sqlx::any::install_default_drivers`] has been called before
52+ /// constructing the pool.
53+ pub async fn new ( pool : AnyPool ) -> Result < Self , SqlColdError > {
54+ // Detect backend from a pooled connection.
55+ let conn = pool. acquire ( ) . await ?;
56+ let backend = conn. backend_name ( ) . to_owned ( ) ;
57+ drop ( conn) ;
58+
59+ let migration = match backend. as_str ( ) {
60+ "PostgreSQL" => include_str ! ( "../migrations/001_initial_pg.sql" ) ,
61+ "SQLite" => include_str ! ( "../migrations/001_initial.sql" ) ,
62+ other => {
63+ return Err ( SqlColdError :: Convert ( format ! (
64+ "unsupported database backend: {other}"
65+ ) ) ) ;
66+ }
67+ } ;
68+ // Execute via pool to ensure the migration uses the same
69+ // connection that subsequent queries will use.
70+ sqlx:: raw_sql ( migration) . execute ( & pool) . await ?;
4471 Ok ( Self { pool } )
4572 }
4673
74+ /// Connect to a database URL and create the backend.
75+ ///
76+ /// Installs the default sqlx drivers on the first call. The database
77+ /// type is inferred from the URL scheme (`sqlite:` or `postgres:`).
78+ ///
79+ /// For SQLite in-memory databases (`sqlite::memory:`), the pool is
80+ /// limited to one connection to ensure all operations share the same
81+ /// database.
82+ pub async fn connect ( url : & str ) -> Result < Self , SqlColdError > {
83+ sqlx:: any:: install_default_drivers ( ) ;
84+ let pool: AnyPool = sqlx:: pool:: PoolOptions :: new ( ) . max_connections ( 1 ) . connect ( url) . await ?;
85+ Self :: new ( pool) . await
86+ }
87+
4788 // ========================================================================
4889 // Specifier resolution
4990 // ========================================================================
@@ -231,7 +272,7 @@ impl PostgresColdBackend {
231272 block_number : rr. get ( "block_number" ) ,
232273 tx_index : rr. get ( "tx_index" ) ,
233274 tx_type : rr. get :: < i32 , _ > ( "tx_type" ) as i16 ,
234- success : rr. get :: < bool , _ > ( "success" ) ,
275+ success : rr. get :: < i32 , _ > ( "success" ) != 0 ,
235276 cumulative_gas_used : rr. get ( "cumulative_gas_used" ) ,
236277 } ;
237278
@@ -329,7 +370,7 @@ impl PostgresColdBackend {
329370 . bind ( tr. tx_index )
330371 . bind ( & tr. tx_hash )
331372 . bind ( tr. tx_type as i32 )
332- . bind ( tr. sig_y_parity )
373+ . bind ( tr. sig_y_parity as i32 )
333374 . bind ( & tr. sig_r )
334375 . bind ( & tr. sig_s )
335376 . bind ( tr. chain_id )
@@ -359,7 +400,7 @@ impl PostgresColdBackend {
359400 . bind ( rr. block_number )
360401 . bind ( rr. tx_index )
361402 . bind ( rr. tx_type as i32 )
362- . bind ( rr. success )
403+ . bind ( rr. success as i32 )
363404 . bind ( rr. cumulative_gas_used )
364405 . execute ( & mut * tx)
365406 . await ?;
@@ -468,13 +509,13 @@ impl PostgresColdBackend {
468509}
469510
470511/// Convert a sqlx row to a TxRow.
471- fn row_to_tx_row ( r : & sqlx:: postgres :: PgRow ) -> TxRow {
512+ fn row_to_tx_row ( r : & sqlx:: any :: AnyRow ) -> TxRow {
472513 TxRow {
473514 block_number : r. get ( "block_number" ) ,
474515 tx_index : r. get ( "tx_index" ) ,
475516 tx_hash : r. get ( "tx_hash" ) ,
476517 tx_type : r. get :: < i32 , _ > ( "tx_type" ) as i16 ,
477- sig_y_parity : r. get ( "sig_y_parity" ) ,
518+ sig_y_parity : r. get :: < i32 , _ > ( "sig_y_parity" ) != 0 ,
478519 sig_r : r. get ( "sig_r" ) ,
479520 sig_s : r. get ( "sig_s" ) ,
480521 chain_id : r. get ( "chain_id" ) ,
@@ -493,7 +534,7 @@ fn row_to_tx_row(r: &sqlx::postgres::PgRow) -> TxRow {
493534 }
494535}
495536
496- fn row_to_signet_event_row ( r : & sqlx:: postgres :: PgRow ) -> SignetEventRow {
537+ fn row_to_signet_event_row ( r : & sqlx:: any :: AnyRow ) -> SignetEventRow {
497538 SignetEventRow {
498539 block_number : r. get ( "block_number" ) ,
499540 event_index : r. get ( "event_index" ) ,
@@ -512,7 +553,7 @@ fn row_to_signet_event_row(r: &sqlx::postgres::PgRow) -> SignetEventRow {
512553 }
513554}
514555
515- fn row_to_zenith_header_row ( r : & sqlx:: postgres :: PgRow ) -> ZenithHeaderRow {
556+ fn row_to_zenith_header_row ( r : & sqlx:: any :: AnyRow ) -> ZenithHeaderRow {
516557 ZenithHeaderRow {
517558 block_number : r. get ( "block_number" ) ,
518559 host_block_number : r. get ( "host_block_number" ) ,
@@ -523,7 +564,7 @@ fn row_to_zenith_header_row(r: &sqlx::postgres::PgRow) -> ZenithHeaderRow {
523564 }
524565}
525566
526- impl ColdStorage for PostgresColdBackend {
567+ impl ColdStorage for SqlColdBackend {
527568 async fn get_header ( & self , spec : HeaderSpecifier ) -> ColdResult < Option < Header > > {
528569 let Some ( block_num) = self . resolve_header_spec ( spec) . await ? else {
529570 return Ok ( None ) ;
@@ -616,7 +657,7 @@ impl ColdStorage for PostgresColdBackend {
616657 block_number : rr. get ( "block_number" ) ,
617658 tx_index : tx_idx,
618659 tx_type : rr. get :: < i32 , _ > ( "tx_type" ) as i16 ,
619- success : rr. get :: < bool , _ > ( "success" ) ,
660+ success : rr. get :: < i32 , _ > ( "success" ) != 0 ,
620661 cumulative_gas_used : rr. get ( "cumulative_gas_used" ) ,
621662 } ;
622663
@@ -825,13 +866,18 @@ mod tests {
825866 use signet_cold:: conformance:: conformance;
826867
827868 #[ tokio:: test]
828- async fn pg_backend_conformance ( ) {
869+ async fn sqlite_conformance ( ) {
870+ let backend = SqlColdBackend :: connect ( "sqlite::memory:" ) . await . unwrap ( ) ;
871+ conformance ( & backend) . await . unwrap ( ) ;
872+ }
873+
874+ #[ tokio:: test]
875+ async fn pg_conformance ( ) {
829876 let Ok ( url) = std:: env:: var ( "DATABASE_URL" ) else {
830877 eprintln ! ( "skipping pg conformance: DATABASE_URL not set" ) ;
831878 return ;
832879 } ;
833- let pool = PgPool :: connect ( & url) . await . unwrap ( ) ;
834- let backend = PostgresColdBackend :: new ( pool) . await . unwrap ( ) ;
880+ let backend = SqlColdBackend :: connect ( & url) . await . unwrap ( ) ;
835881 conformance ( & backend) . await . unwrap ( ) ;
836882 }
837883}
0 commit comments