Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion http-cache-reqwest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ anyhow = "1.0.95"
async-trait = "0.1.85"
http = "1.2.0"
http-cache-semantics = "2.1.0"
reqwest = { version = "0.12.12", default-features = false }
reqwest = { version = "0.12.12", default-features = false, features = ["stream"] }
reqwest-middleware = "0.4.0"
serde = { version = "1.0.217", features = ["derive"] }
url = { version = "2.5.4", features = ["serde"] }
Expand Down
34 changes: 15 additions & 19 deletions http-cache-reqwest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ use reqwest_middleware::{Error, Next};
use url::Url;

pub use http_cache::{
CacheManager, CacheMode, CacheOptions, HttpCache, HttpCacheOptions,
HttpResponse,
Body, CacheManager, CacheMode, CacheOptions, HttpCache, HttpCacheOptions,
HttpResponse, Parts as HttpParts,
};

#[cfg(feature = "manager-cacache")]
Expand Down Expand Up @@ -169,29 +169,25 @@ impl Middleware for ReqwestMiddleware<'_> {
let url = res.url().clone();
let status = res.status().into();
let version = res.version();
let body: Vec<u8> = match res.bytes().await {
Ok(b) => b,
Err(e) => return Err(Box::new(e)),
}
.to_vec();
Ok(HttpResponse {
body,
headers,
status,
url,
version: version.try_into()?,
})
let parts =
HttpParts { headers, status, url, version: version.try_into()? };
Ok(HttpResponse::from_parts(
parts,
Body::wrap_stream(res.bytes_stream()),
))
}
}

