Skip to content

Commit f1afb93

Browse files
committed
feat(anda_engine_server): support middleware
1 parent b08e474 commit f1afb93

File tree

4 files changed

+181
-30
lines changed

4 files changed

+181
-30
lines changed

anda_engine_server/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "anda_engine_server"
33
description = "A http server to serve multiple Anda engines."
44
repository = "https://github.com/ldclabs/anda/tree/main/anda_engine_server"
55
publish = true
6-
version = "0.9.3"
6+
version = "0.9.4"
77
edition.workspace = true
88
keywords.workspace = true
99
categories.workspace = true

anda_engine_server/src/handler.rs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,6 @@ pub struct AppState {
2626
pub(crate) engines: Arc<BTreeMap<Principal, Engine>>,
2727
pub(crate) default_engine: Principal,
2828
pub(crate) start_time_ms: u64,
29-
pub(crate) api_key: Option<String>,
30-
}
31-
32-
impl AppState {
33-
pub fn check_api_key(&self, headers: &http::HeaderMap) -> Result<(), String> {
34-
if let Some(expected_key) = &self.api_key {
35-
match headers.get("x-api-key") {
36-
Some(provided_key) if provided_key == expected_key => Ok(()),
37-
_ => Err("missing or invalid x-api-key in headers".to_string()),
38-
}
39-
} else {
40-
Ok(())
41-
}
42-
}
4329
}
4430

