diff --git a/dropshot/src/api_description.rs b/dropshot/src/api_description.rs index 51604716..f1899269 100644 --- a/dropshot/src/api_description.rs +++ b/dropshot/src/api_description.rs @@ -335,11 +335,31 @@ impl ApiEndpointBodyContentType { CONTENT_TYPE_JSON => Ok(Self::Json), CONTENT_TYPE_URL_ENCODED => Ok(Self::UrlEncoded), CONTENT_TYPE_MULTIPART_FORM_DATA => Ok(Self::MultipartFormData), - _ => Err(mime_type.to_string()), + _ => match mime_split(mime_type) { + // We may see content-type that is of the form + // application/XXX+json which means "XXX protocol serialized as + // JSON". A more pedantic implementation might involve a server + // (or subset of its API) indicating that it expects (and + // produces) bodies in a particular format, but for now it + // suffices to treat input bodies of this form as equivalent to + // application/json. + Some(("application", _, Some("json"))) => Ok(Self::Json), + _ => Err(mime_type.to_string()), + }, } } } +/// Split the mime type in to the type, subtype, and optional suffix +/// components. +fn mime_split(mime_type: &str) -> Option<(&str, &str, Option<&str>)> { + let (type_, rest) = mime_type.split_once('/')?; + let mut sub_parts = rest.splitn(2, '+'); + let subtype = sub_parts.next()?; + let suffix = sub_parts.next(); + Some((type_, subtype, suffix)) +} + #[derive(Debug)] pub struct ApiEndpointHeader { pub name: String, diff --git a/dropshot/src/extractor/body.rs b/dropshot/src/extractor/body.rs index 144e7415..d5dde1ff 100644 --- a/dropshot/src/extractor/body.rs +++ b/dropshot/src/extractor/body.rs @@ -121,15 +121,16 @@ impl ExclusiveExtractor for MultipartBody { /// Given an HTTP request, attempt to read the body, parse it according /// to the content type, and deserialize it to an instance of `BodyType`. -async fn http_request_load_body( - rqctx: &RequestContext, +async fn http_request_load_body( request: hyper::Request, + request_body_max_bytes: usize, + expected_body_content_type: &ApiEndpointBodyContentType, ) -> Result, HttpError> where BodyType: JsonSchema + DeserializeOwned + Send + Sync, { let (parts, body) = request.into_parts(); - let body = StreamingBody::new(body, rqctx.request_body_max_bytes()) + let body = StreamingBody::new(body, request_body_max_bytes) .into_bytes_mut() .await?; @@ -150,14 +151,19 @@ where .unwrap_or(Ok(CONTENT_TYPE_JSON))?; let end = content_type.find(';').unwrap_or_else(|| content_type.len()); let mime_type = content_type[..end].trim_end().to_lowercase(); - let body_content_type = - ApiEndpointBodyContentType::from_mime_type(&mime_type) - .map_err(|e| HttpError::for_bad_request(None, e))?; - let expected_content_type = rqctx.endpoint.body_content_type.clone(); + let body_content_type = ApiEndpointBodyContentType::from_mime_type( + &mime_type, + ) + .map_err(|e| { + HttpError::for_bad_request( + None, + format!("unsupported content-type: {}", e), + ) + })?; use ApiEndpointBodyContentType::*; - let content = match (expected_content_type, body_content_type) { + let content = match (expected_body_content_type, body_content_type) { (Json, Json) => { let jd = &mut serde_json::Deserializer::from_slice(&body); serde_path_to_error::deserialize(jd).map_err(|e| { @@ -186,7 +192,7 @@ where expected.mime_type(), requested.mime_type() ), - )) + )); } }; Ok(TypedBody { inner: content }) @@ -207,7 +213,12 @@ where rqctx: &RequestContext, request: hyper::Request, ) -> Result, HttpError> { - http_request_load_body(rqctx, request).await + http_request_load_body( + request, + rqctx.request_body_max_bytes(), + &rqctx.endpoint.body_content_type, + ) + .await } fn metadata(content_type: ApiEndpointBodyContentType) -> ExtractorMetadata { @@ -457,3 +468,32 @@ fn untyped_metadata() -> ExtractorMetadata { extension_mode: ExtensionMode::None, } } + +#[cfg(test)] +mod tests { + use schemars::JsonSchema; + use serde::Deserialize; + + use crate::extractor::body::http_request_load_body; + + #[tokio::test] + async fn test_content_plus_json() { + #[derive(Deserialize, JsonSchema)] + struct TheRealScimShady {} + + let body = "{}"; + let request = hyper::Request::builder() + .header(http::header::CONTENT_TYPE, "application/scim+json") + .body(crate::Body::with_content(body)) + .unwrap(); + + let r = http_request_load_body::( + request, + 9000, + &crate::ApiEndpointBodyContentType::Json, + ) + .await; + + assert!(r.is_ok()) + } +}