@@ -8,9 +8,11 @@ use tokio::signal;
88use tokio_util:: sync:: CancellationToken ;
99
1010mod handler;
11+ mod middleware;
1112mod types;
1213
1314use handler:: * ;
15+ pub use middleware:: { ApiKeyMiddleware , HttpMiddleware } ;
1416
1517const APP_NAME : & str = env ! ( "CARGO_PKG_NAME" ) ;
1618const APP_VERSION : & str = env ! ( "CARGO_PKG_VERSION" ) ;
@@ -22,7 +24,7 @@ pub struct ServerBuilder {
2224 origin : String ,
2325 engines : BTreeMap < Principal , Engine > ,
2426 default_engine : Option < Principal > ,
25- api_key : Option < String > ,
27+ middlewares : Vec < Arc < dyn HttpMiddleware > > ,
2628}
2729
2830impl Default for ServerBuilder {
@@ -43,7 +45,7 @@ impl ServerBuilder {
4345 origin : "https://localhost:8443" . to_string ( ) ,
4446 engines : BTreeMap :: new ( ) ,
4547 default_engine : None ,
46- api_key : None ,
48+ middlewares : Vec :: new ( ) ,
4749 }
4850 }
4951
@@ -67,11 +69,6 @@ impl ServerBuilder {
6769 self
6870 }
6971
70- pub fn with_api_key ( mut self , api_key : String ) -> Self {
71- self . api_key = Some ( api_key) ;
72- self
73- }
74-
7572 pub fn with_engines (
7673 mut self ,
7774 mut engines : BTreeMap < Principal , Engine > ,
@@ -86,6 +83,67 @@ impl ServerBuilder {
8683 self
8784 }
8885
86+ /// Register a router middleware.
87+ ///
88+ /// This is the low-level API. The middleware will be applied to the internal
89+ /// axum `Router` (typically via `router.layer(...)`). Middlewares are applied
90+ /// in the order they are added.
91+ ///
92+ /// More details: https://docs.rs/axum/latest/axum/middleware/index.html#ordering
93+ ///
94+ /// If you want a middleware that looks like `axum::middleware::from_fn`
95+ /// (i.e. can operate on `(req, next)`), prefer [`with_request_middleware`].
96+ ///
97+ /// Example:
98+ /// ```ignore
99+ /// let server = ServerBuilder::new()
100+ /// .with_middleware(|router| {
101+ /// router.layer(axum::middleware::from_fn(|req, next| async move {
102+ /// // custom auth / param checks here
103+ /// next.run(req).await
104+ /// }))
105+ /// });
106+ /// ```
107+ pub fn with_middleware < M > ( mut self , middleware : M ) -> Self
108+ where
109+ M : HttpMiddleware ,
110+ {
111+ self . middlewares . push ( Arc :: new ( middleware) ) ;
112+ self
113+ }
114+
115+ /// Register a request middleware like `axum::middleware::from_fn`.
116+ ///
117+ /// The middleware function runs for every incoming request, and can decide
118+ /// to short-circuit with a response or call `next.run(req)`.
119+ ///
120+ /// Example:
121+ /// ```ignore
122+ /// use axum::http::StatusCode;
123+ /// use axum::response::IntoResponse;
124+ ///
125+ /// let server = ServerBuilder::new()
126+ /// .with_request_middleware(|req, next| async move {
127+ /// // custom auth / param checks here
128+ /// if req.headers().get("x-allow").is_none() {
129+ /// return (StatusCode::UNAUTHORIZED, "missing x-allow").into_response();
130+ /// }
131+ ///
132+ /// next.run(req).await
133+ /// });
134+ /// ```
135+ pub fn with_request_middleware < F , Fut > ( self , f : F ) -> Self
136+ where
137+ F : Fn ( axum:: extract:: Request , axum:: middleware:: Next ) -> Fut
138+ + Clone
139+ + Send
140+ + Sync
141+ + ' static ,
142+ Fut : Future < Output = axum:: response:: Response > + Send + ' static ,
143+ {
144+ self . with_middleware ( middleware:: RequestFnMiddleware :: new ( f) )
145+ }
146+
89147 pub async fn serve (
90148 self ,
91149 signal : impl Future < Output = ( ) > + Send + ' static ,
@@ -105,18 +163,26 @@ impl ServerBuilder {
105163 engines : Arc :: new ( self . engines ) ,
106164 default_engine,
107165 start_time_ms : unix_ms ( ) ,
108- api_key : self . api_key ,
109166 } ;
110- let app = Router :: new ( )
167+
168+ // Build a router that is still "missing" an `AppState`.
169+ // We'll provide the state at the end (after applying middlewares) so we
170+ // end up with a `Router<()>` that can be passed to `axum::serve`.
171+ let mut app: Router < AppState > = Router :: new ( )
111172 . route ( "/" , routing:: get ( get_information) )
112173 . route ( "/.well-known/information" , routing:: get ( get_information) )
113174 . route ( "/.well-known/agents" , routing:: get ( get_information) )
114175 . route (
115176 "/.well-known/agents/{id}" ,
116177 routing:: get ( get_engine_information) ,
117178 )
118- . route ( "/{*id}" , routing:: post ( anda_engine) )
119- . with_state ( state) ;
179+ . route ( "/{*id}" , routing:: post ( anda_engine) ) ;
180+
181+ for middleware in & self . middlewares {
182+ app = middleware. apply ( app) ;
183+ }
184+
185+ let app = app. with_state ( state) ;
120186
121187 let addr: SocketAddr = self . addr . parse ( ) ?;
122188 let listener = create_reuse_port_listener ( addr) . await ?;
0 commit comments