Skip to content

Commit b500883

Browse files
committed
feat: add sqlx-based MySQL storage backend
1 parent 2957bc2 commit b500883

File tree

6 files changed

+534
-1
lines changed

6 files changed

+534
-1
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ zeroize = { version = "1.7.0", features = ["zeroize_derive"] }
5454
diesel = { version = "2.1.4", features = ["mysql", "r2d2"], optional = true }
5555
r2d2 = { version = "0.8.9", optional = true }
5656
r2d2-diesel = { version = "1.0.0", optional = true }
57+
sqlx = { version = "0.8", optional = true, default-features = true, features = ["macros", "mysql", "postgres", "sqlite", "any", "runtime-tokio", "chrono", "uuid", "json", "bigdecimal"] }
5758
bcrypt = "0.15"
5859
url = "2.5"
5960
ureq = { version = "2.10", features = ["json"] }
@@ -95,7 +96,8 @@ openssl-sys = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" }
9596
toml = "0.8.19"
9697

9798
[features]
98-
default = ["crypto_adaptor_openssl"]
99+
default = ["crypto_adaptor_openssl", "storage_sqlx"]
100+
storage_sqlx = ["sqlx"]
99101
storage_mysql = ["diesel", "r2d2", "r2d2-diesel"]
100102
crypto_adaptor_openssl = ["dep:openssl", "dep:openssl-sys"]
101103
crypto_adaptor_tongsuo = ["dep:openssl", "dep:openssl-sys"]

src/errors.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ pub enum RvError {
301301
source: tokio::task::JoinError,
302302
},
303303

304+
#[error("Some sqlx error happened")]
305+
SqlxError {
306+
#[from]
307+
source: sqlx::Error,
308+
},
309+
304310
#[error("Some string utf8 error happened, {:?}", .source)]
305311
StringUtf8Error {
306312
#[from]

src/storage/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ pub mod barrier_view;
2727
#[cfg(feature = "storage_mysql")]
2828
pub mod mysql;
2929
pub mod physical;
30+
#[cfg(all(not(feature = "sync_handler"), feature = "storage_sqlx"))]
31+
pub mod sqlx;
3032

3133
/// A trait that abstracts core methods for all storage barrier types.
3234
#[maybe_async::maybe_async]
@@ -87,6 +89,11 @@ pub fn new_backend(t: &str, conf: &HashMap<String, Value>) -> Result<Arc<dyn Bac
8789
let backend = mysql::mysql_backend::MysqlBackend::new(conf)?;
8890
Ok(Arc::new(backend))
8991
}
92+
#[cfg(all(not(feature = "sync_handler"), feature = "storage_sqlx"))]
93+
"sqlx" => {
94+
let backend = sqlx::SqlxBackend::new(conf)?;
95+
Ok(Arc::new(backend))
96+
}
9097
"mock" => Ok(Arc::new(physical::mock::MockBackend::new())),
9198
_ => Err(RvError::ErrPhysicalTypeInvalid),
9299
}

