Skip to content

Commit 95057fd

Browse files
authored
chore: Derive FromRef on AppContext (#1735)
This allows extractors to pull out fields of `AppContext` as [substates](https://docs.rs/axum/0.8.8/axum/extract/struct.State.html\#substates) in handlers through [`FromRef::from_ref`](https://docs.rs/axum/latest/axum/extract/trait.FromRef.html)
1 parent 888614f commit 95057fd

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

src/app.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::{
1010
};
1111

1212
use async_trait::async_trait;
13+
use axum::extract::FromRef;
1314
use axum::Router as AxumRouter;
1415
use dashmap::DashMap;
1516

@@ -249,7 +250,7 @@ impl<T: 'static + Send + Sync> std::ops::Deref for RefGuard<'_, T> {
249250
/// the web server to operate. It is typically used to store and manage shared
250251
/// resources and settings that are accessible throughout the application's
251252
/// lifetime.
252-
#[derive(Clone)]
253+
#[derive(Clone, FromRef)]
253254
#[allow(clippy::module_name_repetitions)]
254255
pub struct AppContext {
255256
/// The environment in which the application is running.

tests/controller/from_ref.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use axum::extract::FromRef;
2+
use loco_rs::{
3+
app::{AppContext, SharedStore},
4+
cache,
5+
prelude::*,
6+
tests_cfg,
7+
};
8+
use std::sync::Arc;
9+
10+
use crate::infra_cfg;
11+
12+
#[cfg(feature = "with-db")]
13+
use sea_orm::DatabaseConnection;
14+
15+
/// Tests that DatabaseConnection can be extracted from AppContext via FromRef
16+
#[cfg(feature = "with-db")]
17+
#[tokio::test]
18+
async fn can_extract_db_connection_from_app_context() {
19+
let ctx = tests_cfg::app::get_app_context().await;
20+
21+
#[allow(clippy::items_after_statements)]
22+
async fn action(State(ctx): State<AppContext>) -> Result<Response> {
23+
// Use FromRef to extract DatabaseConnection from AppContext
24+
let _db: DatabaseConnection = DatabaseConnection::from_ref(&ctx);
25+
format::json(serde_json::json!({"extracted": "db"}))
26+
}
27+
28+
let port = get_available_port().await;
29+
let handle = infra_cfg::server::start_with_route(ctx, "/", get(action), Some(port)).await;
30+
31+
let res = reqwest::get(get_base_url_port(port))
32+
.await
33+
.expect("Valid response");
34+
35+
assert_eq!(res.status(), 200);
36+
37+
let body: serde_json::Value = res.json().await.expect("JSON response");
38+
assert_eq!(body["extracted"], "db");
39+
40+
handle.abort();
41+
}
42+
43+
/// Tests that Arc<Cache> can be extracted from AppContext via FromRef
44+
#[tokio::test]
45+
async fn can_extract_cache_from_app_context() {
46+
let ctx = tests_cfg::app::get_app_context().await;
47+
48+
#[allow(clippy::items_after_statements)]
49+
async fn action(State(ctx): State<AppContext>) -> Result<Response> {
50+
// Use FromRef to extract Arc<Cache> from AppContext
51+
let _cache: Arc<cache::Cache> = Arc::from_ref(&ctx);
52+
format::json(serde_json::json!({"extracted": "cache"}))
53+
}
54+
55+
let port = get_available_port().await;
56+
let handle = infra_cfg::server::start_with_route(ctx, "/", get(action), Some(port)).await;
57+
58+
let res = reqwest::get(get_base_url_port(port))
59+
.await
60+
.expect("Valid response");
61+
62+
assert_eq!(res.status(), 200);
63+
64+
let body: serde_json::Value = res.json().await.expect("JSON response");
65+
assert_eq!(body["extracted"], "cache");
66+
67+
handle.abort();
68+
}
69+
70+
/// Tests that Arc<SharedStore> can be extracted from AppContext via FromRef
71+
#[tokio::test]
72+
async fn can_extract_shared_store_from_app_context() {
73+
let ctx = tests_cfg::app::get_app_context().await;
74+
75+
#[allow(clippy::items_after_statements)]
76+
async fn action(State(ctx): State<AppContext>) -> Result<Response> {
77+
// Use FromRef to extract Arc<SharedStore> from AppContext
78+
let _store: Arc<SharedStore> = Arc::from_ref(&ctx);
79+
format::json(serde_json::json!({"extracted": "shared_store"}))
80+
}
81+
82+
let port = get_available_port().await;
83+
let handle = infra_cfg::server::start_with_route(ctx, "/", get(action), Some(port)).await;
84+
85+
let res = reqwest::get(get_base_url_port(port))
86+
.await
87+
.expect("Valid response");
88+
89+
assert_eq!(res.status(), 200);
90+
91+
let body: serde_json::Value = res.json().await.expect("JSON response");
92+
assert_eq!(body["extracted"], "shared_store");
93+
94+
handle.abort();
95+
}

tests/controller/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
mod extractor;
2+
mod from_ref;
23
mod into_response;
34
mod middlewares;

0 commit comments

Comments
 (0)