|
| 1 | +use actix_web::web::Data; |
| 2 | +use actix_web::{FromRequest, HttpMessage}; |
| 3 | +use derive_more::{Deref, DerefMut}; |
| 4 | +use diesel::connection::{ |
| 5 | + AnsiTransactionManager, Connection, ConnectionSealed, DefaultLoadingMode, |
| 6 | + LoadConnection, SimpleConnection, TransactionManager, |
| 7 | +}; |
| 8 | +use diesel::pg::{Pg, PgQueryBuilder}; |
| 9 | +use diesel::r2d2::{ConnectionManager, PooledConnection}; |
| 10 | +use diesel::PgConnection; |
| 11 | +use diesel::RunQueryDsl; |
| 12 | + |
| 13 | +use crate::service::types::{AppState, SchemaName}; |
| 14 | + |
| 15 | +pub struct TransactionManagerImpl; |
| 16 | + |
| 17 | +impl TransactionManager<ConnectionImpl> for TransactionManagerImpl { |
| 18 | + type TransactionStateData = <AnsiTransactionManager as TransactionManager< |
| 19 | + PgConnection, |
| 20 | + >>::TransactionStateData; |
| 21 | + |
| 22 | + fn begin_transaction(conn: &mut ConnectionImpl) -> diesel::prelude::QueryResult<()> { |
| 23 | + AnsiTransactionManager::begin_transaction(&mut *conn.conn)?; |
| 24 | + let result = diesel::sql_query("SELECT set_config('search_path', $1, true)") |
| 25 | + .bind::<diesel::sql_types::Text, _>(&conn.namespace) |
| 26 | + .execute(&mut *conn.conn)?; |
| 27 | + log::info!("{:?}", result); |
| 28 | + Ok(()) |
| 29 | + } |
| 30 | + |
| 31 | + fn rollback_transaction( |
| 32 | + conn: &mut ConnectionImpl, |
| 33 | + ) -> diesel::prelude::QueryResult<()> { |
| 34 | + AnsiTransactionManager::rollback_transaction(&mut *conn.conn) |
| 35 | + } |
| 36 | + |
| 37 | + fn commit_transaction(conn: &mut ConnectionImpl) -> diesel::prelude::QueryResult<()> { |
| 38 | + AnsiTransactionManager::commit_transaction(&mut *conn.conn) |
| 39 | + } |
| 40 | + |
| 41 | + fn transaction_manager_status_mut( |
| 42 | + conn: &mut ConnectionImpl, |
| 43 | + ) -> &mut diesel::connection::TransactionManagerStatus { |
| 44 | + AnsiTransactionManager::transaction_manager_status_mut(&mut *conn.conn) |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +pub struct ConnectionImpl { |
| 49 | + namespace: String, |
| 50 | + conn: PooledConnection<ConnectionManager<PgConnection>>, |
| 51 | +} |
| 52 | + |
| 53 | +impl ConnectionImpl { |
| 54 | + pub fn new( |
| 55 | + namespace: String, |
| 56 | + mut conn: PooledConnection<ConnectionManager<PgConnection>>, |
| 57 | + ) -> Self { |
| 58 | + conn.set_prepared_statement_cache_size(diesel::connection::CacheSize::Disabled); |
| 59 | + ConnectionImpl { namespace, conn } |
| 60 | + } |
| 61 | + |
| 62 | + pub fn set_namespace(&mut self, namespace: String) { |
| 63 | + self.namespace = namespace; |
| 64 | + } |
| 65 | + |
| 66 | + pub fn from_request_override( |
| 67 | + req: &actix_web::HttpRequest, |
| 68 | + schema_name: String, |
| 69 | + ) -> Result<Self, actix_web::Error> { |
| 70 | + let app_state = match req.app_data::<Data<AppState>>() { |
| 71 | + Some(state) => state, |
| 72 | + None => { |
| 73 | + log::info!( |
| 74 | + "DbConnection-FromRequest: Unable to get app_data from request" |
| 75 | + ); |
| 76 | + return Err(actix_web::error::ErrorInternalServerError("")); |
| 77 | + } |
| 78 | + }; |
| 79 | + |
| 80 | + match app_state.db_pool.get() { |
| 81 | + Ok(conn) => Ok(ConnectionImpl::new(schema_name, conn)), |
| 82 | + Err(e) => { |
| 83 | + log::info!("Unable to get db connection from pool, error: {e}"); |
| 84 | + Err(actix_web::error::ErrorInternalServerError("")) |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +impl ConnectionSealed for ConnectionImpl {} |
| 91 | + |
| 92 | +impl SimpleConnection for ConnectionImpl { |
| 93 | + fn batch_execute(&mut self, query: &str) -> diesel::prelude::QueryResult<()> { |
| 94 | + self.conn.batch_execute(query) |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +impl Connection for ConnectionImpl { |
| 99 | + type Backend = Pg; |
| 100 | + type TransactionManager = TransactionManagerImpl; |
| 101 | + |
| 102 | + // NOTE: this function will never be used, so namespace here doesn't matter |
| 103 | + fn establish(database_url: &str) -> diesel::prelude::ConnectionResult<Self> { |
| 104 | + let conn = PooledConnection::establish(database_url)?; |
| 105 | + Ok(ConnectionImpl { |
| 106 | + namespace: String::new(), |
| 107 | + conn, |
| 108 | + }) |
| 109 | + } |
| 110 | + |
| 111 | + fn execute_returning_count<T>( |
| 112 | + &mut self, |
| 113 | + source: &T, |
| 114 | + ) -> diesel::prelude::QueryResult<usize> |
| 115 | + where |
| 116 | + T: diesel::query_builder::QueryFragment<Self::Backend> |
| 117 | + + diesel::query_builder::QueryId, |
| 118 | + { |
| 119 | + log::info!("{:?}", source.to_sql(&mut PgQueryBuilder::default(), &Pg)); |
| 120 | + self.transaction::<usize, diesel::result::Error, _>(|conn| { |
| 121 | + (*conn.conn).execute_returning_count(source) |
| 122 | + }) |
| 123 | + } |
| 124 | + |
| 125 | + fn transaction_state(&mut self,) -> &mut<Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{ |
| 126 | + self.conn.transaction_state() |
| 127 | + } |
| 128 | + |
| 129 | + fn set_prepared_statement_cache_size(&mut self, size: diesel::connection::CacheSize) { |
| 130 | + self.conn.set_prepared_statement_cache_size(size) |
| 131 | + } |
| 132 | + |
| 133 | + fn set_instrumentation( |
| 134 | + &mut self, |
| 135 | + instrumentation: impl diesel::connection::Instrumentation, |
| 136 | + ) { |
| 137 | + self.conn.set_instrumentation(instrumentation) |
| 138 | + } |
| 139 | + |
| 140 | + fn instrumentation(&mut self) -> &mut dyn diesel::connection::Instrumentation { |
| 141 | + self.conn.instrumentation() |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +impl LoadConnection<DefaultLoadingMode> for ConnectionImpl { |
| 146 | + type Cursor<'conn, 'query> = |
| 147 | + <PgConnection as LoadConnection<DefaultLoadingMode>>::Cursor<'conn, 'query>; |
| 148 | + type Row<'conn, 'query> = |
| 149 | + <PgConnection as LoadConnection<DefaultLoadingMode>>::Row<'conn, 'query>; |
| 150 | + |
| 151 | + fn load<'conn, 'query, T>( |
| 152 | + &'conn mut self, |
| 153 | + source: T, |
| 154 | + ) -> diesel::prelude::QueryResult<Self::Cursor<'conn, 'query>> |
| 155 | + where |
| 156 | + T: diesel::query_builder::Query |
| 157 | + + diesel::query_builder::QueryFragment<Self::Backend> |
| 158 | + + diesel::query_builder::QueryId |
| 159 | + + 'query, |
| 160 | + Self::Backend: diesel::expression::QueryMetadata<T::SqlType>, |
| 161 | + { |
| 162 | + self.transaction::<Self::Cursor<'conn, 'query>, diesel::result::Error, _>( |
| 163 | + |conn| { |
| 164 | + log::info!("{:?}", source.to_sql(&mut PgQueryBuilder::default(), &Pg)); |
| 165 | + <PgConnection as LoadConnection<DefaultLoadingMode>>::load::<T>( |
| 166 | + &mut *conn.conn, |
| 167 | + source, |
| 168 | + ) |
| 169 | + }, |
| 170 | + ) |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +impl FromRequest for ConnectionImpl { |
| 175 | + type Error = actix_web::Error; |
| 176 | + type Future = std::future::Ready<Result<ConnectionImpl, Self::Error>>; |
| 177 | + |
| 178 | + fn from_request( |
| 179 | + req: &actix_web::HttpRequest, |
| 180 | + _: &mut actix_web::dev::Payload, |
| 181 | + ) -> Self::Future { |
| 182 | + let schema_name = req.extensions().get::<SchemaName>().cloned().unwrap().0; |
| 183 | + std::future::ready(ConnectionImpl::from_request_override(req, schema_name)) |
| 184 | + } |
| 185 | +} |
| 186 | + |
| 187 | +#[derive(Deref, DerefMut)] |
| 188 | +pub struct PublicConnection(pub ConnectionImpl); |
| 189 | +impl FromRequest for PublicConnection { |
| 190 | + type Error = actix_web::Error; |
| 191 | + type Future = std::future::Ready<Result<PublicConnection, Self::Error>>; |
| 192 | + |
| 193 | + fn from_request( |
| 194 | + req: &actix_web::HttpRequest, |
| 195 | + _: &mut actix_web::dev::Payload, |
| 196 | + ) -> Self::Future { |
| 197 | + std::future::ready( |
| 198 | + ConnectionImpl::from_request_override(req, String::from("public")) |
| 199 | + .map(|conn| PublicConnection(conn)), |
| 200 | + ) |
| 201 | + } |
| 202 | +} |
| 203 | + |
| 204 | +#[derive(Deref, DerefMut)] |
| 205 | +pub struct SuperpositionConnection(pub ConnectionImpl); |
| 206 | +impl FromRequest for SuperpositionConnection { |
| 207 | + type Error = actix_web::Error; |
| 208 | + type Future = std::future::Ready<Result<SuperpositionConnection, Self::Error>>; |
| 209 | + |
| 210 | + fn from_request( |
| 211 | + req: &actix_web::HttpRequest, |
| 212 | + _: &mut actix_web::dev::Payload, |
| 213 | + ) -> Self::Future { |
| 214 | + std::future::ready( |
| 215 | + ConnectionImpl::from_request_override(req, String::from("superposition")) |
| 216 | + .map(|conn| SuperpositionConnection(conn)), |
| 217 | + ) |
| 218 | + } |
| 219 | +} |
0 commit comments