|
| 1 | +//! Database backend abstraction for the catalog. |
| 2 | +//! |
| 3 | +//! This module provides a generic database backend that works with any sqlx-compatible |
| 4 | +//! database (currently Postgres and SQLite). It abstracts over database-specific differences |
| 5 | +//! like parameter binding syntax while providing a unified API for catalog operations. |
| 6 | +//! |
| 7 | +//! # Architecture |
| 8 | +//! |
| 9 | +//! The [`CatalogBackend`] struct is parameterized by a database type that implements |
| 10 | +//! [`CatalogDatabase`]. This trait extends sqlx's `Database` trait with additional |
| 11 | +//! functionality needed for cross-database compatibility. |
| 12 | +//! |
| 13 | +//! # Example |
| 14 | +//! |
| 15 | +//! ```ignore |
| 16 | +//! use rivetdb::catalog::backend::CatalogBackend; |
| 17 | +//! use sqlx::SqlitePool; |
| 18 | +//! |
| 19 | +//! let pool = SqlitePool::connect("sqlite::memory:").await?; |
| 20 | +//! let backend = CatalogBackend::new(pool); |
| 21 | +//! let connections = backend.list_connections().await?; |
| 22 | +//! ``` |
| 23 | +
|
| 24 | +use crate::catalog::manager::{ConnectionInfo, TableInfo}; |
| 25 | +use anyhow::{anyhow, Result}; |
| 26 | +use sqlx::{ |
| 27 | + query, query_as, query_scalar, ColumnIndex, Database, Decode, Encode, Executor, FromRow, |
| 28 | + IntoArguments, Pool, Postgres, Sqlite, Type, |
| 29 | +}; |
| 30 | +use std::borrow::Cow; |
| 31 | + |
| 32 | +/// Extension trait for sqlx databases that provides catalog-specific functionality. |
| 33 | +/// |
| 34 | +/// This trait handles differences in SQL syntax between database backends, |
| 35 | +/// particularly parameter binding syntax (e.g., `$1` for Postgres vs `?` for SQLite). |
| 36 | +pub trait CatalogDatabase: Database { |
| 37 | + /// Returns the parameter placeholder for the given 1-based index. |
| 38 | + /// |
| 39 | + /// - Postgres uses `$1`, `$2`, etc. |
| 40 | + /// - SQLite uses `?` for all parameters (index is ignored). |
| 41 | + fn bind_param(index: usize) -> Cow<'static, str>; |
| 42 | +} |
| 43 | + |
| 44 | +impl CatalogDatabase for Postgres { |
| 45 | + fn bind_param(index: usize) -> Cow<'static, str> { |
| 46 | + Cow::Owned(format!("${}", index)) |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +impl CatalogDatabase for Sqlite { |
| 51 | + fn bind_param(_: usize) -> Cow<'static, str> { |
| 52 | + Cow::Borrowed("?") |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +/// Generic database backend for catalog operations. |
| 57 | +/// |
| 58 | +/// Wraps a sqlx connection pool and provides methods for managing connections |
| 59 | +/// and tables in the catalog. Works with any database that implements [`CatalogDatabase`]. |
| 60 | +pub struct CatalogBackend<DB: CatalogDatabase> { |
| 61 | + pool: Pool<DB>, |
| 62 | +} |
| 63 | + |
| 64 | +impl<DB: CatalogDatabase> CatalogBackend<DB> { |
| 65 | + /// Creates a new backend with the given connection pool. |
| 66 | + pub fn new(pool: Pool<DB>) -> Self { |
| 67 | + Self { pool } |
| 68 | + } |
| 69 | + |
| 70 | + /// Returns a reference to the underlying connection pool. |
| 71 | + pub fn pool(&self) -> &Pool<DB> { |
| 72 | + &self.pool |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +impl<DB> CatalogBackend<DB> |
| 77 | +where |
| 78 | + DB: CatalogDatabase, |
| 79 | + ConnectionInfo: for<'r> FromRow<'r, DB::Row>, |
| 80 | + TableInfo: for<'r> FromRow<'r, DB::Row>, |
| 81 | + for<'q> &'q str: Encode<'q, DB> + Type<DB>, |
| 82 | + for<'q> i32: Encode<'q, DB> + Type<DB>, |
| 83 | + for<'r> i32: Decode<'r, DB>, |
| 84 | + for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB> + Send, |
| 85 | + for<'c> &'c Pool<DB>: Executor<'c, Database = DB>, |
| 86 | + usize: ColumnIndex<DB::Row>, |
| 87 | +{ |
| 88 | + pub async fn list_connections(&self) -> Result<Vec<ConnectionInfo>> { |
| 89 | + query_as::<DB, ConnectionInfo>( |
| 90 | + "SELECT id, name, source_type, config_json FROM connections ORDER BY name", |
| 91 | + ) |
| 92 | + .fetch_all(&self.pool) |
| 93 | + .await |
| 94 | + .map_err(Into::into) |
| 95 | + } |
| 96 | + |
| 97 | + pub async fn add_connection( |
| 98 | + &self, |
| 99 | + name: &str, |
| 100 | + source_type: &str, |
| 101 | + config_json: &str, |
| 102 | + ) -> Result<i32> { |
| 103 | + let insert_sql = format!( |
| 104 | + "INSERT INTO connections (name, source_type, config_json) VALUES ({}, {}, {})", |
| 105 | + DB::bind_param(1), |
| 106 | + DB::bind_param(2), |
| 107 | + DB::bind_param(3) |
| 108 | + ); |
| 109 | + |
| 110 | + query(&insert_sql) |
| 111 | + .bind(name) |
| 112 | + .bind(source_type) |
| 113 | + .bind(config_json) |
| 114 | + .execute(&self.pool) |
| 115 | + .await?; |
| 116 | + |
| 117 | + let select_sql = format!( |
| 118 | + "SELECT id FROM connections WHERE name = {}", |
| 119 | + DB::bind_param(1) |
| 120 | + ); |
| 121 | + |
| 122 | + query_scalar::<DB, i32>(&select_sql) |
| 123 | + .bind(name) |
| 124 | + .fetch_one(&self.pool) |
| 125 | + .await |
| 126 | + .map_err(Into::into) |
| 127 | + } |
| 128 | + |
| 129 | + pub async fn get_connection(&self, name: &str) -> Result<Option<ConnectionInfo>> { |
| 130 | + let sql = format!( |
| 131 | + "SELECT id, name, source_type, config_json FROM connections WHERE name = {}", |
| 132 | + DB::bind_param(1) |
| 133 | + ); |
| 134 | + |
| 135 | + query_as::<DB, ConnectionInfo>(&sql) |
| 136 | + .bind(name) |
| 137 | + .fetch_optional(&self.pool) |
| 138 | + .await |
| 139 | + .map_err(Into::into) |
| 140 | + } |
| 141 | + |
| 142 | + pub async fn add_table( |
| 143 | + &self, |
| 144 | + connection_id: i32, |
| 145 | + schema_name: &str, |
| 146 | + table_name: &str, |
| 147 | + arrow_schema_json: &str, |
| 148 | + ) -> Result<i32> { |
| 149 | + let insert_sql = format!( |
| 150 | + "INSERT INTO tables (connection_id, schema_name, table_name, arrow_schema_json) \ |
| 151 | + VALUES ({}, {}, {}, {}) \ |
| 152 | + ON CONFLICT (connection_id, schema_name, table_name) \ |
| 153 | + DO UPDATE SET arrow_schema_json = excluded.arrow_schema_json", |
| 154 | + DB::bind_param(1), |
| 155 | + DB::bind_param(2), |
| 156 | + DB::bind_param(3), |
| 157 | + DB::bind_param(4) |
| 158 | + ); |
| 159 | + |
| 160 | + query(&insert_sql) |
| 161 | + .bind(connection_id) |
| 162 | + .bind(schema_name) |
| 163 | + .bind(table_name) |
| 164 | + .bind(arrow_schema_json) |
| 165 | + .execute(&self.pool) |
| 166 | + .await?; |
| 167 | + |
| 168 | + let select_sql = format!( |
| 169 | + "SELECT id FROM tables WHERE connection_id = {} AND schema_name = {} AND table_name = {}", |
| 170 | + DB::bind_param(1), |
| 171 | + DB::bind_param(2), |
| 172 | + DB::bind_param(3), |
| 173 | + ); |
| 174 | + |
| 175 | + query_scalar::<DB, i32>(&select_sql) |
| 176 | + .bind(connection_id) |
| 177 | + .bind(schema_name) |
| 178 | + .bind(table_name) |
| 179 | + .fetch_one(&self.pool) |
| 180 | + .await |
| 181 | + .map_err(Into::into) |
| 182 | + } |
| 183 | + |
| 184 | + pub async fn list_tables(&self, connection_id: Option<i32>) -> Result<Vec<TableInfo>> { |
| 185 | + let mut sql = String::from( |
| 186 | + "SELECT id, connection_id, schema_name, table_name, parquet_path, state_path, \ |
| 187 | + CAST(last_sync AS TEXT) as last_sync, arrow_schema_json \ |
| 188 | + FROM tables", |
| 189 | + ); |
| 190 | + |
| 191 | + if connection_id.is_some() { |
| 192 | + sql.push_str(" WHERE connection_id = "); |
| 193 | + sql.push_str(DB::bind_param(1).as_ref()); |
| 194 | + } |
| 195 | + |
| 196 | + sql.push_str(" ORDER BY schema_name, table_name"); |
| 197 | + |
| 198 | + let mut stmt = query_as::<DB, TableInfo>(&sql); |
| 199 | + if let Some(conn_id) = connection_id { |
| 200 | + stmt = stmt.bind(conn_id); |
| 201 | + } |
| 202 | + |
| 203 | + stmt.fetch_all(&self.pool).await.map_err(Into::into) |
| 204 | + } |
| 205 | + |
| 206 | + pub async fn get_table( |
| 207 | + &self, |
| 208 | + connection_id: i32, |
| 209 | + schema_name: &str, |
| 210 | + table_name: &str, |
| 211 | + ) -> Result<Option<TableInfo>> { |
| 212 | + let sql = format!( |
| 213 | + "SELECT id, connection_id, schema_name, table_name, parquet_path, state_path, \ |
| 214 | + CAST(last_sync AS TEXT) as last_sync, arrow_schema_json \ |
| 215 | + FROM tables WHERE connection_id = {} AND schema_name = {} AND table_name = {}", |
| 216 | + DB::bind_param(1), |
| 217 | + DB::bind_param(2), |
| 218 | + DB::bind_param(3), |
| 219 | + ); |
| 220 | + |
| 221 | + query_as::<DB, TableInfo>(&sql) |
| 222 | + .bind(connection_id) |
| 223 | + .bind(schema_name) |
| 224 | + .bind(table_name) |
| 225 | + .fetch_optional(&self.pool) |
| 226 | + .await |
| 227 | + .map_err(Into::into) |
| 228 | + } |
| 229 | + |
| 230 | + pub async fn update_table_sync( |
| 231 | + &self, |
| 232 | + table_id: i32, |
| 233 | + parquet_path: &str, |
| 234 | + state_path: &str, |
| 235 | + ) -> Result<()> { |
| 236 | + let sql = format!( |
| 237 | + "UPDATE tables SET parquet_path = {}, state_path = {}, last_sync = CURRENT_TIMESTAMP \ |
| 238 | + WHERE id = {}", |
| 239 | + DB::bind_param(1), |
| 240 | + DB::bind_param(2), |
| 241 | + DB::bind_param(3), |
| 242 | + ); |
| 243 | + |
| 244 | + query(&sql) |
| 245 | + .bind(parquet_path) |
| 246 | + .bind(state_path) |
| 247 | + .bind(table_id) |
| 248 | + .execute(&self.pool) |
| 249 | + .await?; |
| 250 | + |
| 251 | + Ok(()) |
| 252 | + } |
| 253 | + |
| 254 | + pub async fn clear_table_cache_metadata( |
| 255 | + &self, |
| 256 | + connection_id: i32, |
| 257 | + schema_name: &str, |
| 258 | + table_name: &str, |
| 259 | + ) -> Result<TableInfo> { |
| 260 | + let table = self |
| 261 | + .get_table(connection_id, schema_name, table_name) |
| 262 | + .await? |
| 263 | + .ok_or_else(|| anyhow!("Table '{}.{}' not found", schema_name, table_name))?; |
| 264 | + |
| 265 | + let sql = format!( |
| 266 | + "UPDATE tables SET parquet_path = NULL, state_path = NULL, last_sync = NULL WHERE id = {}", |
| 267 | + DB::bind_param(1) |
| 268 | + ); |
| 269 | + |
| 270 | + query(&sql).bind(table.id).execute(&self.pool).await?; |
| 271 | + |
| 272 | + Ok(table) |
| 273 | + } |
| 274 | + |
| 275 | + pub async fn clear_connection_cache_metadata(&self, name: &str) -> Result<()> { |
| 276 | + let connection = self |
| 277 | + .get_connection(name) |
| 278 | + .await? |
| 279 | + .ok_or_else(|| anyhow!("Connection '{}' not found", name))?; |
| 280 | + |
| 281 | + let sql = format!( |
| 282 | + "UPDATE tables SET parquet_path = NULL, state_path = NULL, last_sync = NULL \ |
| 283 | + WHERE connection_id = {}", |
| 284 | + DB::bind_param(1) |
| 285 | + ); |
| 286 | + |
| 287 | + query(&sql).bind(connection.id).execute(&self.pool).await?; |
| 288 | + |
| 289 | + Ok(()) |
| 290 | + } |
| 291 | + |
| 292 | + pub async fn delete_connection(&self, name: &str) -> Result<()> { |
| 293 | + let connection = self |
| 294 | + .get_connection(name) |
| 295 | + .await? |
| 296 | + .ok_or_else(|| anyhow!("Connection '{}' not found", name))?; |
| 297 | + |
| 298 | + let delete_tables_sql = format!( |
| 299 | + "DELETE FROM tables WHERE connection_id = {}", |
| 300 | + DB::bind_param(1) |
| 301 | + ); |
| 302 | + |
| 303 | + query(&delete_tables_sql) |
| 304 | + .bind(connection.id) |
| 305 | + .execute(&self.pool) |
| 306 | + .await?; |
| 307 | + |
| 308 | + let delete_connection_sql = |
| 309 | + format!("DELETE FROM connections WHERE id = {}", DB::bind_param(1)); |
| 310 | + |
| 311 | + query(&delete_connection_sql) |
| 312 | + .bind(connection.id) |
| 313 | + .execute(&self.pool) |
| 314 | + .await?; |
| 315 | + |
| 316 | + Ok(()) |
| 317 | + } |
| 318 | +} |
0 commit comments