src/storage/sqlx/mod.rs

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
use std::{any::Any, collections::HashMap};
2+
3+
use serde::Deserialize;
4+
use serde_json::Value;
5+
6+
use crate::{
7+
errors::RvError,
8+
storage::{Backend, BackendEntry},
9+
utils::{db::strip_db_name, DatabaseName},
10+
};
11+
12+
pub struct SqlxBackend {
13+
pool: sqlx::AnyPool,
14+
table_name: String,
15+
db_scheme: String,
16+
lock: SqlxBackendLock,
17+
}
18+
19+
#[derive(Debug, Clone, Default, sqlx::FromRow, Deserialize)]
20+
pub struct SqlxBackendEntry {
21+
pub vault_key: String,
22+
pub vault_value: Vec<u8>,
23+
}
24+
25+
pub struct SqlxBackendLock {
26+
pool: sqlx::AnyPool,
27+
db_scheme: String,
28+
timeout_secs: i32,
29+
}
30+
31+
impl SqlxBackendLock {
32+
pub fn new(pool: &sqlx::AnyPool, db_scheme: &str, timeout_secs: i32) -> Self {
33+
Self { pool: pool.clone(), db_scheme: db_scheme.to_string(), timeout_secs }
34+
}
35+
36+
async fn lock(&self, lock_name: &str) -> Result<bool, RvError> {
37+
let result: bool = match self.db_scheme.as_str() {
38+
"mysql" => {
39+
let count: Option<i32> = sqlx::query_scalar("SELECT GET_LOCK(SHA1(?), ?) as result")
40+
.bind(lock_name)
41+
.bind(self.timeout_secs)
42+
.fetch_one(&self.pool)
43+
.await?;
44+
count.unwrap_or(0) == 1
45+
}
46+
"postgres" => {
47+
let ret: Option<bool> = sqlx::query_scalar("SELECT pg_advisory_lock(hashtext('?'))")
48+
.bind(lock_name)
49+
.fetch_one(&self.pool)
50+
.await?;
51+
ret.unwrap_or(false)
52+
}
53+
_ => {
54+
return Err(RvError::ErrDatabaseTypeInvalid);
55+
}
56+
};
57+
58+
Ok(result)
59+
}
60+
61+
async fn unlock(&self, lock_name: &str) -> Result<(), RvError> {
62+
match self.db_scheme.as_str() {
63+
"mysql" => {
64+
sqlx::query("SELECT RELEASE_LOCK(SHA1(?))").bind(lock_name).execute(&self.pool).await?;
65+
}
66+
"postgres" => {
67+
sqlx::query("SELECT pg_advisory_unlock(hashtext('?'))").bind(lock_name).execute(&self.pool).await?;
68+
}
69+
_ => {
70+
return Err(RvError::ErrDatabaseTypeInvalid);
71+
}
72+
};
73+
74+
Ok(())
75+
}
76+
}
77+
78+
#[maybe_async::must_be_async]
79+
impl Backend for SqlxBackend {
80+
async fn list(&self, prefix: &str) -> Result<Vec<String>, RvError> {
81+
if prefix.starts_with("/") {
82+
return Err(RvError::ErrPhysicalBackendPrefixInvalid);
83+
}
84+
85+
let results: Vec<SqlxBackendEntry> =
86+
sqlx::query_as(&format!("SELECT vault_key, vault_value FROM {} WHERE vault_key LIKE ?", self.table_name))
87+
.bind(format!("{prefix}%"))
88+
.fetch_all(&self.pool)
89+
.await?;
90+
91+
let mut keys: Vec<String> = Vec::new();
92+
for entry in results.iter() {
93+
let key = entry.vault_key.clone();
94+
let key = key.trim_start_matches(prefix);
95+
match key.find('/') {
96+
Some(i) => {
97+
let key = &key[0..i + 1];
98+
if !keys.contains(&key.to_string()) {
99+
keys.push(key.to_string());
100+
}
101+
}
102+
None => {
103+
keys.push(key.to_string());
104+
}
105+
}
106+
}
107+
Ok(keys)
108+
}
109+
110+
async fn get(&self, key: &str) -> Result<Option<BackendEntry>, RvError> {
111+
if key.starts_with("/") {
112+
return Err(RvError::ErrPhysicalBackendKeyInvalid);
113+
}
114+
115+
let result: Option<SqlxBackendEntry> =
116+
sqlx::query_as(&format!("SELECT vault_key, vault_value FROM {} WHERE vault_key = ?", self.table_name))
117+
.bind(key)
118+
.fetch_optional(&self.pool)
119+
.await?;
120+
121+
if let Some(entry) = result {
122+
return Ok(Some(BackendEntry { key: entry.vault_key, value: entry.vault_value }));
123+
}
124+
125+
Ok(None)
126+
}
127+
128+
async fn put(&self, entry: &BackendEntry) -> Result<(), RvError> {
129+
if entry.key.as_str().starts_with("/") {
130+
return Err(RvError::ErrPhysicalBackendKeyInvalid);
131+
}
132+
133+
let _ = self.lock.lock(&entry.key).await?;
134+
135+
let ret = sqlx::query(&format!(
136+
"INSERT INTO {} VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
137+
self.table_name
138+
))
139+
.bind(entry.key.as_str())
140+
.bind(entry.value.as_slice())
141+
.execute(&self.pool)
142+
.await;
143+
144+
self.lock.unlock(&entry.key).await?;
145+
146+
let _ = ret?;
147+
148+
Ok(())
149+
}
150+
151+
async fn delete(&self, key: &str) -> Result<(), RvError> {
152+
if key.starts_with("/") {
153+
return Err(RvError::ErrPhysicalBackendKeyInvalid);
154+
}
155+
156+
let _ = self.lock(key).await?;
157+
158+
let ret = sqlx::query(&format!("DELETE FROM {} WHERE vault_key = ?", self.table_name))
159+
.bind(key)
160+
.execute(&self.pool)
161+
.await;
162+
163+
self.lock.unlock(key).await?;
164+
165+
let _ = ret?;
166+
167+
Ok(())
168+
}
169+
170+
async fn lock(&self, _lock_name: &str) -> Result<Box<dyn Any>, RvError> {
171+
Ok(Box::new(SqlxBackendLock::new(&self.pool, self.db_scheme.as_str(), 1)))
172+
}
173+
}
174+
175+
impl SqlxBackend {
176+
async fn new_backend(conf: &HashMap<String, Value>) -> Result<Self, RvError> {
177+
let database_url =
178+
conf.get("database_url").and_then(|v| v.as_str()).ok_or(RvError::ErrDatabaseConnectionInfoInvalid)?;
179+
let table_name = conf.get("table").and_then(|v| v.as_str()).unwrap_or("vault");
180+
181+
let db_name = DatabaseName::from_url(database_url)?;
182+
let db_scheme = db_name.scheme().to_string();
183+
184+
let database_url_root = strip_db_name(database_url);
185+
186+
let pool = sqlx::AnyPool::connect(&database_url_root).await?;
187+
188+
match db_name {
189+
DatabaseName::MySql(database_name) => {
190+
let _ = sqlx::query(&format!("CREATE DATABASE IF NOT EXISTS `{database_name}`")).execute(&pool).await?;
191+
let _ = sqlx::query(&format!("CREATE TABLE IF NOT EXISTS `{database_name}.{table_name}` (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))")).execute(&pool).await?;
192+
}
193+
_ => {
194+
return Err(RvError::ErrDatabaseTypeInvalid);
195+
}
196+
}
197+
198+
pool.close().await;
199+
200+
let pool = sqlx::AnyPool::connect_lazy(database_url)?;
201+
202+
let lock = SqlxBackendLock::new(&pool, db_scheme.as_str(), 1);
203+
Ok(SqlxBackend { pool, table_name: table_name.to_string(), db_scheme, lock })
204+
}
205+
206+
pub fn new(conf: &HashMap<String, Value>) -> Result<SqlxBackend, RvError> {
207+
let _database_url =
208+
conf.get("database_url").and_then(|v| v.as_str()).ok_or(RvError::ErrDatabaseConnectionInfoInvalid)?;
209+
210+
sqlx::any::install_default_drivers();
211+
212+
match tokio::runtime::Handle::try_current() {
213+
Ok(_handle) => std::thread::scope(|s| {
214+
let conf = conf.clone();
215+
let handle = s.spawn(move || {
216+
let rt = tokio::runtime::Runtime::new().unwrap();
217+
rt.block_on(async { Self::new_backend(&conf).await })
218+
});
219+
handle.join().unwrap()
220+
}),
221+
Err(_) => {
222+
let rt = tokio::runtime::Runtime::new()?;
223+
rt.block_on(async { Self::new_backend(conf).await })
224+
}
225+
}
226+
}
227+
}
228+
229+
#[cfg(all(test, not(feature = "sync_handler"), feature = "storage_sqlx"))]
230+
mod test {
231+
use std::sync::Arc;
232+
use std::{collections::HashMap, env};
233+
234+
use serde_json::Value;
235+
236+
use super::SqlxBackend;
237+
238+
use crate::errors::RvError;
239+
use crate::storage::test::{test_backend_curd, test_backend_list_prefix};
240+
use crate::test_utils::test_multi_routine;
241+
242+
async fn sqlx_table_clear(backend: &SqlxBackend) -> Result<(), RvError> {
243+
let _ = sqlx::query("TRUNCATE TABLE vault").execute(&backend.pool).await?;
244+
Ok(())
245+
}
246+
247+
#[tokio::test]
248+
async fn test_sqlx_backend() {
249+
let sqlx_pwd = env::var("CARGO_TEST_MYSQL_PASSWORD").unwrap_or("".into());
250+
let mut conf: HashMap<String, Value> = HashMap::new();
251+
conf.insert("database_url".to_string(), Value::String(format!("mysql://root:{sqlx_pwd}@127.0.0.1:3306/vault")));
252+
conf.insert("table".to_string(), Value::String("vault".to_string()));
253+
254+
let backend = SqlxBackend::new(&conf);
255+
256+
assert!(backend.is_ok());
257+
258+
let backend = backend.unwrap();
259+
260+
assert!(sqlx_table_clear(&backend).await.is_ok());
261+
262+
test_backend_curd(&backend).await;
263+
test_backend_list_prefix(&backend).await;
264+
}
265+
266+
#[tokio::test]
267+
async fn test_sqlx_backend_multi_routine() {
268+
let sqlx_pwd = env::var("CARGO_TEST_MYSQL_PASSWORD").unwrap_or("".into());
269+
let mut conf: HashMap<String, Value> = HashMap::new();
270+
conf.insert("database_url".to_string(), Value::String(format!("mysql://root:{sqlx_pwd}@127.0.0.1:3306/vault")));
271+
conf.insert("table".to_string(), Value::String("vault".to_string()));
272+
273+
let backend = SqlxBackend::new(&conf).unwrap();
274+
275+
test_multi_routine(Arc::new(backend));
276+
}
277+
}

0 commit comments

Comments
 (0)