@@ -4,6 +4,7 @@ use std::net::SocketAddr;
44use std:: pin:: Pin ;
55
66use bytes:: Bytes ;
7+ pub use http:: Method ;
78use http_body_util:: Full ;
89use hyper:: body:: Incoming ;
910use hyper:: service:: Service ;
@@ -19,10 +20,34 @@ use tokio::net::{
1920use tokio:: select;
2021use tokio_util:: sync:: CancellationToken ;
2122
23+ /// A local HTTP server that can be used for returning mock responses to HTTP requests.
24+ ///
25+ /// # Examples
26+ ///
27+ /// ```rust
28+ /// # use fig_test_utils::server::*;
29+ /// # async fn run() -> Result<(), reqwest::Error> {
30+ /// // Hosting a local server that responds to GET requests for "/my-file" with the
31+ /// // body "some text".
32+ /// let test_path = String::from("/my-file");
33+ /// let mock_body = String::from("some text");
34+ /// let test_server_addr = TestServer::new()
35+ /// .await
36+ /// .with_mock_response(Method::GET, test_path.clone(), mock_body.clone())
37+ /// .spawn_listener();
38+ ///
39+ /// let body = reqwest::get(format!("http://{}{}", &test_server_addr, &test_path))
40+ /// .await?
41+ /// .text()
42+ /// .await?;
43+ ///
44+ /// assert_eq!(body, mock_body);
45+ /// # Ok(())
46+ /// # }
47+ /// ```
2248#[ derive( Debug ) ]
2349pub struct TestServer {
2450 listener : TcpListener ,
25- cancellation_token : Option < CancellationToken > ,
2651 mock_responses : HashMap < ( http:: Method , String ) , String > ,
2752}
2853
@@ -33,7 +58,6 @@ impl TestServer {
3358 . expect ( "failed to bind socket to local host" ) ;
3459 Self {
3560 listener,
36- cancellation_token : None ,
3761 mock_responses : HashMap :: default ( ) ,
3862 }
3963 }
@@ -43,14 +67,14 @@ impl TestServer {
4367 self
4468 }
4569
46- pub fn spawn_listener ( mut self ) -> SocketAddr {
47- let addr = self
70+ /// Spawns a new task for accepting requests, returning the address of the listening socket.
71+ pub fn spawn_listener ( self ) -> TestAddress {
72+ let address = self
4873 . listener
4974 . local_addr ( )
5075 . expect ( "listener should be bound to an address" ) ;
51- let token = CancellationToken :: new ( ) ;
52- let token_clone = token. clone ( ) ;
53- self . cancellation_token = Some ( token) ;
76+ let cancellation_token = CancellationToken :: new ( ) ;
77+ let token_clone = cancellation_token. clone ( ) ;
5478 tokio:: task:: spawn ( async move {
5579 loop {
5680 select ! {
@@ -64,7 +88,10 @@ impl TestServer {
6488 }
6589 }
6690 } ) ;
67- addr
91+ TestAddress {
92+ address,
93+ cancellation_token,
94+ }
6895 }
6996
7097 async fn handle_request ( & self , stream : TcpStream ) {
@@ -76,11 +103,27 @@ impl TestServer {
76103 }
77104}
78105
79- impl Drop for TestServer {
106+ #[ derive( Debug ) ]
107+ pub struct TestAddress {
108+ address : SocketAddr ,
109+ cancellation_token : CancellationToken ,
110+ }
111+
112+ impl TestAddress {
113+ pub fn address ( & self ) -> SocketAddr {
114+ self . address
115+ }
116+ }
117+
118+ impl std:: fmt:: Display for TestAddress {
119+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
120+ write ! ( f, "{}" , self . address)
121+ }
122+ }
123+
124+ impl Drop for TestAddress {
80125 fn drop ( & mut self ) {
81- if let Some ( token) = & self . cancellation_token {
82- token. cancel ( ) ;
83- }
126+ self . cancellation_token . cancel ( ) ;
84127 }
85128}
86129
@@ -100,3 +143,34 @@ impl Service<Request<Incoming>> for TestServer {
100143 Box :: pin ( async move { Ok ( Response :: builder ( ) . status ( 200 ) . body ( result. into ( ) ) . unwrap ( ) ) } )
101144 }
102145}
146+
147+ #[ cfg( test) ]
148+ mod tests {
149+ use super :: * ;
150+
151+ #[ tokio:: test]
152+ async fn test_server_mock_and_drop ( ) {
153+ let test_path = String :: from ( "/test-path" ) ;
154+ let test_response = String :: from ( "test body" ) ;
155+ let test_server_addr = TestServer :: new ( )
156+ . await
157+ . with_mock_response ( Method :: GET , test_path. clone ( ) , test_response. clone ( ) )
158+ . spawn_listener ( ) ;
159+
160+ let response = reqwest:: get ( format ! ( "http://{}{}" , & test_server_addr, & test_path) )
161+ . await
162+ . unwrap ( )
163+ . text ( )
164+ . await
165+ . unwrap ( ) ;
166+ assert_eq ! ( response, test_response) ;
167+
168+ // Test that dropping TestAddress stops the server
169+ let addr = test_server_addr. address ( ) ;
170+ std:: mem:: drop ( test_server_addr) ;
171+ // wait for the task to complete
172+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
173+ let response = reqwest:: get ( format ! ( "http://{}{}" , & addr, & test_path) ) . await ;
174+ assert ! ( response. is_err( ) ) ;
175+ }
176+ }
0 commit comments