4531
/// GET /.well-known/information
@@ -112,10 +98,6 @@ pub async fn anda_engine(
11298
Path(id): Path<String>,
11399
ct: ContentWithSHA3<RPCRequest>,
114100
) -> impl IntoResponse {
115-
if let Err(err) = app.check_api_key(&headers) {
116-
return (StatusCode::UNAUTHORIZED, err).into_response();
117-
}
118-
119101
let id = if &id == "default" {
120102
app.default_engine
121103
} else if let Ok(id) = Principal::from_text(&id) {

anda_engine_server/src/lib.rs

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ use tokio::signal;
88
use tokio_util::sync::CancellationToken;
99

1010
mod handler;
11+
mod middleware;
1112
mod types;
1213

1314
use handler::*;
15+
pub use middleware::{ApiKeyMiddleware, HttpMiddleware};
1416

1517
const APP_NAME: &str = env!("CARGO_PKG_NAME");
1618
const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -22,7 +24,7 @@ pub struct ServerBuilder {
2224
origin: String,
2325
engines: BTreeMap<Principal, Engine>,
2426
default_engine: Option<Principal>,
25-
api_key: Option<String>,
27+
middlewares: Vec<Arc<dyn HttpMiddleware>>,
2628
}
2729

2830
impl Default for ServerBuilder {
@@ -43,7 +45,7 @@ impl ServerBuilder {
4345
origin: "https://localhost:8443".to_string(),
4446
engines: BTreeMap::new(),
4547
default_engine: None,
46-
api_key: None,
48+
middlewares: Vec::new(),
4749
}
4850
}
4951

@@ -67,11 +69,6 @@ impl ServerBuilder {
6769
self
6870
}
6971

70-
pub fn with_api_key(mut self, api_key: String) -> Self {
71-
self.api_key = Some(api_key);
72-
self
73-
}
74-
7572
pub fn with_engines(
7673
mut self,
7774
mut engines: BTreeMap<Principal, Engine>,
@@ -86,6 +83,67 @@ impl ServerBuilder {
8683
self
8784
}
8885

86+
/// Register a router middleware.
87+
///
88+
/// This is the low-level API. The middleware will be applied to the internal
89+
/// axum `Router` (typically via `router.layer(...)`). Middlewares are applied
90+
/// in the order they are added.
91+
///
92+
/// More details: https://docs.rs/axum/latest/axum/middleware/index.html#ordering
93+
///
94+
/// If you want a middleware that looks like `axum::middleware::from_fn`
95+
/// (i.e. can operate on `(req, next)`), prefer [`with_request_middleware`].
96+
///
97+
/// Example:
98+
/// ```ignore
99+
/// let server = ServerBuilder::new()
100+
/// .with_middleware(|router| {
101+
/// router.layer(axum::middleware::from_fn(|req, next| async move {
102+
/// // custom auth / param checks here
103+
/// next.run(req).await
104+
/// }))
105+
/// });
106+
/// ```
107+
pub fn with_middleware<M>(mut self, middleware: M) -> Self
108+
where
109+
M: HttpMiddleware,
110+
{
111+
self.middlewares.push(Arc::new(middleware));
112+
self
113+
}
114+
115+
/// Register a request middleware like `axum::middleware::from_fn`.
116+
///
117+
/// The middleware function runs for every incoming request, and can decide
118+
/// to short-circuit with a response or call `next.run(req)`.
119+
///
120+
/// Example:
121+
/// ```ignore
122+
/// use axum::http::StatusCode;
123+
/// use axum::response::IntoResponse;
124+
///
125+
/// let server = ServerBuilder::new()
126+
/// .with_request_middleware(|req, next| async move {
127+
/// // custom auth / param checks here
128+
/// if req.headers().get("x-allow").is_none() {
129+
/// return (StatusCode::UNAUTHORIZED, "missing x-allow").into_response();
130+
/// }
131+
///
132+
/// next.run(req).await
133+
/// });
134+
/// ```
135+
pub fn with_request_middleware<F, Fut>(self, f: F) -> Self
136+
where
137+
F: Fn(axum::extract::Request, axum::middleware::Next) -> Fut
138+
+ Clone
139+
+ Send
140+
+ Sync
141+
+ 'static,
142+
Fut: Future<Output = axum::response::Response> + Send + 'static,
143+
{
144+
self.with_middleware(middleware::RequestFnMiddleware::new(f))
145+
}
146+
89147
pub async fn serve(
90148
self,
91149
signal: impl Future<Output = ()> + Send + 'static,
@@ -105,18 +163,26 @@ impl ServerBuilder {
105163
engines: Arc::new(self.engines),
106164
default_engine,
107165
start_time_ms: unix_ms(),
108-
api_key: self.api_key,
109166
};
110-
let app = Router::new()
167+
168+
// Build a router that is still "missing" an `AppState`.
169+
// We'll provide the state at the end (after applying middlewares) so we
170+
// end up with a `Router<()>` that can be passed to `axum::serve`.
171+
let mut app: Router<AppState> = Router::new()
111172
.route("/", routing::get(get_information))
112173
.route("/.well-known/information", routing::get(get_information))
113174
.route("/.well-known/agents", routing::get(get_information))
114175
.route(
115176
"/.well-known/agents/{id}",
116177
routing::get(get_engine_information),
117178
)
118-
.route("/{*id}", routing::post(anda_engine))
119-
.with_state(state);
179+
.route("/{*id}", routing::post(anda_engine));
180+
181+
for middleware in &self.middlewares {
182+
app = middleware.apply(app);
183+
}
184+
185+
let app = app.with_state(state);
120186

121187
let addr: SocketAddr = self.addr.parse()?;
122188
let listener = create_reuse_port_listener(addr).await?;
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use std::future::Future;
2+
3+
use axum::{
4+
Router,
5+
extract::Request,
6+
http::StatusCode,
7+
middleware::Next,
8+
response::{IntoResponse, Response},
9+
};
10+
11+
use crate::handler::AppState;
12+
13+
/// Object-safe middleware trait for applying HTTP middleware to the server `Router`.
14+
///
15+
/// This is intentionally type-erased so callers can register arbitrary axum/tower
16+
/// middleware without turning `ServerBuilder` into a giant generic type.
17+
pub trait HttpMiddleware: Send + Sync + 'static {
18+
fn apply(&self, router: Router<AppState>) -> Router<AppState>;
19+
}
20+
21+
impl<F> HttpMiddleware for F
22+
where
23+
F: Fn(Router<AppState>) -> Router<AppState> + Send + Sync + 'static,
24+
{
25+
fn apply(&self, router: Router<AppState>) -> Router<AppState> {
26+
(self)(router)
27+
}
28+
}
29+
30+
/// Middleware built from a function like axum::middleware::from_fn.
31+
#[derive(Clone)]
32+
pub struct RequestFnMiddleware<F> {
33+
f: F,
34+
}
35+
36+
impl<F> RequestFnMiddleware<F> {
37+
pub fn new(f: F) -> Self {
38+
Self { f }
39+
}
40+
}
41+
42+
impl<F, Fut> HttpMiddleware for RequestFnMiddleware<F>
43+
where
44+
F: Fn(Request, Next) -> Fut + Clone + Send + Sync + 'static,
45+
Fut: Future<Output = Response> + Send + 'static,
46+
{
47+
fn apply(&self, router: Router<AppState>) -> Router<AppState> {
48+
router.layer(axum::middleware::from_fn(self.f.clone()))
49+
}
50+
}
51+
52+
/// A simple API key middleware that validates `x-api-key` on every request.
53+
///
54+
/// Use `exempt_path` to allow unauthenticated endpoints (e.g. health/info routes).
55+
#[derive(Clone, Default)]
56+
pub struct ApiKeyMiddleware {
57+
expected_key: String,
58+
exempt_paths: Vec<String>,
59+
}
60+
61+
impl ApiKeyMiddleware {
62+
pub fn new(expected_key: impl Into<String>) -> Self {
63+
Self {
64+
expected_key: expected_key.into(),
65+
exempt_paths: Vec::new(),
66+
}
67+
}
68+
69+
pub fn exempt_path(mut self, path: impl Into<String>) -> Self {
70+
self.exempt_paths.push(path.into());
71+
self
72+
}
73+
}
74+
75+
impl HttpMiddleware for ApiKeyMiddleware {
76+
fn apply(&self, router: Router<AppState>) -> Router<AppState> {
77+
let expected_key = self.expected_key.clone();
78+
let exempt_paths = self.exempt_paths.clone();
79+
80+
router.layer(axum::middleware::from_fn(
81+
move |req: Request, next: Next| {
82+
let expected_key = expected_key.clone();
83+
let exempt_paths = exempt_paths.clone();
84+
85+
async move {
86+
let path = req.uri().path();
87+
if exempt_paths.iter().any(|p| p == path) {
88+
return next.run(req).await;
89+
}
90+
91+
match req.headers().get("x-api-key") {
92+
Some(provided_key) if provided_key == &expected_key => next.run(req).await,
93+
_ => (
94+
StatusCode::UNAUTHORIZED,
95+
"missing or invalid x-api-key in headers",
96+
)
97+
.into_response(),
98+
}
99+
}
100+
},
101+
))
102+
}
103+
}

0 commit comments

Comments
 (0)