Skip to content

Commit 98e8421

Browse files
committed
core: Replacing async_session with tower_sessions
1 parent 36f6b4c commit 98e8421

File tree

8 files changed

+366
-462
lines changed

8 files changed

+366
-462
lines changed

Cargo.lock

Lines changed: 156 additions & 204 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ edition = "2021"
77
anyhow = "1.0.95"
88
askama = { version = "0.12.1", features = ["with-axum"] }
99
askama_axum = "0.4.0"
10-
async-session = "3.0.0"
1110
async-trait = "0.1.83"
1211
axum = { version = "0.8.1", features = ["macros"] }
1312
axum-extra = { version = "0.10.0", features = ["cookie", "multipart", "typed-header"] }
1413
dotenv = "0.15.0"
1514
headers = "0.4.0"
15+
http = "1.2.0"
1616
oauth2 = "4.4.2"
1717
once_cell = "1.20.2"
1818
openidconnect = { version = "3.5.0", features = ["reqwest"] }
@@ -27,6 +27,8 @@ sqlx = { version = "0.8.2", features = ["runtime-tokio-native-tls", "sqlite", "p
2727
time = "0.3.37"
2828
tokio = { version = "1.42.0", features = ["rt", "rt-multi-thread", "macros", "signal"] }
2929
tower-http = { version = "0.6.2", features = ["fs", "trace"] }
30+
tower-sessions = { version = "0.14.0", features = ["memory-store"] }
31+
#tower-sessions-memory-store = { version = "0.14.0", features = [] }
3032
tracing = { version = "0.1.41", features = ["std", "log"] }
3133
tracing-opentelemetry = { version = "0.28.0", features = [] }
3234
tracing-subscriber = { version = "0.3.19", features = ["env-filter", "registry", "fmt"] }

backend/src/html.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::model::*;
2-
use crate::{AppError, ServerImpl, SessionUser};
2+
use crate::session::SkjeraSessionData;
3+
use crate::{AppError, ServerImpl};
34
use anyhow::Context;
45
use askama_axum::Template;
56
use axum::extract::{Path, State};
@@ -43,7 +44,7 @@ struct MeTemplate<'a> {
4344
#[tracing::instrument]
4445
pub async fn get_me(
4546
State(app): State<ServerImpl>,
46-
user: SessionUser,
47+
user: SkjeraSessionData,
4748
) -> Result<Html<String>, AppError> {
4849
let me = app
4950
.employee_dao
@@ -79,7 +80,7 @@ pub(crate) struct MeForm {
7980
#[tracing::instrument]
8081
pub async fn post_me(
8182
State(app): State<ServerImpl>,
82-
user: SessionUser,
83+
user: SkjeraSessionData,
8384
Form(input): Form<MeForm>,
8485
) -> Result<Redirect, AppError> {
8586
debug!("form: {:?}", input);
@@ -109,7 +110,7 @@ pub async fn post_me(
109110

110111
pub async fn delete_some_account(
111112
State(app): State<ServerImpl>,
112-
user: SessionUser,
113+
user: SkjeraSessionData,
113114
Path(some_account_id): Path<SomeAccountId>,
114115
) -> Result<Redirect, AppError> {
115116
info!(
@@ -138,7 +139,7 @@ pub(crate) struct AddSomeAccountForm {
138139

139140
pub async fn add_some_account(
140141
State(app): State<ServerImpl>,
141-
user: SessionUser,
142+
user: SkjeraSessionData,
142143
Form(input): Form<AddSomeAccountForm>,
143144
) -> Result<Redirect, AppError> {
144145
let _span = span!(Level::INFO, "add_some_account");
@@ -193,7 +194,7 @@ impl EmployeeTemplate {
193194
#[tracing::instrument]
194195
pub async fn employee(
195196
State(app): State<ServerImpl>,
196-
_user: SessionUser,
197+
_user: SkjeraSessionData,
197198
Path(employee_id): Path<EmployeeId>,
198199
) -> Result<Html<String>, AppError> {
199200
let employee = app
@@ -210,21 +211,22 @@ pub async fn employee(
210211
#[derive(Template)]
211212
#[template(path = "hello.html"/*, print = "all"*/)]
212213
struct HelloTemplate {
213-
pub user: Option<SessionUser>,
214+
pub user: SkjeraSessionData,
214215
pub google_auth_url: Option<String>,
215216
pub employees: Option<Vec<Employee>>,
216217
}
217218

218-
#[tracing::instrument]
219219
pub async fn hello_world(
220220
State(app): State<ServerImpl>,
221-
user: Option<SessionUser>,
221+
session: SkjeraSessionData,
222222
) -> Result<Html<String>, AppError> {
223+
let _span = span!(Level::INFO, "hello_world");
224+
223225
let scope = "openid profile email";
224226

225227
let mut employees = None::<Vec<Employee>>;
226228
let mut url = None::<String>;
227-
if user.is_some() {
229+
if session.authenticated() {
228230
employees = Some(app.employee_dao.employees().await?);
229231
} else {
230232
let u = Url::parse_with_params(
@@ -241,7 +243,7 @@ pub async fn hello_world(
241243
}
242244

243245
let template = HelloTemplate {
244-
user,
246+
user: session,
245247
google_auth_url: url,
246248
employees,
247249
};

backend/src/main.rs

Lines changed: 21 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,17 @@ mod macros;
44
mod meta;
55
mod model;
66
mod oauth;
7+
mod session;
78
#[cfg(any())]
89
mod skjera;
910
mod slack;
1011
mod web;
1112

1213
use crate::model::*;
1314
use crate::slack::SlackConnect;
14-
use anyhow::{anyhow, Error};
15-
use async_session::{MemoryStore, Session, SessionStore};
16-
use axum::extract::{FromRef, FromRequestParts, OptionalFromRequestParts};
17-
use axum::http::request::Parts;
18-
use axum::http::{header, StatusCode};
15+
use axum::http::StatusCode;
1916
use axum::response::{IntoResponse, Redirect, Response};
20-
use axum::RequestPartsExt;
21-
use axum_extra::typed_header::TypedHeaderRejectionReason;
22-
use axum_extra::TypedHeader;
23-
use headers;
2417
use oauth2::basic::BasicClient;
25-
use oauth2::{CsrfToken, PkceCodeVerifier};
26-
use openidconnect::Nonce;
2718
use opentelemetry::trace::TracerProvider as _;
2819
use opentelemetry::{global, KeyValue};
2920
use opentelemetry_appender_tracing::layer;
@@ -32,20 +23,18 @@ use opentelemetry_sdk::logs::LoggerProvider;
3223
use opentelemetry_sdk::trace::TracerProvider;
3324
use opentelemetry_sdk::{runtime, Resource};
3425
use reqwest::Client as ReqwestClient;
35-
use serde::{Deserialize, Serialize};
3626
use sqlx::postgres::PgConnectOptions;
3727
use std::env;
3828
use std::process::exit;
3929
use tokio::net::TcpListener;
4030
use tokio::signal;
4131
use tower_http::trace::TraceLayer;
32+
use tower_sessions::cookie::SameSite::Lax;
33+
use tower_sessions::{MemoryStore, SessionManagerLayer, SessionStore};
4234
use tracing::{debug, info, warn};
4335
use tracing_subscriber::prelude::*;
4436
use tracing_subscriber::EnvFilter;
4537

46-
pub(crate) static COOKIE_NAME: &str = "SESSION";
47-
const USER_SESSION_KEY: &'static str = "user";
48-
4938
#[tokio::main]
5039
async fn main() {
5140
// We don't care if there is a problem here
@@ -115,7 +104,6 @@ async fn main() {
115104
cfg,
116105
basic_client,
117106
employee_dao: EmployeeDao::new(pool),
118-
store: MemoryStore::new(),
119107
slack_connect,
120108
};
121109

@@ -125,7 +113,13 @@ async fn main() {
125113
// info!(name: "my-event-name", target: "my-system", event_id = 20, user_name = "otel", user_email = "otel@opentelemetry.io", message = "This is an example message");
126114
// });
127115

128-
start_server(server_impl, "0.0.0.0:8080").await;
116+
let session_store = MemoryStore::default();
117+
let session_layer = SessionManagerLayer::new(session_store)
118+
.with_secure(true)
119+
.with_http_only(true)
120+
.with_same_site(Lax);
121+
122+
start_server(server_impl, session_layer, "0.0.0.0:8080").await;
129123

130124
providers.0.shutdown().unwrap();
131125
providers.1.shutdown().unwrap();
@@ -198,21 +192,20 @@ struct ServerImpl {
198192
ctx: ReqwestClient,
199193
cfg: Config,
200194
basic_client: BasicClient,
201-
store: MemoryStore,
202195
pub employee_dao: EmployeeDao,
203196
pub slack_connect: Option<SlackConnect>,
204197
}
205198

206-
impl FromRef<ServerImpl> for MemoryStore {
207-
fn from_ref(state: &ServerImpl) -> Self {
208-
state.store.clone()
209-
}
210-
}
211-
212-
async fn start_server(server_impl: ServerImpl, addr: &str) {
213-
let app = web::create_router(server_impl);
214-
215-
let app = app.layer(TraceLayer::new_for_http());
199+
async fn start_server<SS>(
200+
server_impl: ServerImpl,
201+
session_layer: SessionManagerLayer<SS>,
202+
addr: &str,
203+
) where
204+
SS: SessionStore + Clone,
205+
{
206+
let app = web::create_router(server_impl)
207+
.layer(session_layer)
208+
.layer(TraceLayer::new_for_http());
216209

217210
// Run the server with graceful shutdown
218211
let listener = TcpListener::bind(addr).await.unwrap();
@@ -332,136 +325,3 @@ impl IntoResponse for AuthRedirect {
332325
Redirect::temporary("/").into_response()
333326
}
334327
}
335-
336-
#[derive(Debug, Deserialize, Serialize)]
337-
pub struct SessionUser {
338-
employee: EmployeeId,
339-
email: String,
340-
name: String,
341-
slack_connect: Option<SlackConnectData>,
342-
}
343-
344-
#[derive(Debug, Deserialize, Serialize)]
345-
pub struct SlackConnectData {
346-
csrf_token: CsrfToken,
347-
nonce: Nonce,
348-
pkce_verifier: PkceCodeVerifier,
349-
}
350-
351-
impl SessionUser {
352-
pub(crate) fn with_slack_connect(
353-
self,
354-
csrf_token: CsrfToken,
355-
nonce: Nonce,
356-
pkce_verifier: PkceCodeVerifier,
357-
) -> Self {
358-
SessionUser {
359-
slack_connect: Some(SlackConnectData {
360-
csrf_token,
361-
nonce,
362-
pkce_verifier,
363-
}),
364-
..self
365-
}
366-
}
367-
}
368-
369-
async fn load_session_from_parts<S>(parts: &mut Parts, state: &S) -> anyhow::Result<Session>
370-
where
371-
MemoryStore: FromRef<S>,
372-
S: Send + Sync,
373-
{
374-
let cookies = parts.extract::<TypedHeader<headers::Cookie>>().await?;
375-
376-
load_session(cookies, state).await
377-
}
378-
379-
async fn load_session<S>(cookies: TypedHeader<headers::Cookie>, state: &S) -> anyhow::Result<Session>
380-
where
381-
MemoryStore: FromRef<S>,
382-
S: Send + Sync,
383-
{
384-
let cookie = cookies
385-
.get(COOKIE_NAME)
386-
.ok_or(anyhow!("cookie not found"))?
387-
.to_string();
388-
389-
let store = MemoryStore::from_ref(state);
390-
391-
match store.load_session(cookie).await? {
392-
Some(session) => Ok(session),
393-
_ => Err(anyhow!("Could not load session")),
394-
}
395-
}
396-
397-
impl<S> OptionalFromRequestParts<S> for SessionUser
398-
where
399-
MemoryStore: FromRef<S>,
400-
S: Send + Sync,
401-
{
402-
type Rejection = ();
403-
404-
async fn from_request_parts(
405-
parts: &mut Parts,
406-
state: &S,
407-
) -> Result<Option<Self>, Self::Rejection> {
408-
let session = load_session_from_parts(parts, state).await;
409-
410-
let user = session.and_then(|session| {
411-
session
412-
.get::<SessionUser>(USER_SESSION_KEY)
413-
.ok_or(anyhow!("no user in session"))
414-
});
415-
416-
Ok(user.ok())
417-
}
418-
}
419-
420-
impl<S> FromRequestParts<S> for SessionUser
421-
where
422-
MemoryStore: FromRef<S>,
423-
S: Send + Sync,
424-
{
425-
type Rejection = AuthRedirect;
426-
427-
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
428-
let session = load_session_from_parts(parts, state).await;
429-
430-
let user = session.and_then(|session| {
431-
session
432-
.get::<SessionUser>(USER_SESSION_KEY)
433-
.ok_or(anyhow!("no user in session"))
434-
});
435-
436-
user.map_err(|_| AuthRedirect)
437-
438-
// let store = MemoryStore::from_ref(state);
439-
//
440-
// async move {
441-
// let cookies = parts
442-
// .extract::<TypedHeader<headers::Cookie>>()
443-
// .await
444-
// .map_err(|e| match *e.name() {
445-
// header::COOKIE => match e.reason() {
446-
// TypedHeaderRejectionReason::Missing => AuthRedirect,
447-
// _ => panic!("unexpected error getting Cookie header(s): {e}"),
448-
// },
449-
// _ => panic!("unexpected error getting cookies: {e}"),
450-
// })?;
451-
// let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?;
452-
//
453-
// let session = store
454-
// .load_session(session_cookie.to_string())
455-
// .await
456-
// .unwrap()
457-
// .ok_or(AuthRedirect)?;
458-
//
459-
// let user = session
460-
// .get::<SessionUser>(USER_SESSION_KEY)
461-
// .ok_or(AuthRedirect)?;
462-
//
463-
// Ok(user)
464-
// }
465-
// .await
466-
}
467-
}

0 commit comments

Comments
 (0)