// Converts an [`HttpResponse`] to a reqwest [`Response`]
fn convert_response(response: HttpResponse) -> anyhow::Result<Response> {
let (parts, body) = response.into_parts();
let body = reqwest::Body::wrap_stream(body.into_data_stream());
let mut ret_res = http::Response::builder()
.status(response.status)
.url(response.url)
.version(response.version.into())
.body(response.body)?;
for header in response.headers {
.status(parts.status)
.url(parts.url)
.version(parts.version.into())
.body(body)?;
for header in parts.headers {
ret_res.headers_mut().insert(
HeaderName::from_str(header.0.clone().as_str())?,
HeaderValue::from_str(header.1.clone().as_str())?,
Expand Down
10 changes: 9 additions & 1 deletion http-cache/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,33 @@ rust-version = "1.71.1"
[dependencies]
async-trait = "0.1.85"
bincode = { version = "1.3.3", optional = true }
bytes = "1.10.1"
cacache = { version = "13.1.0", default-features = false, features = ["mmap"], optional = true }
futures = "0.3.31"
futures-util = "0.3.31"
http = "1.2.0"
http-body = "1.0.1"
http-body-util = "0.1.3"
http-cache-semantics = "2.1.0"
http-types = { version = "2.12.0", default-features = false, optional = true }
httpdate = "1.0.3"
moka = { version = "0.12.10", features = ["future"], optional = true }
serde = { version = "1.0.217", features = ["derive"] }
tokio = { version = "1", default-features = false, features = ["io-util"], optional = true }
tokio-util = { version = "0.7.14", features = ["io"], optional = true }
url = { version = "2.5.4", features = ["serde"] }

[dev-dependencies]
async-attributes = "1.1.2"
async-std = { version = "1.13.0" }
http-cache-semantics = "2.1.0"
tempfile = "3.19.1"
tokio = { version = "1.43.0", features = [ "macros", "rt", "rt-multi-thread" ] }

[features]
default = ["manager-cacache", "cacache-async-std"]
manager-cacache = ["cacache", "bincode"]
cacache-tokio = ["cacache/tokio-runtime"]
cacache-tokio = ["cacache/tokio-runtime", "tokio", "tokio-util"]
cacache-async-std = ["cacache/async-std"]
manager-moka = ["moka", "bincode"]
with-http-types = ["http-types"]
Expand Down
153 changes: 131 additions & 22 deletions http-cache/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ use std::{
time::SystemTime,
};

use bytes::{BufMut, Bytes};
use futures::StreamExt;
use http::{header::CACHE_CONTROL, request, response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyDataStream, BodyExt, Full};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use serde::{Deserialize, Serialize};
use url::Url;
Expand Down Expand Up @@ -117,10 +121,99 @@ impl fmt::Display for HttpVersion {
}

/// A basic generic type that represents an HTTP response
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug)]
pub struct HttpResponse {
/// HTTP response body
pub body: Vec<u8>,
body: Body,
/// HTTP response parts
parts: Parts,
}

/// HTTP response body.
#[derive(Debug)]
pub struct Body {
inner: BodyInner,
}

#[derive(Debug)]
enum BodyInner {
Full(Bytes),
Streaming(BoxBody<Bytes, BoxError>),
}

impl Body {
/// wrap stream
pub fn wrap_stream<S>(stream: S) -> Body
where
S: futures::stream::TryStream + Send + Sync + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
Bytes: From<S::Ok>,
{
use futures_util::TryStreamExt;
use http_body::Frame;
use http_body_util::StreamBody;

let body = BoxBody::new(StreamBody::new(
stream.map_ok(|d| Frame::data(Bytes::from(d))).map_err(Into::into),
));
Body { inner: BodyInner::Streaming(body) }
}

/// Get body bytes if body is full.
pub fn as_bytes(&self) -> Option<&[u8]> {
match &self.inner {
BodyInner::Full(bytes) => Some(bytes),
BodyInner::Streaming(_) => None,
}
}

/// Get all bytes of the response, collecting data stream if some.
pub async fn bytes(self) -> Result<Bytes> {
Ok(match self.inner {
BodyInner::Full(bytes) => bytes,
BodyInner::Streaming(boxed_body) => boxed_body
.into_data_stream()
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?
.into_iter()
.fold(bytes::BytesMut::new(), |mut acc, chunk| {
acc.put(chunk);
acc
})
.freeze(),
})
}

/// Into data stream
pub fn into_data_stream(self) -> BodyDataStream<BoxBody<Bytes, BoxError>> {
match self.inner {
BodyInner::Full(data) => {
Full::new(data).map_err(Into::into).boxed().into_data_stream()
}
BodyInner::Streaming(boxed_body) => boxed_body.into_data_stream(),
}
}
}

impl From<Vec<u8>> for Body {
fn from(value: Vec<u8>) -> Self {
Self { inner: BodyInner::Full(value.into()) }
}
}

impl From<Bytes> for Body {
fn from(value: Bytes) -> Self {
Self { inner: BodyInner::Full(value) }
}
}

/// HTTP response parts consists of status, version, response URL and headers.
///
/// Serializable alternative to [`http::response::Parts`].
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Parts {
/// HTTP response headers
pub headers: HashMap<String, String>,
/// HTTP response status code
Expand All @@ -132,13 +225,23 @@ pub struct HttpResponse {
}

impl HttpResponse {
/// Consumes the response returning the head and body parts.
pub fn into_parts(self) -> (Parts, Body) {
(self.parts, self.body)
}

/// Creates a new Response with the given head and body.
pub fn from_parts(parts: Parts, body: Body) -> Self {
Self { body, parts }
}

/// Returns `http::response::Parts`
pub fn parts(&self) -> Result<response::Parts> {
let mut converted =
response::Builder::new().status(self.status).body(())?;
response::Builder::new().status(self.parts.status).body(())?;
{
let headers = converted.headers_mut();
for header in &self.headers {
for header in &self.parts.headers {
headers.insert(
http::header::HeaderName::from_str(header.0.as_str())?,
http::HeaderValue::from_str(header.1.as_str())?,
Expand All @@ -151,7 +254,7 @@ impl HttpResponse {
/// Returns the status code of the warning header if present
#[must_use]
pub fn warning_code(&self) -> Option<usize> {
self.headers.get("warning").and_then(|hdr| {
self.parts.headers.get("warning").and_then(|hdr| {
hdr.as_str().chars().take(3).collect::<String>().parse().ok()
})
}
Expand All @@ -167,7 +270,7 @@ impl HttpResponse {
// warn-text = quoted-string
// warn-date = <"> HTTP-date <">
// (https://tools.ietf.org/html/rfc2616#section-14.46)
self.headers.insert(
self.parts.headers.insert(
"warning".to_string(),
format!(
"{} {} {:?} \"{}\"",
Expand All @@ -181,13 +284,13 @@ impl HttpResponse {

/// Removes a warning header from a response
pub fn remove_warning(&mut self) {
self.headers.remove("warning");
self.parts.headers.remove("warning");
}

/// Update the headers from `http::response::Parts`
pub fn update_headers(&mut self, parts: &response::Parts) -> Result<()> {
for header in parts.headers.iter() {
self.headers.insert(
self.parts.headers.insert(
header.0.as_str().to_string(),
header.1.to_str()?.to_string(),
);
Expand All @@ -198,23 +301,27 @@ impl HttpResponse {
/// Checks if the Cache-Control header contains the must-revalidate directive
#[must_use]
pub fn must_revalidate(&self) -> bool {
self.headers.get(CACHE_CONTROL.as_str()).is_some_and(|val| {
self.parts.headers.get(CACHE_CONTROL.as_str()).is_some_and(|val| {
val.as_str().to_lowercase().contains("must-revalidate")
})
}

/// Adds the custom `x-cache` header to the response
pub fn cache_status(&mut self, hit_or_miss: HitOrMiss) {
self.headers.insert(XCACHE.to_string(), hit_or_miss.to_string());
self.parts.headers.insert(XCACHE.to_string(), hit_or_miss.to_string());
}

/// Adds the custom `x-cache-lookup` header to the response
pub fn cache_lookup_status(&mut self, hit_or_miss: HitOrMiss) {
self.headers.insert(XCACHELOOKUP.to_string(), hit_or_miss.to_string());
self.parts
.headers
.insert(XCACHELOOKUP.to_string(), hit_or_miss.to_string());
}
}

/// A trait providing methods for storing, reading, and removing cache records.
///
/// Generic argument `R` defines the type of HTTP response body which may be put into cache.
#[async_trait::async_trait]
pub trait CacheManager: Send + Sync + 'static {
/// Attempts to pull a cached response and related policy from cache.
Expand Down Expand Up @@ -555,7 +662,7 @@ impl<T: CacheManager> HttpCache<T> {
// the rest of the network for a period of time.
// (https://tools.ietf.org/html/rfc2616#section-14.46)
res.add_warning(
&res.url.clone(),
&res.parts.url.clone(),
112,
"Disconnected operation",
);
Expand All @@ -571,11 +678,13 @@ impl<T: CacheManager> HttpCache<T> {
CacheMode::OnlyIfCached => {
// ENOTCACHED
let mut res = HttpResponse {
body: b"GatewayTimeout".to_vec(),
headers: HashMap::default(),
status: 504,
url: middleware.url()?,
version: HttpVersion::Http11,
body: b"GatewayTimeout".to_vec().into(),
parts: Parts {
headers: HashMap::default(),
status: 504,
url: middleware.url()?,
version: HttpVersion::Http11,
},
};
if self.options.cache_status_headers {
res.cache_status(HitOrMiss::MISS);
Expand Down Expand Up @@ -615,9 +724,9 @@ impl<T: CacheManager> HttpCache<T> {
let mode = self.cache_mode(middleware)?;
let mut is_cacheable = is_get_head
&& mode != CacheMode::NoStore
&& res.status == 200
&& res.parts.status == 200
&& policy.is_storable();
if mode == CacheMode::IgnoreRules && res.status == 200 {
if mode == CacheMode::IgnoreRules && res.parts.status == 200 {
is_cacheable = true;
}
if is_cacheable {
Expand Down Expand Up @@ -670,7 +779,7 @@ impl<T: CacheManager> HttpCache<T> {
let req_url = middleware.url()?;
match middleware.remote_fetch().await {
Ok(mut cond_res) => {
let status = StatusCode::from_u16(cond_res.status)?;
let status = StatusCode::from_u16(cond_res.parts.status)?;
if status.is_server_error() && cached_res.must_revalidate() {
// 111 Revalidation failed
// MUST be included if a cache returns a stale response
Expand All @@ -686,7 +795,7 @@ impl<T: CacheManager> HttpCache<T> {
cached_res.cache_status(HitOrMiss::HIT);
}
Ok(cached_res)
} else if cond_res.status == 304 {
} else if cond_res.parts.status == 304 {
let after_res = policy.after_response(
&middleware.parts()?,
&cond_res.parts()?,
Expand All @@ -713,7 +822,7 @@ impl<T: CacheManager> HttpCache<T> {
)
.await?;
Ok(res)
} else if cond_res.status == 200 {
} else if cond_res.parts.status == 200 {
let policy = match self.options.cache_options {
Some(options) => middleware
.policy_with_options(&cond_res, options)?,
Expand Down
Loading