Skip to content

Commit cde2294

Browse files
committed
initial commit
0 parents  commit cde2294

File tree

4 files changed

+200
-0
lines changed

4 files changed

+200
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/target
2+
Cargo.lock

Cargo.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "async-sqlx-session"
3+
version = "0.1.0"
4+
authors = ["Jacob Rothstein <[email protected]>"]
5+
edition = "2018"
6+
7+
[features]
8+
default = ["sqlite"]
9+
sqlite = ["sqlx/sqlite"]
10+
11+
[dependencies]
12+
async-session = "1.0.2"
13+
sqlx = { version = "0.3.5" }
14+
chrono = "0.4.13"
15+
16+
[patch.crates-io]
17+
async-session = { git = "https://github.com/jbr/async-session", branch = "tide" }

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[cfg(feature = "sqlite")]
2+
mod sqlite;
3+
#[cfg(feature = "sqlite")]
4+
pub use sqlite::SqliteStore;

src/sqlite.rs

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
use async_session::{async_trait, base64, log, serde_json, Session, SessionStore};
2+
use chrono::Utc;
3+
use sqlx::prelude::*;
4+
use sqlx::sqlite::SqlitePool;
5+
6+
#[derive(Clone, Debug)]
7+
pub struct SqliteStore {
8+
client: SqlitePool,
9+
ttl: chrono::Duration,
10+
prefix: Option<String>,
11+
table_name: String,
12+
}
13+
14+
impl SqliteStore {
15+
pub fn from_client(client: SqlitePool) -> Self {
16+
Self {
17+
client,
18+
table_name: "async_sessions".into(),
19+
ttl: chrono::Duration::days(1),
20+
prefix: None,
21+
}
22+
}
23+
24+
pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
25+
Ok(Self::from_client(SqlitePool::new(database_url).await?))
26+
}
27+
28+
pub fn with_table_name(mut self, table_name: String) -> Self {
29+
if table_name.chars().any(|c| c.is_ascii_alphanumeric()) {
30+
panic!(
31+
"table name must be alphanumeric, but {} was not",
32+
table_name
33+
);
34+
}
35+
36+
self.table_name = table_name;
37+
self
38+
}
39+
40+
pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
41+
self.ttl = chrono::Duration::from_std(ttl).unwrap();
42+
self
43+
}
44+
45+
pub fn expiry(&self) -> i64 {
46+
(Utc::now() + self.ttl).timestamp()
47+
}
48+
49+
pub async fn migrate(&self) -> sqlx::Result<()> {
50+
log::info!("migrating sessions on `{}`", self.table_name);
51+
52+
let mut conn = self.client.acquire().await?;
53+
sqlx::query(&self.substitute_table_name(
54+
r#"
55+
CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
56+
id TEXT PRIMARY KEY NOT NULL,
57+
expires DATETIME NOT NULL,
58+
session TEXT NOT NULL
59+
)
60+
"#,
61+
))
62+
.execute(&mut conn)
63+
.await?;
64+
Ok(())
65+
}
66+
67+
fn substitute_table_name(&self, query: &str) -> String {
68+
query.replace("%%TABLE_NAME%%", &self.table_name)
69+
}
70+
71+
async fn connection(&self) -> sqlx::Result<sqlx::pool::PoolConnection<sqlx::SqliteConnection>> {
72+
self.client.acquire().await
73+
}
74+
}
75+
76+
#[async_trait]
77+
impl SessionStore for SqliteStore {
78+
type Error = sqlx::Error;
79+
80+
async fn load_session(&self, cookie_value: String) -> Option<Session> {
81+
let id = Session::id_from_cookie_value(&cookie_value).ok()?;
82+
let mut connection = self.connection().await.ok()?;
83+
84+
let (session,): (String,) = sqlx::query_as(&self.substitute_table_name(
85+
r#"
86+
SELECT session FROM %%TABLE_NAME%%
87+
WHERE id = ? AND expires > ?
88+
"#,
89+
))
90+
.bind(&id)
91+
.bind(Utc::now().timestamp())
92+
.fetch_one(&mut connection)
93+
.await
94+
.ok()?;
95+
96+
serde_json::from_str(&session).ok()?
97+
}
98+
99+
async fn store_session(&self, mut session: Session) -> Option<String> {
100+
let id = session.id();
101+
let string = serde_json::to_string(&session).ok()?;
102+
let mut connection = self.connection().await.ok()?;
103+
104+
sqlx::query(&self.substitute_table_name(
105+
r#"
106+
INSERT INTO %%TABLE_NAME%%
107+
(id, expires, session) VALUES (?, ?, ?)
108+
ON CONFLICT(id) DO UPDATE SET
109+
expires = excluded.expires,
110+
session = excluded.session
111+
"#,
112+
))
113+
.bind(&id)
114+
.bind(self.expiry())
115+
.bind(&string)
116+
.execute(&mut connection)
117+
.await
118+
.unwrap();
119+
120+
session.take_cookie_value()
121+
}
122+
123+
async fn destroy_session(&self, session: Session) -> Result<(), Self::Error> {
124+
let id = session.id();
125+
let mut connection = self.connection().await?;
126+
sqlx::query(&self.substitute_table_name(
127+
r#"
128+
DELETE FROM %%TABLE_NAME%% WHERE id = ?
129+
"#,
130+
))
131+
.bind(&id)
132+
.execute(&mut connection)
133+
.await?;
134+
Ok(())
135+
}
136+
137+
async fn clear_store(&self) -> Result<(), Self::Error> {
138+
let mut connection = self.connection().await?;
139+
sqlx::query(&self.substitute_table_name(
140+
r#"
141+
DELETE FROM %%TABLE_NAME%%
142+
"#,
143+
))
144+
.execute(&mut connection)
145+
.await?;
146+
Ok(())
147+
}
148+
}
149+
150+
#[derive(Debug)]
151+
pub enum Error {
152+
SqlxError(sqlx::Error),
153+
SerdeError(serde_json::Error),
154+
}
155+
156+
impl std::fmt::Display for Error {
157+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158+
match self {
159+
Error::SqlxError(e) => e.fmt(f),
160+
Error::SerdeError(e) => e.fmt(f),
161+
}
162+
}
163+
}
164+
165+
impl From<serde_json::Error> for Error {
166+
fn from(e: serde_json::Error) -> Self {
167+
Self::SerdeError(e)
168+
}
169+
}
170+
171+
impl From<sqlx::Error> for Error {
172+
fn from(e: sqlx::Error) -> Self {
173+
Self::SqlxError(e)
174+
}
175+
}
176+
177+
impl std::error::Error for Error {}

0 commit comments

Comments
 (0)