Skip to content

Commit e2c1692

Browse files
authored
Merge pull request #224 from songokas/status-code-from-request
Allow returning a different status code based on a request
2 parents 4a70878 + 86596b7 commit e2c1692

File tree

4 files changed

+99
-8
lines changed

4 files changed

+99
-8
lines changed

src/mock.rs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::diff;
22
use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher};
3-
use crate::response::{Body, Header, Response};
3+
use crate::response::{Body, Header, Response, ResponseStatusCode};
44
use crate::server::RemoteMock;
55
use crate::server::State;
66
use crate::Request;
@@ -346,10 +346,45 @@ impl Mock {
346346
///
347347
#[track_caller]
348348
pub fn with_status(mut self, status: usize) -> Self {
349-
self.inner.response.status = StatusCode::from_u16(status as u16)
350-
.map_err(|_| Error::new_with_context(ErrorKind::InvalidStatusCode, status))
351-
.unwrap();
349+
self.inner.response.status = ResponseStatusCode::StatusCode(
350+
StatusCode::from_u16(status as u16)
351+
.map_err(|_| Error::new_with_context(ErrorKind::InvalidStatusCode, status))
352+
.unwrap(),
353+
);
354+
self
355+
}
352356

357+
///
358+
/// Sets the status code of the mock response dynamically while exposing the request object.
359+
///
360+
/// You can use this method to provide custom status code for every incoming request.
361+
///
362+
/// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
363+
/// Use `move` closures and `Arc` to share any data.
364+
///
365+
/// ### Example
366+
///
367+
/// ```
368+
/// let mut s = mockito::Server::new();
369+
///
370+
/// let _m = s.mock("GET", mockito::Matcher::Any).with_status_code_from_request(|request| {
371+
/// if request.path() == "/bob" {
372+
/// 500
373+
/// } else if request.path() == "/alice" {
374+
/// 400
375+
/// } else {
376+
/// 404
377+
/// }
378+
/// });
379+
/// ```
380+
///
381+
#[track_caller]
382+
pub fn with_status_code_from_request(
383+
mut self,
384+
callback: impl Fn(&Request) -> usize + Send + Sync + 'static,
385+
) -> Self {
386+
self.inner.response.status =
387+
ResponseStatusCode::FnWithRequest(Arc::new(move |req| callback(req)));
353388
self
354389
}
355390

src/response.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,44 @@ use tokio::sync::mpsc;
1212

1313
#[derive(Clone, Debug, PartialEq)]
1414
pub(crate) struct Response {
15-
pub status: StatusCode,
15+
pub status: ResponseStatusCode,
1616
pub headers: HeaderMap<Header>,
1717
pub body: Body,
1818
}
1919

20+
#[derive(Clone)]
21+
pub(crate) enum ResponseStatusCode {
22+
StatusCode(StatusCode),
23+
FnWithRequest(Arc<StatusCodeFnWithRequest>),
24+
}
25+
26+
impl fmt::Debug for ResponseStatusCode {
27+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28+
match *self {
29+
ResponseStatusCode::StatusCode(ref s) => s.fmt(f),
30+
ResponseStatusCode::FnWithRequest(_) => f.write_str("<callback>"),
31+
}
32+
}
33+
}
34+
35+
impl PartialEq for ResponseStatusCode {
36+
fn eq(&self, other: &Self) -> bool {
37+
match (self, other) {
38+
(ResponseStatusCode::StatusCode(ref a), ResponseStatusCode::StatusCode(ref b)) => {
39+
a == b
40+
}
41+
(
42+
ResponseStatusCode::FnWithRequest(ref a),
43+
ResponseStatusCode::FnWithRequest(ref b),
44+
) => std::ptr::eq(
45+
a.as_ref() as *const StatusCodeFnWithRequest as *const u8,
46+
b.as_ref() as *const StatusCodeFnWithRequest as *const u8,
47+
),
48+
_ => false,
49+
}
50+
}
51+
}
52+
2053
#[derive(Clone)]
2154
pub(crate) enum Header {
2255
String(String),
@@ -45,6 +78,7 @@ impl PartialEq for Header {
4578
}
4679
}
4780

81+
type StatusCodeFnWithRequest = dyn Fn(&Request) -> usize + Send + Sync;
4882
type HeaderFnWithRequest = dyn Fn(&Request) -> String + Send + Sync;
4983

5084
type BodyFnWithWriter = dyn Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static;
@@ -89,7 +123,7 @@ impl Default for Response {
89123
let mut headers = HeaderMap::with_capacity(1);
90124
headers.insert("connection", Header::String("close".to_string()));
91125
Self {
92-
status: StatusCode::OK,
126+
status: ResponseStatusCode::StatusCode(StatusCode::OK),
93127
headers,
94128
body: Body::Bytes(Bytes::new()),
95129
}

src/server.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::mock::InnerMock;
22
use crate::request::Request;
3-
use crate::response::{Body as ResponseBody, ChunkedStream, Header};
3+
use crate::response::{Body as ResponseBody, ChunkedStream, Header, ResponseStatusCode};
44
use crate::ServerGuard;
55
use crate::{Error, ErrorKind, Matcher, Mock};
66
use bytes::Bytes;
@@ -561,7 +561,15 @@ async fn handle_request(
561561
}
562562

563563
fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
564-
let status: StatusCode = mock.inner.response.status;
564+
let status: StatusCode = match &mock.inner.response.status {
565+
ResponseStatusCode::StatusCode(c) => *c,
566+
ResponseStatusCode::FnWithRequest(status_code_fn) => {
567+
let status = status_code_fn(&request);
568+
StatusCode::from_u16(status as u16)
569+
.map_err(|_| Error::new_with_context(ErrorKind::InvalidStatusCode, status))
570+
.unwrap()
571+
}
572+
};
565573
let mut response = Response::builder().status(status);
566574

567575
for (name, value) in mock.inner.response.headers.iter() {

tests/lib.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,20 @@ fn test_mock_with_custom_status() {
587587
assert_eq!("HTTP/1.1 499 <none>\r\n", status_line);
588588
}
589589

590+
#[test]
591+
fn test_mock_with_status_code_from_request() {
592+
let mut s = Server::new();
593+
s.mock("GET", Matcher::Any)
594+
.with_status_code_from_request(|request| if request.path() == "/world" { 500 } else { 404 })
595+
.create();
596+
597+
let (status_line, _, _) = request(s.host_with_port(), "GET /world", "");
598+
assert_eq!("HTTP/1.1 500 Internal Server Error\r\n", status_line);
599+
600+
let (status_line, _, _) = request(s.host_with_port(), "GET /", "");
601+
assert_eq!("HTTP/1.1 404 Not Found\r\n", status_line);
602+
}
603+
590604
#[test]
591605
fn test_mock_with_body() {
592606
let mut s = Server::new();

0 commit comments

Comments
 (0)