Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 174 additions & 28 deletions actix-web/src/middleware/normalize.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
//! For middleware documentation, see [`NormalizePath`].

use actix_http::uri::{PathAndQuery, Uri};
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};

use actix_service::{Service, Transform};
use actix_utils::future::{ready, Ready};
use bytes::Bytes;
use futures_core::ready;
use pin_project_lite::pin_project;
use regex::Regex;

use crate::{
body::EitherBody,
http::{
header,
uri::{PathAndQuery, Uri},
StatusCode,
},
service::{ServiceRequest, ServiceResponse},
Error,
Error, HttpResponse,
};

/// Determines the behavior of the [`NormalizePath`] middleware.
/// Determines the path rewriting behavior of the [`NormalizePath`] middleware.
///
/// The default is `TrailingSlash::Trim`.
#[non_exhaustive]
Expand Down Expand Up @@ -86,7 +100,13 @@ impl Default for TrailingSlash {
/// # })
/// ```
#[derive(Debug, Clone, Copy)]
pub struct NormalizePath(TrailingSlash);
pub struct NormalizePath {
/// Controls path normalization behavior.
trailing_slash_behavior: TrailingSlash,

/// Returns redirects for non-normalized paths if `Some`.
use_redirects: Option<StatusCode>,
}

impl Default for NormalizePath {
fn default() -> Self {
Expand All @@ -95,14 +115,20 @@ impl Default for NormalizePath {
in v4 from `Always` to `Trim`. Update your call to `NormalizePath::new(...)`."
);

Self(TrailingSlash::Trim)
Self {
trailing_slash_behavior: TrailingSlash::default(),
use_redirects: None,
}
}
}

impl NormalizePath {
/// Create new `NormalizePath` middleware with the specified trailing slash style.
pub fn new(trailing_slash_style: TrailingSlash) -> Self {
Self(trailing_slash_style)
pub fn new(behavior: TrailingSlash) -> Self {
Self {
trailing_slash_behavior: behavior,
use_redirects: None,
}
}

/// Constructs a new `NormalizePath` middleware with [trim](TrailingSlash::Trim) semantics.
Expand All @@ -111,42 +137,70 @@ impl NormalizePath {
pub fn trim() -> Self {
Self::new(TrailingSlash::Trim)
}

/// Configures middleware to respond to requests with non-normalized paths with a 307 redirect.
///
/// If configured
///
/// For example, a request with the path `/api//v1/foo/` would receive a response with a
/// `Location: /api/v1/foo` header (assuming `Trim` trailing slash behavior.)
///
/// To customize the status code, use [`use_redirects_with`](Self::use_redirects_with).
pub fn use_redirects(mut self) -> Self {
self.use_redirects = Some(StatusCode::TEMPORARY_REDIRECT);
self
}

/// Configures middleware to respond to requests with non-normalized paths with a redirect.
///
/// For example, a request with the path `/api//v1/foo/` would receive a 307 response with a
/// `Location: /api/v1/foo` header (assuming `Trim` trailing slash behavior.)
///
/// # Panics
/// Panics if `status_code` is not a redirect (300-399).
pub fn use_redirects_with(mut self, status_code: StatusCode) -> Self {
assert!(status_code.is_redirection());
self.use_redirects = Some(status_code);
self
}
}

impl<S, B> Transform<S, ServiceRequest> for NormalizePath
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
{
type Response = ServiceResponse<B>;
type Response = ServiceResponse<EitherBody<B, ()>>;
type Error = Error;
type Transform = NormalizePathNormalization<S>;
type Transform = NormalizePathService<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;

fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(NormalizePathNormalization {
ready(Ok(NormalizePathService {
service,
merge_slash: Regex::new("//+").unwrap(),
trailing_slash_behavior: self.0,
trailing_slash_behavior: self.trailing_slash_behavior,
use_redirects: self.use_redirects,
}))
}
}

pub struct NormalizePathNormalization<S> {
pub struct NormalizePathService<S> {
service: S,
merge_slash: Regex,
trailing_slash_behavior: TrailingSlash,
use_redirects: Option<StatusCode>,
}

impl<S, B> Service<ServiceRequest> for NormalizePathNormalization<S>
impl<S, B> Service<ServiceRequest> for NormalizePathService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
{
type Response = ServiceResponse<B>;
type Response = ServiceResponse<EitherBody<B, ()>>;
type Error = Error;
type Future = S::Future;
type Future = NormalizePathFuture<S, B>;

actix_service::forward_ready!(service);

Expand Down Expand Up @@ -189,7 +243,7 @@ where
let query = parts.path_and_query.as_ref().and_then(|pq| pq.query());

let path = match query {
Some(q) => Bytes::from(format!("{}?{}", path, q)),
Some(query) => Bytes::from(format!("{}?{}", path, query)),
None => Bytes::copy_from_slice(path.as_bytes()),
};
parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap());
Expand All @@ -199,20 +253,87 @@ where
req.head_mut().uri = uri;
}
}
self.service.call(req)

match self.use_redirects {
Some(code) => {
let mut res = HttpResponse::with_body(code, ());
res.headers_mut().insert(
header::LOCATION,
req.head_mut().uri.to_string().parse().unwrap(),
);
NormalizePathFuture::redirect(req.into_response(res))
}

None => NormalizePathFuture::service(self.service.call(req)),
}
}
}

pin_project! {
pub struct NormalizePathFuture<S: Service<ServiceRequest>, B> {
#[pin] inner: Inner<S, B>,
}
}

impl<S: Service<ServiceRequest>, B> NormalizePathFuture<S, B> {
fn service(fut: S::Future) -> Self {
Self {
inner: Inner::Service {
fut,
_body: PhantomData,
},
}
}

fn redirect(res: ServiceResponse<()>) -> Self {
Self {
inner: Inner::Redirect { res: Some(res) },
}
}
}

pin_project! {
#[project = InnerProj]
enum Inner<S: Service<ServiceRequest>, B> {
Redirect { res: Option<ServiceResponse<()>>, },
Service {
#[pin] fut: S::Future,
_body: PhantomData<B>,
},
}
}

impl<S, B> Future for NormalizePathFuture<S, B>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{
type Output = Result<ServiceResponse<EitherBody<B, ()>>, Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

match this.inner.project() {
InnerProj::Redirect { res } => {
Poll::Ready(Ok(res.take().unwrap().map_into_right_body()))
}

InnerProj::Service { fut, .. } => {
let res = ready!(fut.poll(cx))?;
Poll::Ready(Ok(res.map_into_left_body()))
}
}
}
}

