@@ -51,8 +51,16 @@ impl From<MediaType> for OutputFormat {
5151} 
5252
5353// Builds the `micro_http::Response` with a given HTTP version, status code, and body. 
54- fn  build_response ( http_version :  Version ,  status_code :  StatusCode ,  body :  Body )  -> Response  { 
54+ fn  build_response ( 
55+     http_version :  Version , 
56+     status_code :  StatusCode , 
57+     content_type :  Option < MediaType > , 
58+     body :  Body , 
59+ )  -> Response  { 
5560    let  mut  response = Response :: new ( http_version,  status_code) ; 
61+     if  let  Some ( content_type)  = content_type { 
62+         response. set_content_type ( content_type) ; 
63+     } 
5664    response. set_body ( body) ; 
5765    response
5866} 
@@ -105,6 +113,7 @@ pub fn convert_to_response(mmds: Arc<Mutex<Mmds>>, request: Request) -> Response
105113        return  build_response ( 
106114            request. http_version ( ) , 
107115            StatusCode :: BadRequest , 
116+             None , 
108117            Body :: new ( VmmMmdsError :: InvalidURI . to_string ( ) ) , 
109118        ) ; 
110119    } 
@@ -125,6 +134,7 @@ fn respond_to_request_mmdsv1(mmds: &Mmds, request: Request) -> Response {
125134            let  mut  response = build_response ( 
126135                request. http_version ( ) , 
127136                StatusCode :: MethodNotAllowed , 
137+                 None , 
128138                Body :: new ( VmmMmdsError :: MethodNotAllowed . to_string ( ) ) , 
129139            ) ; 
130140            response. allow_method ( Method :: Get ) ; 
@@ -141,6 +151,7 @@ fn respond_to_request_mmdsv2(mmds: &mut Mmds, request: Request) -> Response {
141151            return  build_response ( 
142152                request. http_version ( ) , 
143153                StatusCode :: BadRequest , 
154+                 None , 
144155                Body :: new ( err. to_string ( ) ) , 
145156            ) ; 
146157        } 
@@ -154,6 +165,7 @@ fn respond_to_request_mmdsv2(mmds: &mut Mmds, request: Request) -> Response {
154165            let  mut  response = build_response ( 
155166                request. http_version ( ) , 
156167                StatusCode :: MethodNotAllowed , 
168+                 None , 
157169                Body :: new ( VmmMmdsError :: MethodNotAllowed . to_string ( ) ) , 
158170            ) ; 
159171            response. allow_method ( Method :: Get ) ; 
@@ -176,6 +188,7 @@ fn respond_to_get_request_checked(
176188            return  build_response ( 
177189                request. http_version ( ) , 
178190                StatusCode :: Unauthorized , 
191+                 None , 
179192                Body :: new ( error_msg) , 
180193            ) ; 
181194        } 
@@ -187,6 +200,7 @@ fn respond_to_get_request_checked(
187200        Ok ( false )  => build_response ( 
188201            request. http_version ( ) , 
189202            StatusCode :: Unauthorized , 
203+             None , 
190204            Body :: new ( VmmMmdsError :: InvalidToken . to_string ( ) ) , 
191205        ) , 
192206        Err ( _)  => unreachable ! ( ) , 
@@ -200,10 +214,13 @@ fn respond_to_get_request_unchecked(mmds: &Mmds, request: Request) -> Response {
200214    // sanitize the URI. 
201215    let  json_path = sanitize_uri ( uri. to_string ( ) ) ; 
202216
203-     match  mmds. get_value ( json_path,  request. headers . accept ( ) . into ( ) )  { 
217+     let  content_type = request. headers . accept ( ) ; 
218+ 
219+     match  mmds. get_value ( json_path,  content_type. into ( ) )  { 
204220        Ok ( response_body)  => build_response ( 
205221            request. http_version ( ) , 
206222            StatusCode :: OK , 
223+             Some ( content_type) , 
207224            Body :: new ( response_body) , 
208225        ) , 
209226        Err ( err)  => match  err { 
@@ -212,17 +229,20 @@ fn respond_to_get_request_unchecked(mmds: &Mmds, request: Request) -> Response {
212229                build_response ( 
213230                    request. http_version ( ) , 
214231                    StatusCode :: NotFound , 
232+                     None , 
215233                    Body :: new ( error_msg) , 
216234                ) 
217235            } 
218236            MmdsError :: UnsupportedValueType  => build_response ( 
219237                request. http_version ( ) , 
220238                StatusCode :: NotImplemented , 
239+                 None , 
221240                Body :: new ( err. to_string ( ) ) , 
222241            ) , 
223242            MmdsError :: DataStoreLimitExceeded  => build_response ( 
224243                request. http_version ( ) , 
225244                StatusCode :: PayloadTooLarge , 
245+                 None , 
226246                Body :: new ( err. to_string ( ) ) , 
227247            ) , 
228248            _ => unreachable ! ( ) , 
@@ -248,6 +268,7 @@ fn respond_to_put_request(
248268        return  build_response ( 
249269            request. http_version ( ) , 
250270            StatusCode :: BadRequest , 
271+             None , 
251272            Body :: new ( error_msg) , 
252273        ) ; 
253274    } 
@@ -262,6 +283,7 @@ fn respond_to_put_request(
262283        return  build_response ( 
263284            request. http_version ( ) , 
264285            StatusCode :: NotFound , 
286+             None , 
265287            Body :: new ( error_msg) , 
266288        ) ; 
267289    } 
@@ -273,6 +295,7 @@ fn respond_to_put_request(
273295            return  build_response ( 
274296                request. http_version ( ) , 
275297                StatusCode :: BadRequest , 
298+                 None , 
276299                Body :: new ( VmmMmdsError :: NoTtlProvided . to_string ( ) ) , 
277300            ) ; 
278301        } 
@@ -281,15 +304,16 @@ fn respond_to_put_request(
281304    // Generate token. 
282305    let  result = mmds. generate_token ( ttl_seconds) ; 
283306    match  result { 
284-         Ok ( token)  => { 
285-             let   mut  response = 
286-                  build_response ( request . http_version ( ) ,   StatusCode :: OK ,   Body :: new ( token ) ) ; 
287-             response . set_content_type ( MediaType :: PlainText ) ; 
288-             response 
289-         } 
307+         Ok ( token)  => build_response ( 
308+             request . http_version ( ) , 
309+             StatusCode :: OK , 
310+             Some ( MediaType :: PlainText ) , 
311+             Body :: new ( token ) , 
312+         ) , 
290313        Err ( err)  => build_response ( 
291314            request. http_version ( ) , 
292315            StatusCode :: BadRequest , 
316+             None , 
293317            Body :: new ( err. to_string ( ) ) , 
294318        ) , 
295319    } 
@@ -343,6 +367,31 @@ mod tests {
343367        }"# 
344368    } 
345369
370+     fn  get_plain_text_data ( )  -> & ' static  str  { 
371+         "age\n name/\n phones/" 
372+     } 
373+ 
374+     fn  generate_request_and_expected_response ( 
375+         request_bytes :  & [ u8 ] , 
376+         media_type :  MediaType , 
377+     )  -> ( Request ,  Response )  { 
378+         let  request = Request :: try_from ( request_bytes,  None ) . unwrap ( ) ; 
379+ 
380+         let  mut  response = Response :: new ( Version :: Http10 ,  StatusCode :: OK ) ; 
381+         response. set_content_type ( media_type) ; 
382+         let  body = match  media_type { 
383+             MediaType :: ApplicationJson  => { 
384+                 let  mut  body = get_json_data ( ) . to_string ( ) ; 
385+                 body. retain ( |c| !c. is_whitespace ( ) ) ; 
386+                 body
387+             } 
388+             MediaType :: PlainText  => get_plain_text_data ( ) . to_string ( ) , 
389+         } ; 
390+         response. set_body ( Body :: new ( body) ) ; 
391+ 
392+         ( request,  response) 
393+     } 
394+ 
346395    #[ test]  
347396    fn  test_sanitize_uri ( )  { 
348397        let  sanitized = "/a/b/c/d" ; 
@@ -362,6 +411,66 @@ mod tests {
362411        assert_eq ! ( sanitize_uri( "//aa//bb///cc//d" . to_owned( ) ) ,  "/aa/bb/cc/d" ) ; 
363412    } 
364413
414+     #[ test]  
415+     fn  test_request_accept_header ( )  { 
416+         // This test validates the response `Content-Type` header and the response content for 
417+         // various request `Accept` headers. 
418+ 
419+         // Populate MMDS with data. 
420+         let  mmds = populate_mmds ( ) ; 
421+ 
422+         // Test without `Accept` header. micro-http defaults to `Accept: text/plain`. 
423+         let  ( request,  expected_response)  = generate_request_and_expected_response ( 
424+             b"GET http://169.254.169.254/ HTTP/1.0\r \n \r \n " , 
425+             MediaType :: PlainText , 
426+         ) ; 
427+         assert_eq ! ( 
428+             convert_to_response( mmds. clone( ) ,  request) , 
429+             expected_response
430+         ) ; 
431+ 
432+         // Test with empty `Accept` header. micro-http defaults to `Accept: text/plain`. 
433+         let  ( request,  expected_response)  = generate_request_and_expected_response ( 
434+             b"GET http://169.254.169.254/ HTTP/1.0\r \n \"  
435+             Accept:\r \n \r \n " , 
436+             MediaType :: PlainText , 
437+         ) ; 
438+         assert_eq ! ( 
439+             convert_to_response( mmds. clone( ) ,  request) , 
440+             expected_response
441+         ) ; 
442+ 
443+         // Test with `Accept: */*` header. 
444+         let  ( request,  expected_response)  = generate_request_and_expected_response ( 
445+             b"GET http://169.254.169.254/ HTTP/1.0\r \n \"  
446+             Accept: */*\r \n \r \n " , 
447+             MediaType :: PlainText , 
448+         ) ; 
449+         assert_eq ! ( 
450+             convert_to_response( mmds. clone( ) ,  request) , 
451+             expected_response
452+         ) ; 
453+ 
454+         // Test with `Accept: text/plain`. 
455+         let  ( request,  expected_response)  = generate_request_and_expected_response ( 
456+             b"GET http://169.254.169.254/ HTTP/1.0\r \n \  
457+ \r \n \r \n ", 
458+             MediaType :: PlainText , 
459+         ) ; 
460+         assert_eq ! ( 
461+             convert_to_response( mmds. clone( ) ,  request) , 
462+             expected_response
463+         ) ; 
464+ 
465+         // Test with `Accept: application/json`. 
466+         let  ( request,  expected_response)  = generate_request_and_expected_response ( 
467+             b"GET http://169.254.169.254/ HTTP/1.0\r \n \  
468+ \r \n \r \n ", 
469+             MediaType :: ApplicationJson , 
470+         ) ; 
471+         assert_eq ! ( convert_to_response( mmds,  request) ,  expected_response) ; 
472+     } 
473+ 
365474    #[ test]  
366475    fn  test_respond_to_request_mmdsv1 ( )  { 
367476        // Populate MMDS with data. 
0 commit comments