Skip to content

Commit f873b73

Browse files
committed
fix bug in router
wasn't able to match on _just_ the prefix
1 parent 3afc533 commit f873b73

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

rama-http/src/service/web/router.rs

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::convert::Infallible;
1+
use std::{convert::Infallible, sync::Arc};
22

33
use crate::{
44
Request, Response,
@@ -139,15 +139,20 @@ where
139139
where
140140
I: IntoEndpointService<State, T>,
141141
{
142-
let path = format!("{}/{}", prefix.trim_end_matches(['/']), "{*nest}");
143-
let matcher = HttpMatcher::custom(true);
142+
let path = format!("{}/{}", prefix.trim().trim_end_matches(['/']), "{*nest}");
143+
let nested = Arc::new(service.into_endpoint_service().boxed());
144144

145145
let nested_router_service = NestedRouterService {
146-
prefix: prefix.to_owned(),
147-
nested: service.into_endpoint_service().boxed(),
146+
prefix: Arc::from(prefix),
147+
nested,
148148
};
149149

150-
self.match_route(&path, matcher, nested_router_service)
150+
self.match_route(
151+
prefix,
152+
HttpMatcher::custom(true),
153+
nested_router_service.clone(),
154+
)
155+
.match_route(&path, HttpMatcher::custom(true), nested_router_service)
151156
}
152157

153158
/// add a route to the router with it's matcher and service.
@@ -162,6 +167,11 @@ where
162167
{
163168
let service = service.into_endpoint_service().boxed();
164169

170+
let mut path = path.trim().trim_end_matches('/');
171+
if path.is_empty() {
172+
path = "/"
173+
}
174+
165175
if let Ok(matched) = self.routes.at_mut(path) {
166176
matched.value.push((matcher, service));
167177
} else {
@@ -183,17 +193,11 @@ where
183193
}
184194
}
185195

196+
#[derive(Debug, Clone)]
186197
struct NestedRouterService<State> {
187-
prefix: String,
188-
nested: BoxService<State, Request, Response, Infallible>,
189-
}
190-
191-
impl<State: std::fmt::Debug> std::fmt::Debug for NestedRouterService<State> {
192-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193-
f.debug_struct("NestedRouterService")
194-
.field("prefix", &self.prefix)
195-
.finish()
196-
}
198+
#[expect(unused)]
199+
prefix: Arc<str>,
200+
nested: Arc<BoxService<State, Request, Response, Infallible>>,
197201
}
198202

199203
impl<State> Service<State, Request> for NestedRouterService<State>
@@ -210,7 +214,7 @@ where
210214
) -> Result<Self::Response, Self::Error> {
211215
let params: UriParams = match ctx.remove::<UriParams>() {
212216
Some(params) => {
213-
let nested_path = params.get("nest").unwrap();
217+
let nested_path = params.get("nest").unwrap_or_default();
214218

215219
let filtered_params: UriParams =
216220
params.iter().filter(|(key, _)| *key != "nest").collect();
@@ -311,7 +315,7 @@ mod tests {
311315
})
312316
}
313317

314-
fn get_users_servic() -> impl Service<(), Request, Response = Response, Error = Infallible> {
318+
fn get_users_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
315319
service_fn(|_ctx, _req| async {
316320
Ok(Response::builder()
317321
.status(200)
@@ -383,7 +387,7 @@ mod tests {
383387
async fn test_router() {
384388
let router = Router::new()
385389
.get("/", root_service())
386-
.get("/users", get_users_servic())
390+
.get("/users", get_users_service())
387391
.post("/users", create_user_service())
388392
.get("/users/{user_id}", get_user_service())
389393
.delete("/users/{user_id}", delete_user_service())
@@ -452,12 +456,14 @@ mod tests {
452456
#[tokio::test]
453457
async fn test_router_nest() {
454458
let api_router = Router::new()
455-
.get("/users", get_users_servic())
459+
.get("/users", get_users_service())
456460
.post("/users", create_user_service())
457461
.delete("/users/{user_id}", delete_user_service())
458462
.sub(
459463
"/users/{user_id}",
460-
Router::new().get("/orders/{order_id}", get_user_order_service()),
464+
Router::new()
465+
.get("/", get_user_service())
466+
.get("/orders/{order_id}", get_user_order_service()),
461467
);
462468

463469
let app = Router::new()
@@ -474,6 +480,12 @@ mod tests {
474480
"Delete User: 123",
475481
StatusCode::OK,
476482
),
483+
(
484+
Method::GET,
485+
"/api/users/123",
486+
"Get User: 123",
487+
StatusCode::OK,
488+
),
477489
(
478490
Method::GET,
479491
"/api/users/123/orders/456",
@@ -493,9 +505,13 @@ mod tests {
493505
.unwrap();
494506

495507
let res = app.serve(Context::default(), req).await.unwrap();
496-
assert_eq!(res.status(), expected_status);
508+
assert_eq!(
509+
res.status(),
510+
expected_status,
511+
"method: {method} ; path = {path}"
512+
);
497513
let body = res.into_body().collect().await.unwrap().to_bytes();
498-
assert_eq!(body, expected_body);
514+
assert_eq!(body, expected_body, "method: {method} ; path = {path}");
499515
}
500516
}
501517
}

0 commit comments

Comments
 (0)