diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ade25ba437..37e52632787 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,10 @@ and this project adheres to - [#5222](https://github.com/firecracker-microvm/firecracker/pull/5222): Fixed network and rng devices locking up on hosts with non 4K pages. +- [#5226](https://github.com/firecracker-microvm/firecracker/pull/5226): Fixed + MMDS to set `Content-Type` header correctly (i.e. `Content-Type: text/plain` + for IMDS-formatted or error responses and `Content-Type: application/json` for + JSON-formatted responses). ## [1.12.0] diff --git a/src/vmm/src/mmds/data_store.rs b/src/vmm/src/mmds/data_store.rs index 518ae7d901a..22b2ee215fb 100644 --- a/src/vmm/src/mmds/data_store.rs +++ b/src/vmm/src/mmds/data_store.rs @@ -352,7 +352,8 @@ mod tests { ], "member": false, "shares_percentage": 12.12, - "balance": -24 + "balance": -24, + "json_string": "{\n \"hello\": \"world\"\n}" }"#; let data_store: Value = serde_json::from_str(data).unwrap(); mmds.put_data(data_store).unwrap(); @@ -492,6 +493,18 @@ mod tests { .to_string(), MmdsDatastoreError::UnsupportedValueType.to_string() ); + + // Retrieve a string including escapes. + assert_eq!( + mmds.get_value("/json_string".to_string(), OutputFormat::Json) + .unwrap(), + r#""{\n \"hello\": \"world\"\n}""# + ); + assert_eq!( + mmds.get_value("/json_string".to_string(), OutputFormat::Imds) + .unwrap(), + "{\n \"hello\": \"world\"\n}" + ) } #[test] diff --git a/src/vmm/src/mmds/mod.rs b/src/vmm/src/mmds/mod.rs index cdc0052eb42..7f831d384f8 100644 --- a/src/vmm/src/mmds/mod.rs +++ b/src/vmm/src/mmds/mod.rs @@ -51,8 +51,14 @@ impl From for OutputFormat { } // Builds the `micro_http::Response` with a given HTTP version, status code, and body. -fn build_response(http_version: Version, status_code: StatusCode, body: Body) -> Response { +fn build_response( + http_version: Version, + status_code: StatusCode, + content_type: MediaType, + body: Body, +) -> Response { let mut response = Response::new(http_version, status_code); + response.set_content_type(content_type); response.set_body(body); response } @@ -105,6 +111,7 @@ pub fn convert_to_response(mmds: Arc>, request: Request) -> Response return build_response( request.http_version(), StatusCode::BadRequest, + MediaType::PlainText, Body::new(VmmMmdsError::InvalidURI.to_string()), ); } @@ -125,6 +132,7 @@ fn respond_to_request_mmdsv1(mmds: &Mmds, request: Request) -> Response { let mut response = build_response( request.http_version(), StatusCode::MethodNotAllowed, + MediaType::PlainText, Body::new(VmmMmdsError::MethodNotAllowed.to_string()), ); response.allow_method(Method::Get); @@ -141,6 +149,7 @@ fn respond_to_request_mmdsv2(mmds: &mut Mmds, request: Request) -> Response { return build_response( request.http_version(), StatusCode::BadRequest, + MediaType::PlainText, Body::new(err.to_string()), ); } @@ -154,6 +163,7 @@ fn respond_to_request_mmdsv2(mmds: &mut Mmds, request: Request) -> Response { let mut response = build_response( request.http_version(), StatusCode::MethodNotAllowed, + MediaType::PlainText, Body::new(VmmMmdsError::MethodNotAllowed.to_string()), ); response.allow_method(Method::Get); @@ -176,6 +186,7 @@ fn respond_to_get_request_checked( return build_response( request.http_version(), StatusCode::Unauthorized, + MediaType::PlainText, Body::new(error_msg), ); } @@ -187,6 +198,7 @@ fn respond_to_get_request_checked( Ok(false) => build_response( request.http_version(), StatusCode::Unauthorized, + MediaType::PlainText, Body::new(VmmMmdsError::InvalidToken.to_string()), ), Err(_) => unreachable!(), @@ -200,10 +212,13 @@ fn respond_to_get_request_unchecked(mmds: &Mmds, request: Request) -> Response { // sanitize the URI. let json_path = sanitize_uri(uri.to_string()); - match mmds.get_value(json_path, request.headers.accept().into()) { + let content_type = request.headers.accept(); + + match mmds.get_value(json_path, content_type.into()) { Ok(response_body) => build_response( request.http_version(), StatusCode::OK, + content_type, Body::new(response_body), ), Err(err) => match err { @@ -212,17 +227,20 @@ fn respond_to_get_request_unchecked(mmds: &Mmds, request: Request) -> Response { build_response( request.http_version(), StatusCode::NotFound, + MediaType::PlainText, Body::new(error_msg), ) } MmdsError::UnsupportedValueType => build_response( request.http_version(), StatusCode::NotImplemented, + MediaType::PlainText, Body::new(err.to_string()), ), MmdsError::DataStoreLimitExceeded => build_response( request.http_version(), StatusCode::PayloadTooLarge, + MediaType::PlainText, Body::new(err.to_string()), ), _ => unreachable!(), @@ -248,6 +266,7 @@ fn respond_to_put_request( return build_response( request.http_version(), StatusCode::BadRequest, + MediaType::PlainText, Body::new(error_msg), ); } @@ -262,6 +281,7 @@ fn respond_to_put_request( return build_response( request.http_version(), StatusCode::NotFound, + MediaType::PlainText, Body::new(error_msg), ); } @@ -273,6 +293,7 @@ fn respond_to_put_request( return build_response( request.http_version(), StatusCode::BadRequest, + MediaType::PlainText, Body::new(VmmMmdsError::NoTtlProvided.to_string()), ); } @@ -281,15 +302,16 @@ fn respond_to_put_request( // Generate token. let result = mmds.generate_token(ttl_seconds); match result { - Ok(token) => { - let mut response = - build_response(request.http_version(), StatusCode::OK, Body::new(token)); - response.set_content_type(MediaType::PlainText); - response - } + Ok(token) => build_response( + request.http_version(), + StatusCode::OK, + MediaType::PlainText, + Body::new(token), + ), Err(err) => build_response( request.http_version(), StatusCode::BadRequest, + MediaType::PlainText, Body::new(err.to_string()), ), } @@ -343,6 +365,31 @@ mod tests { }"# } + fn get_plain_text_data() -> &'static str { + "age\nname/\nphones/" + } + + fn generate_request_and_expected_response( + request_bytes: &[u8], + media_type: MediaType, + ) -> (Request, Response) { + let request = Request::try_from(request_bytes, None).unwrap(); + + let mut response = Response::new(Version::Http10, StatusCode::OK); + response.set_content_type(media_type); + let body = match media_type { + MediaType::ApplicationJson => { + let mut body = get_json_data().to_string(); + body.retain(|c| !c.is_whitespace()); + body + } + MediaType::PlainText => get_plain_text_data().to_string(), + }; + response.set_body(Body::new(body)); + + (request, response) + } + #[test] fn test_sanitize_uri() { let sanitized = "/a/b/c/d"; @@ -362,6 +409,66 @@ mod tests { assert_eq!(sanitize_uri("//aa//bb///cc//d".to_owned()), "/aa/bb/cc/d"); } + #[test] + fn test_request_accept_header() { + // This test validates the response `Content-Type` header and the response content for + // various request `Accept` headers. + + // Populate MMDS with data. + let mmds = populate_mmds(); + + // Test without `Accept` header. micro-http defaults to `Accept: text/plain`. + let (request, expected_response) = generate_request_and_expected_response( + b"GET http://169.254.169.254/ HTTP/1.0\r\n\r\n", + MediaType::PlainText, + ); + assert_eq!( + convert_to_response(mmds.clone(), request), + expected_response + ); + + // Test with empty `Accept` header. micro-http defaults to `Accept: text/plain`. + let (request, expected_response) = generate_request_and_expected_response( + b"GET http://169.254.169.254/ HTTP/1.0\r\n\" + Accept:\r\n\r\n", + MediaType::PlainText, + ); + assert_eq!( + convert_to_response(mmds.clone(), request), + expected_response + ); + + // Test with `Accept: */*` header. + let (request, expected_response) = generate_request_and_expected_response( + b"GET http://169.254.169.254/ HTTP/1.0\r\n\" + Accept: */*\r\n\r\n", + MediaType::PlainText, + ); + assert_eq!( + convert_to_response(mmds.clone(), request), + expected_response + ); + + // Test with `Accept: text/plain`. + let (request, expected_response) = generate_request_and_expected_response( + b"GET http://169.254.169.254/ HTTP/1.0\r\n\ + Accept: text/plain\r\n\r\n", + MediaType::PlainText, + ); + assert_eq!( + convert_to_response(mmds.clone(), request), + expected_response + ); + + // Test with `Accept: application/json`. + let (request, expected_response) = generate_request_and_expected_response( + b"GET http://169.254.169.254/ HTTP/1.0\r\n\ + Accept: application/json\r\n\r\n", + MediaType::ApplicationJson, + ); + assert_eq!(convert_to_response(mmds, request), expected_response); + } + #[test] fn test_respond_to_request_mmdsv1() { // Populate MMDS with data. @@ -381,6 +488,7 @@ mod tests { let request_bytes = b"GET http://169.254.169.254/invalid HTTP/1.0\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::NotFound); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new( VmmMmdsError::ResourceNotFound(String::from("/invalid")).to_string(), )); @@ -391,6 +499,7 @@ mod tests { let request_bytes = b"GET /age HTTP/1.1\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http11, StatusCode::NotImplemented); + expected_response.set_content_type(MediaType::PlainText); let body = "Cannot retrieve value. The value has an unsupported type.".to_string(); expected_response.set_body(Body::new(body)); let actual_response = convert_to_response(mmds.clone(), request); @@ -403,6 +512,7 @@ mod tests { let request = Request::try_from(request_bytes.as_bytes(), None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::MethodNotAllowed); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::MethodNotAllowed.to_string())); expected_response.allow_method(Method::Get); let actual_response = convert_to_response(mmds.clone(), request); @@ -413,6 +523,7 @@ mod tests { let request_bytes = b"GET http:// HTTP/1.0\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::InvalidURI.to_string())); let actual_response = convert_to_response(mmds.clone(), request); assert_eq!(actual_response, expected_response); @@ -458,6 +569,7 @@ mod tests { let request_bytes = b"PATCH http://169.254.169.255/ HTTP/1.0\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::MethodNotAllowed); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::MethodNotAllowed.to_string())); expected_response.allow_method(Method::Get); expected_response.allow_method(Method::Put); @@ -470,6 +582,7 @@ mod tests { X-metadata-token-ttl-seconds: application/json\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new( "Invalid header. Reason: Invalid value. Key:X-metadata-token-ttl-seconds; \ Value:application/json" @@ -484,6 +597,7 @@ mod tests { X-Forwarded-For: 203.0.113.195\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new( "Invalid header. Reason: Unsupported header name. Key: X-Forwarded-For".to_string(), )); @@ -495,6 +609,7 @@ mod tests { X-metadata-token-ttl-seconds: 60\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::NotFound); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new( VmmMmdsError::ResourceNotFound(String::from("/token")).to_string(), )); @@ -511,6 +626,7 @@ mod tests { ); let request = Request::try_from(request_bytes.as_bytes(), None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest); + expected_response.set_content_type(MediaType::PlainText); let error_msg = format!( "Invalid time to live value provided for token: {}. Please provide a value \ between {} and {}.", @@ -525,6 +641,7 @@ mod tests { let request_bytes = b"PUT http://169.254.169.254/latest/api/token HTTP/1.0\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::BadRequest); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::NoTtlProvided.to_string())); let actual_response = convert_to_response(mmds.clone(), request); assert_eq!(actual_response, expected_response); @@ -559,6 +676,7 @@ mod tests { ); let request = Request::try_from(request_bytes.as_bytes(), None).unwrap(); let mut expected_response = Response::new(Version::Http11, StatusCode::NotImplemented); + expected_response.set_content_type(MediaType::PlainText); let body = "Cannot retrieve value. The value has an unsupported type.".to_string(); expected_response.set_body(Body::new(body)); let actual_response = convert_to_response(mmds.clone(), request); @@ -571,6 +689,7 @@ mod tests { ); let request = Request::try_from(request_bytes.as_bytes(), None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::NotFound); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new( VmmMmdsError::ResourceNotFound(String::from("/invalid")).to_string(), )); @@ -581,6 +700,7 @@ mod tests { let request_bytes = b"GET http://169.254.169.254/ HTTP/1.0\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::Unauthorized); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::NoTokenProvided.to_string())); let actual_response = convert_to_response(mmds.clone(), request); assert_eq!(actual_response, expected_response); @@ -590,6 +710,7 @@ mod tests { X-metadata-token: foo\r\n\r\n"; let request = Request::try_from(request_bytes, None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::Unauthorized); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::InvalidToken.to_string())); let actual_response = convert_to_response(mmds.clone(), request); assert_eq!(actual_response, expected_response); @@ -614,6 +735,7 @@ mod tests { ); let request = Request::try_from(request_bytes.as_bytes(), None).unwrap(); let mut expected_response = Response::new(Version::Http10, StatusCode::Unauthorized); + expected_response.set_content_type(MediaType::PlainText); expected_response.set_body(Body::new(VmmMmdsError::InvalidToken.to_string())); let actual_response = convert_to_response(mmds.clone(), request); assert_eq!(actual_response, expected_response);