Skip to content

Commit 9127da6

Browse files
authored
refactor!: use Arc internally in Database (#432)
1 parent 54e7894 commit 9127da6

File tree

11 files changed

+81
-159
lines changed

11 files changed

+81
-159
lines changed

cot/src/auth/db.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::any::Any;
77
use std::borrow::Cow;
88
use std::fmt::{Display, Formatter};
9-
use std::sync::Arc;
109

1110
use async_trait::async_trait;
1211
// Importing `Auto` from `cot` instead of `crate` so that the migration generator
@@ -192,10 +191,10 @@ impl DatabaseUser {
192191
/// use cot::auth::UserId;
193192
/// use cot::auth::db::DatabaseUser;
194193
/// use cot::common_types::Password;
194+
/// use cot::db::Database;
195195
/// use cot::html::Html;
196-
/// use cot::request::extractors::RequestDb;
197196
///
198-
/// async fn view(RequestDb(db): RequestDb) -> cot::Result<Html> {
197+
/// async fn view(db: Database) -> cot::Result<Html> {
199198
/// let user =
200199
/// DatabaseUser::create_user(&db, "testuser".to_string(), &Password::new("password123"))
201200
/// .await?;
@@ -210,7 +209,7 @@ impl DatabaseUser {
210209
/// # use cot::test::{TestDatabase, TestRequestBuilder};
211210
/// # let mut test_database = TestDatabase::new_sqlite().await?;
212211
/// # test_database.with_auth().run_migrations().await;
213-
/// # view(RequestDb(test_database.database())).await?;
212+
/// # view(test_database.database()).await?;
214213
/// # test_database.cleanup().await?;
215214
/// # Ok(())
216215
/// # }
@@ -284,10 +283,10 @@ impl DatabaseUser {
284283
/// use cot::auth::UserId;
285284
/// use cot::auth::db::DatabaseUser;
286285
/// use cot::common_types::Password;
286+
/// use cot::db::Database;
287287
/// use cot::html::Html;
288-
/// use cot::request::extractors::RequestDb;
289288
///
290-
/// async fn view(RequestDb(db): RequestDb) -> cot::Result<Html> {
289+
/// async fn view(db: Database) -> cot::Result<Html> {
291290
/// let user =
292291
/// DatabaseUser::create_user(&db, "testuser".to_string(), &Password::new("password123"))
293292
/// .await?;
@@ -327,10 +326,10 @@ impl DatabaseUser {
327326
/// use cot::auth::UserId;
328327
/// use cot::auth::db::DatabaseUser;
329328
/// use cot::common_types::Password;
329+
/// use cot::db::Database;
330330
/// use cot::html::Html;
331-
/// use cot::request::extractors::RequestDb;
332331
///
333-
/// async fn view(RequestDb(db): RequestDb) -> cot::Result<Html> {
332+
/// async fn view(db: Database) -> cot::Result<Html> {
334333
/// let user =
335334
/// DatabaseUser::create_user(&db, "testuser".to_string(), &Password::new("password123"))
336335
/// .await?;
@@ -469,7 +468,7 @@ impl DatabaseUserCredentials {
469468
/// [`DatabaseUserCredentials`] struct and ignores all other credential types.
470469
#[derive(Debug, Clone)]
471470
pub struct DatabaseUserBackend {
472-
database: Arc<Database>,
471+
database: Database,
473472
}
474473

475474
impl DatabaseUserBackend {
@@ -495,7 +494,7 @@ impl DatabaseUserBackend {
495494
/// }
496495
/// ```
497496
#[must_use]
498-
pub fn new(database: Arc<Database>) -> Self {
497+
pub fn new(database: Database) -> Self {
499498
Self { database }
500499
}
501500
}

cot/src/db.rs

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod sea_query_db;
1818
use std::fmt::{Display, Formatter, Write};
1919
use std::hash::Hash;
2020
use std::str::FromStr;
21+
use std::sync::Arc;
2122

2223
use async_trait::async_trait;
2324
pub use cot_macros::{model, query};
@@ -789,10 +790,9 @@ pub trait SqlxValueRef<'r>: Sized {
789790
/// It is used to execute queries and interact with the database. The connection
790791
/// is established when the structure is created and closed when
791792
/// [`Self::close()`] is called.
792-
#[derive(Debug)]
793+
#[derive(Debug, Clone)]
793794
pub struct Database {
794-
_url: String,
795-
inner: DatabaseImpl,
795+
inner: Arc<DatabaseImpl>,
796796
}
797797

798798
#[derive(Debug)]
@@ -837,26 +837,23 @@ impl Database {
837837
if url.starts_with("sqlite:") {
838838
let inner = DatabaseSqlite::new(&url).await?;
839839
return Ok(Self {
840-
_url: url,
841-
inner: DatabaseImpl::Sqlite(inner),
840+
inner: Arc::new(DatabaseImpl::Sqlite(inner)),
842841
});
843842
}
844843

845844
#[cfg(feature = "postgres")]
846845
if url.starts_with("postgresql:") {
847846
let inner = DatabasePostgres::new(&url).await?;
848847
return Ok(Self {
849-
_url: url,
850-
inner: DatabaseImpl::Postgres(inner),
848+
inner: Arc::new(DatabaseImpl::Postgres(inner)),
851849
});
852850
}
853851

854852
#[cfg(feature = "mysql")]
855853
if url.starts_with("mysql:") {
856854
let inner = DatabaseMySql::new(&url).await?;
857855
return Ok(Self {
858-
_url: url,
859-
inner: DatabaseImpl::MySql(inner),
856+
inner: Arc::new(DatabaseImpl::MySql(inner)),
860857
});
861858
}
862859

@@ -886,7 +883,7 @@ impl Database {
886883
/// }
887884
/// ```
888885
pub async fn close(&self) -> Result<()> {
889-
match &self.inner {
886+
match &*self.inner {
890887
#[cfg(feature = "sqlite")]
891888
DatabaseImpl::Sqlite(inner) => inner.close().await,
892889
#[cfg(feature = "postgres")]
@@ -1124,7 +1121,7 @@ impl Database {
11241121
return Ok(());
11251122
}
11261123

1127-
let max_params = match self.inner {
1124+
let max_params = match &*self.inner {
11281125
// https://sqlite.org/limits.html#max_variable_number
11291126
// Assuming SQLite > 3.32.0 (2020-05-22)
11301127
#[cfg(feature = "sqlite")]
@@ -1471,7 +1468,7 @@ impl Database {
14711468
.collect::<Vec<_>>();
14721469
let values = SqlxValues(sea_query::Values(values));
14731470

1474-
let result = match &self.inner {
1471+
let result = match &*self.inner {
14751472
#[cfg(feature = "sqlite")]
14761473
DatabaseImpl::Sqlite(inner) => inner.raw_with(query, values).await?,
14771474
#[cfg(feature = "postgres")]
@@ -1487,7 +1484,7 @@ impl Database {
14871484
where
14881485
T: SqlxBinder + Send + Sync,
14891486
{
1490-
let result = match &self.inner {
1487+
let result = match &*self.inner {
14911488
#[cfg(feature = "sqlite")]
14921489
DatabaseImpl::Sqlite(inner) => inner.fetch_option(statement).await?.map(Row::Sqlite),
14931490
#[cfg(feature = "postgres")]
@@ -1502,7 +1499,7 @@ impl Database {
15021499
}
15031500

15041501
fn supports_returning(&self) -> bool {
1505-
match self.inner {
1502+
match &*self.inner {
15061503
#[cfg(feature = "sqlite")]
15071504
DatabaseImpl::Sqlite(_) => true,
15081505
#[cfg(feature = "postgres")]
@@ -1516,7 +1513,7 @@ impl Database {
15161513
where
15171514
T: SqlxBinder + Send + Sync,
15181515
{
1519-
let result = match &self.inner {
1516+
let result = match &*self.inner {
15201517
#[cfg(feature = "sqlite")]
15211518
DatabaseImpl::Sqlite(inner) => inner
15221519
.fetch_all(statement)
@@ -1547,7 +1544,7 @@ impl Database {
15471544
where
15481545
T: SqlxBinder + Send + Sync,
15491546
{
1550-
let result = match &self.inner {
1547+
let result = match &*self.inner {
15511548
#[cfg(feature = "sqlite")]
15521549
DatabaseImpl::Sqlite(inner) => inner.execute_statement(statement).await?,
15531550
#[cfg(feature = "postgres")]
@@ -1563,7 +1560,7 @@ impl Database {
15631560
&self,
15641561
statement: T,
15651562
) -> Result<StatementResult> {
1566-
let result = match &self.inner {
1563+
let result = match &*self.inner {
15671564
#[cfg(feature = "sqlite")]
15681565
DatabaseImpl::Sqlite(inner) => inner.execute_schema(statement).await?,
15691566
#[cfg(feature = "postgres")]
@@ -1578,7 +1575,7 @@ impl Database {
15781575

15791576
impl ColumnTypeMapper for Database {
15801577
fn sea_query_column_type_for(&self, column_type: ColumnType) -> sea_query::ColumnType {
1581-
match &self.inner {
1578+
match &*self.inner {
15821579
#[cfg(feature = "sqlite")]
15831580
DatabaseImpl::Sqlite(inner) => inner.sea_query_column_type_for(column_type),
15841581
#[cfg(feature = "postgres")]
@@ -1735,45 +1732,6 @@ impl DatabaseBackend for Database {
17351732
}
17361733
}
17371734

1738-
#[async_trait]
1739-
impl DatabaseBackend for std::sync::Arc<Database> {
1740-
async fn insert_or_update<T: Model>(&self, data: &mut T) -> Result<()> {
1741-
Database::insert_or_update(self, data).await
1742-
}
1743-
1744-
async fn insert<T: Model>(&self, data: &mut T) -> Result<()> {
1745-
Database::insert(self, data).await
1746-
}
1747-
1748-
async fn update<T: Model>(&self, data: &mut T) -> Result<()> {
1749-
Database::update(self, data).await
1750-
}
1751-
1752-
async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
1753-
Database::bulk_insert(self, data).await
1754-
}
1755-
1756-
async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
1757-
Database::bulk_insert_or_update(self, data).await
1758-
}
1759-
1760-
async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
1761-
Database::query(self, query).await
1762-
}
1763-
1764-
async fn get<T: Model>(&self, query: &Query<T>) -> Result<Option<T>> {
1765-
Database::get(self, query).await
1766-
}
1767-
1768-
async fn exists<T: Model>(&self, query: &Query<T>) -> Result<bool> {
1769-
Database::exists(self, query).await
1770-
}
1771-
1772-
async fn delete<T: Model>(&self, query: &Query<T>) -> Result<StatementResult> {
1773-
Database::delete(self, query).await
1774-
}
1775-
}
1776-
17771735
/// Result of a statement execution.
17781736
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
17791737
pub struct StatementResult {

cot/src/openapi.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ impl ApiOperationPart for Method {}
837837
impl ApiOperationPart for Session {}
838838
impl ApiOperationPart for Auth {}
839839
#[cfg(feature = "db")]
840-
impl ApiOperationPart for crate::request::extractors::RequestDb {}
840+
impl ApiOperationPart for crate::db::Database {}
841841

842842
impl<D: JsonSchema> ApiOperationPart for Json<D> {
843843
fn modify_api_operation(

cot/src/project.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,11 +1224,11 @@ impl Bootstrapper<WithApps> {
12241224
}
12251225

12261226
#[cfg(feature = "db")]
1227-
async fn init_database(config: &DatabaseConfig) -> cot::Result<Option<Arc<Database>>> {
1227+
async fn init_database(config: &DatabaseConfig) -> cot::Result<Option<Database>> {
12281228
match &config.url {
12291229
Some(url) => {
12301230
let database = Database::new(url.as_str()).await?;
1231-
Ok(Some(Arc::new(database)))
1231+
Ok(Some(database))
12321232
}
12331233
None => Ok(None),
12341234
}
@@ -1626,7 +1626,7 @@ impl BootstrapPhase for WithDatabase {
16261626
type Apps = <WithApps as BootstrapPhase>::Apps;
16271627
type Router = <WithApps as BootstrapPhase>::Router;
16281628
#[cfg(feature = "db")]
1629-
type Database = Option<Arc<Database>>;
1629+
type Database = Option<Database>;
16301630
type AuthBackend = <WithApps as BootstrapPhase>::AuthBackend;
16311631
#[cfg(feature = "cache")]
16321632
type Cache = ();
@@ -1649,7 +1649,7 @@ impl BootstrapPhase for WithCache {
16491649
type Apps = <WithApps as BootstrapPhase>::Apps;
16501650
type Router = <WithApps as BootstrapPhase>::Router;
16511651
#[cfg(feature = "db")]
1652-
type Database = Option<Arc<Database>>;
1652+
type Database = <WithDatabase as BootstrapPhase>::Database;
16531653
type AuthBackend = <WithApps as BootstrapPhase>::AuthBackend;
16541654
#[cfg(feature = "cache")]
16551655
type Cache = Cache;
@@ -1791,7 +1791,7 @@ impl ProjectContext<WithApps> {
17911791
#[must_use]
17921792
fn with_database(
17931793
self,
1794-
#[cfg(feature = "db")] database: Option<Arc<Database>>,
1794+
#[cfg(feature = "db")] database: Option<Database>,
17951795
) -> ProjectContext<WithDatabase> {
17961796
ProjectContext {
17971797
config: self.config,
@@ -1931,7 +1931,7 @@ impl<S: BootstrapPhase<Cache = Cache>> ProjectContext<S> {
19311931
}
19321932

19331933
#[cfg(feature = "db")]
1934-
impl<S: BootstrapPhase<Database = Option<Arc<Database>>>> ProjectContext<S> {
1934+
impl<S: BootstrapPhase<Database = Option<Database>>> ProjectContext<S> {
19351935
/// Returns the database for the project, if it is enabled.
19361936
///
19371937
/// # Examples
@@ -1952,7 +1952,7 @@ impl<S: BootstrapPhase<Database = Option<Arc<Database>>>> ProjectContext<S> {
19521952
/// ```
19531953
#[must_use]
19541954
#[cfg(feature = "db")]
1955-
pub fn try_database(&self) -> Option<&Arc<Database>> {
1955+
pub fn try_database(&self) -> Option<&Database> {
19561956
self.database.as_ref()
19571957
}
19581958

@@ -1980,7 +1980,7 @@ impl<S: BootstrapPhase<Database = Option<Arc<Database>>>> ProjectContext<S> {
19801980
#[cfg(feature = "db")]
19811981
#[must_use]
19821982
#[track_caller]
1983-
pub fn database(&self) -> &Arc<Database> {
1983+
pub fn database(&self) -> &Database {
19841984
self.try_database().expect(
19851985
"Database missing. Did you forget to add the database when configuring CotProject?",
19861986
)

cot/src/request.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ pub trait RequestExt: private::Sealed {
207207
/// ```
208208
#[cfg(feature = "db")]
209209
#[must_use]
210-
fn db(&self) -> &Arc<Database>;
210+
fn db(&self) -> &Database;
211211

212212
/// Get the content type of the request.
213213
///
@@ -318,7 +318,7 @@ impl RequestExt for Request {
318318
}
319319

320320
#[cfg(feature = "db")]
321-
fn db(&self) -> &Arc<Database> {
321+
fn db(&self) -> &Database {
322322
self.context().database()
323323
}
324324

@@ -378,7 +378,7 @@ impl RequestExt for RequestHead {
378378
}
379379

380380
#[cfg(feature = "db")]
381-
fn db(&self) -> &Arc<Database> {
381+
fn db(&self) -> &Database {
382382
self.context().database()
383383
}
384384

0 commit comments

Comments
 (0)