#[cfg(test)]
mod tests {
use actix_http::StatusCode;
use actix_service::IntoService;

use super::*;
use crate::{
dev::ServiceRequest,
guard::fn_guard,
test::{call_service, init_service, TestRequest},
test::{self, call_service, init_service, TestRequest},
web, App, HttpResponse,
};

Expand Down Expand Up @@ -256,7 +377,7 @@ mod tests {
async fn trim_trailing_slashes() {
let app = init_service(
App::new()
.wrap(NormalizePath(TrailingSlash::Trim))
.wrap(NormalizePath::new(TrailingSlash::Trim))
.service(web::resource("/").to(HttpResponse::Ok))
.service(web::resource("/v1/something").to(HttpResponse::Ok))
.service(
Expand Down Expand Up @@ -292,11 +413,13 @@ mod tests {
#[actix_rt::test]
async fn trim_root_trailing_slashes_with_query() {
let app = init_service(
App::new().wrap(NormalizePath(TrailingSlash::Trim)).service(
web::resource("/")
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
App::new()
.wrap(NormalizePath::new(TrailingSlash::Trim))
.service(
web::resource("/")
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok),
),
)
.await;

Expand All @@ -313,7 +436,7 @@ mod tests {
async fn ensure_trailing_slash() {
let app = init_service(
App::new()
.wrap(NormalizePath(TrailingSlash::Always))
.wrap(NormalizePath::new(TrailingSlash::Always))
.service(web::resource("/").to(HttpResponse::Ok))
.service(web::resource("/v1/something/").to(HttpResponse::Ok))
.service(
Expand Down Expand Up @@ -350,7 +473,7 @@ mod tests {
async fn ensure_root_trailing_slash_with_query() {
let app = init_service(
App::new()
.wrap(NormalizePath(TrailingSlash::Always))
.wrap(NormalizePath::new(TrailingSlash::Always))
.service(
web::resource("/")
.guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
Expand All @@ -372,7 +495,7 @@ mod tests {
async fn keep_trailing_slash_unchanged() {
let app = init_service(
App::new()
.wrap(NormalizePath(TrailingSlash::MergeOnly))
.wrap(NormalizePath::new(TrailingSlash::MergeOnly))
.service(web::resource("/").to(HttpResponse::Ok))
.service(web::resource("/v1/something").to(HttpResponse::Ok))
.service(web::resource("/v1/").to(HttpResponse::Ok))
Expand Down Expand Up @@ -486,4 +609,27 @@ mod tests {
let res = normalize.call(req).await.unwrap();
assert!(res.status().is_success());
}

#[actix_rt::test]
async fn should_return_redirects_when_configured() {
let normalize = NormalizePath::trim()
.use_redirects()
.new_transform(test::ok_service())
.await
.unwrap();

let req = TestRequest::with_uri("/v1/something/").to_srv_request();
let res = normalize.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);

let normalize = NormalizePath::trim()
.use_redirects_with(StatusCode::PERMANENT_REDIRECT)
.new_transform(test::ok_service())
.await
.unwrap();

let req = TestRequest::with_uri("/v1/something/").to_srv_request();
let res = normalize.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
}
}