Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ tokio-postgres = "0.7.16"
tokio = { version = "1.50.0", features = ["rt-multi-thread", "macros", "default"]}
async-recursion = "1.1.1"
bb8 = "0.9.1"
rusqlite = { version = "0.31", features = ["bundled"] }
log = "0.4.29"

[dev-dependencies]
Expand All @@ -39,3 +40,4 @@ predicates = "3.1.4"
tempfile = "3.27.0"
test_utils = { path="test-utils" }
pretty_assertions = "1.4.1"
rusqlite = { version = "0.31", features = ["bundled"] }
21 changes: 18 additions & 3 deletions src/common/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ impl Config {
.or_else(|| default_config.map(|x| x.db_host.clone()))
};

let db_host = match (db_url.is_some(), db_host_chain()) {
let is_sqlite = matches!(db_type, DatabaseType::Sqlite);

let db_host = match (db_url.is_some() || is_sqlite, db_host_chain()) {
(true, Some(v)) => v,
(true, None) => String::new(),
(false, Some(v)) => v,
Expand All @@ -254,7 +256,7 @@ impl Config {
.or_else(|| default_config.map(|x| x.db_port))
};

let db_port = match (db_url.is_some(), db_port_chain()) {
let db_port = match (db_url.is_some() || is_sqlite, db_port_chain()) {
(true, Some(v)) => v,
(true, None) => 0,
(false, Some(v)) => v,
Expand All @@ -275,7 +277,7 @@ impl Config {
.or_else(|| default_config.map(|x| x.db_user.clone()))
};

let db_user = match (db_url.is_some(), db_user_chain()) {
let db_user = match (db_url.is_some() || is_sqlite, db_user_chain()) {
(true, Some(v)) => v,
(true, None) => String::new(),
(false, Some(v)) => v,
Expand Down Expand Up @@ -381,6 +383,19 @@ impl Config {
.to_string()
}

/// Returns the file path for a SQLite database connection.
/// If DB_URL is provided, it's used directly. Otherwise DB_NAME is used as the file path.
pub fn get_sqlite_path(&self, conn: &DbConnectionConfig) -> String {
if let Some(db_url) = &conn.db_url {
return db_url.to_owned();
}

conn
.db_name
.clone()
.unwrap_or_else(|| panic!("DB_NAME (file path) is required for SQLite connections"))
}

pub fn get_postgres_cred(&self, conn: &DbConnectionConfig) -> String {
// If custom DB_URL is provided, use it directly
if let Some(db_url) = &conn.db_url {
Expand Down
10 changes: 4 additions & 6 deletions src/common/dotenv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ impl Dotenv {
Dotenv {
db_type: match Self::get_var("DB_TYPE") {
None => None,
Some(val) => {
if val == "mysql" {
Some(DatabaseType::Mysql)
} else {
Some(DatabaseType::Postgres)
}
Some(val) => match val.as_str() {
"mysql" => Some(DatabaseType::Mysql),
"sqlite" => Some(DatabaseType::Sqlite),
_ => Some(DatabaseType::Postgres),
}
},
db_user: Self::get_var("DB_USER"),
Expand Down
15 changes: 15 additions & 0 deletions src/common/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::common::types::DatabaseType;
use crate::core::connection::{DBConn, DBConnections};
use crate::core::mysql::pool::MySqlConnectionManager;
use crate::core::postgres::pool::PostgresConnectionManager;
use crate::core::sqlite::pool::SqliteConnectionManager;
use crate::ts_generator::information_schema::DBSchema;
use clap::Parser;
use std::sync::LazyLock;
Expand Down Expand Up @@ -49,6 +50,20 @@ pub static DB_CONN_CACHE: LazyLock<HashMap<String, Arc<Mutex<DBConn>>>> = LazyLo
DBConn::MySQLPooledConn(Mutex::new(pool))
})
}),
DatabaseType::Sqlite => task::block_in_place(|| {
Handle::current().block_on(async {
let sqlite_path = CONFIG.get_sqlite_path(connection_config);
let manager = SqliteConnectionManager::new(sqlite_path, connection.to_string());
let pool = bb8::Pool::builder()
.max_size(connection_config.pool_size)
.connection_timeout(std::time::Duration::from_secs(connection_config.connection_timeout))
.build(manager)
.await
.expect(&ERR_DB_CONNECTION_ISSUE);

DBConn::SqliteConn(Mutex::new(pool))
})
}),
DatabaseType::Postgres => task::block_in_place(|| {
Handle::current().block_on(async {
let postgres_cred = CONFIG.get_postgres_cred(connection_config);
Expand Down
1 change: 1 addition & 0 deletions src/common/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum FileExtension {
pub enum DatabaseType {
Postgres,
Mysql,
Sqlite,
}

#[derive(ValueEnum, Debug, Clone, Serialize, Deserialize)]
Expand Down
5 changes: 5 additions & 0 deletions src/core/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::common::types::DatabaseType;
use crate::common::SQL;
use crate::core::mysql::prepare as mysql_explain;
use crate::core::postgres::prepare as postgres_explain;
use crate::core::sqlite::prepare as sqlite_explain;
use crate::ts_generator::types::ts_query::TsQuery;
use bb8::Pool;
use std::collections::HashMap;
Expand All @@ -11,6 +12,7 @@ use tokio::sync::Mutex;

use super::mysql::pool::MySqlConnectionManager;
use super::postgres::pool::PostgresConnectionManager;
use super::sqlite::pool::SqliteConnectionManager;
use crate::common::errors::DB_CONN_FROM_LOCAL_CACHE_ERROR;
use color_eyre::Result;
use swc_common::errors::Handler;
Expand All @@ -19,6 +21,7 @@ use swc_common::errors::Handler;
pub enum DBConn {
MySQLPooledConn(Mutex<Pool<MySqlConnectionManager>>),
PostgresConn(Mutex<Pool<PostgresConnectionManager>>),
SqliteConn(Mutex<Pool<SqliteConnectionManager>>),
}

impl DBConn {
Expand All @@ -31,6 +34,7 @@ impl DBConn {
let (explain_failed, ts_query) = match &self {
DBConn::MySQLPooledConn(_conn) => mysql_explain::prepare(self, sql, should_generate_types, handler).await?,
DBConn::PostgresConn(_conn) => postgres_explain::prepare(self, sql, should_generate_types, handler).await?,
DBConn::SqliteConn(_conn) => sqlite_explain::prepare(self, sql, should_generate_types, handler).await?,
};

Ok((explain_failed, ts_query))
Expand All @@ -41,6 +45,7 @@ impl DBConn {
match self {
DBConn::MySQLPooledConn(_) => DatabaseType::Mysql,
DBConn::PostgresConn(_) => DatabaseType::Postgres,
DBConn::SqliteConn(_) => DatabaseType::Sqlite,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod connection;
pub mod execute;
pub mod mysql;
pub mod postgres;
pub mod sqlite;
2 changes: 2 additions & 0 deletions src/core/sqlite/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod pool;
pub mod prepare;
82 changes: 82 additions & 0 deletions src/core/sqlite/pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use rusqlite::Connection;
use std::sync::{Arc, Mutex};
use tokio::task;

/// A connection manager for SQLite that wraps rusqlite's synchronous Connection
/// behind an Arc<Mutex<>> for thread-safe access with bb8 connection pooling.
#[derive(Clone, Debug)]
pub struct SqliteConnectionManager {
db_path: String,
connection_name: String,
}

/// Wrapper around rusqlite::Connection to make it Send + Sync for bb8
pub struct SqliteConnection {
pub conn: Arc<Mutex<Connection>>,
}

// Safety: rusqlite::Connection is not Send by default, but we protect it with Mutex
// and only access it via spawn_blocking
unsafe impl Send for SqliteConnection {}
unsafe impl Sync for SqliteConnection {}

impl SqliteConnectionManager {
pub fn new(db_path: String, connection_name: String) -> Self {
Self {
db_path,
connection_name,
}
}
}

#[derive(Debug)]
pub struct SqlitePoolError(pub String);

impl std::fmt::Display for SqlitePoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SQLite pool error: {}", self.0)
}
}

impl std::error::Error for SqlitePoolError {}

impl bb8::ManageConnection for SqliteConnectionManager {
type Connection = SqliteConnection;
type Error = SqlitePoolError;

async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let db_path = self.db_path.clone();
let connection_name = self.connection_name.clone();

let conn = task::spawn_blocking(move || {
Connection::open(&db_path).unwrap_or_else(|err| {
panic!(
"Failed to open SQLite database at '{}' for connection '{}': {}",
db_path, connection_name, err
)
})
})
.await
.map_err(|e| SqlitePoolError(format!("Failed to spawn blocking task: {e}")))?;

Ok(SqliteConnection {
conn: Arc::new(Mutex::new(conn)),
})
}

async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
let inner = conn.conn.clone();
task::spawn_blocking(move || {
let conn = inner.lock().unwrap();
conn
.execute_batch("SELECT 1")
.map_err(|e| SqlitePoolError(format!("SQLite connection validation failed: {e}")))
})
.await
.map_err(|e| SqlitePoolError(format!("Failed to spawn blocking task: {e}")))?
}

fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}
Loading
Loading