diff --git a/.gitignore b/.gitignore index 023b666..5e26867 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,8 @@ Cargo.lock http-cacache/ /.idea /public -/.DS_Store +**/.DS_Store + +# Cache directories from examples +**/cache/ +**/cache-*/ diff --git a/Cargo.toml b/Cargo.toml index 14fae20..29f90ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,6 @@ members = [ "http-cache-surf", "http-cache-quickcache", "http-cache-tower", + "http-cache-tower-server", "http-cache-ureq" ] \ No newline at end of file diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 190bdd6..39fdfe3 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -6,11 +6,13 @@ - [Development](./development/development.md) - [Supporting a Backend Cache Manager](./development/supporting-a-backend-cache-manager.md) - [Supporting an HTTP Client](./development/supporting-an-http-client.md) -- [Client Implementations](./clients/clients.md) +- [Client-Side Caching](./clients/clients.md) - [reqwest](./clients/reqwest.md) - [surf](./clients/surf.md) - [ureq](./clients/ureq.md) - [tower](./clients/tower.md) +- [Server-Side Caching](./server/server.md) + - [tower-server](./server/tower-server.md) - [Backend Cache Manager Implementations](./managers/managers.md) - [cacache](./managers/cacache.md) - [moka](./managers/moka.md) diff --git a/docs/src/clients/clients.md b/docs/src/clients/clients.md index b34d492..72d105c 100644 --- a/docs/src/clients/clients.md +++ b/docs/src/clients/clients.md @@ -1,4 +1,17 @@ -# Client Implementations +# Client-Side Caching + +These middleware implementations cache responses from external APIs that your application calls. This is different from server-side caching, which caches your own application's responses. + +**Use client-side caching when:** +- Calling external APIs +- Reducing API rate limit consumption +- Improving offline support +- Reducing bandwidth usage +- Speeding up repeated API calls + +**For server-side caching** (caching your own app's responses), see [Server-Side Caching](../server/server.md). + +## Available Client Implementations The following client implementations are provided by this crate: diff --git a/docs/src/introduction.md b/docs/src/introduction.md index b521d49..e9190b9 100644 --- a/docs/src/introduction.md +++ b/docs/src/introduction.md @@ -1,19 +1,62 @@ # Introduction -`http-cache` is a library that acts as a middleware for caching HTTP responses. It is intended to be used by other libraries to support multiple HTTP clients and backend cache managers, though it does come with multiple optional manager implementations out of the box. `http-cache` is built on top of [`http-cache-semantics`](https://github.com/kornelski/rusty-http-cache-semantics) which parses HTTP headers to correctly compute cacheability of responses. +`http-cache` is a comprehensive library for HTTP response caching in Rust. It provides both **client-side** and **server-side** caching middleware for multiple HTTP clients and frameworks. Built on top of [`http-cache-semantics`](https://github.com/kornelski/rusty-http-cache-semantics), it correctly implements HTTP cache semantics as defined in RFC 7234. ## Key Features +- **Client-Side Caching**: Cache responses from external APIs you're calling +- **Server-Side Caching**: Cache your own application's responses to reduce load - **Traditional Caching**: Standard HTTP response caching with full buffering - **Streaming Support**: Memory-efficient caching for large responses without full buffering -- **Cache-Aware Rate Limiting**: Intelligent rate limiting that only applies on cache misses, not cache hits +- **Cache-Aware Rate Limiting**: Intelligent rate limiting that only applies on cache misses - **Multiple Backends**: Support for disk-based (cacache) and in-memory (moka, quick-cache) storage - **Client Integrations**: Support for reqwest, surf, tower, and ureq HTTP clients +- **Server Framework Support**: Tower-based servers (Axum, Hyper, Tonic) - **RFC 7234 Compliance**: Proper HTTP cache semantics with respect for cache-control headers +## Client-Side vs Server-Side Caching + +### Client-Side Caching + +Cache responses from external APIs your application calls: + +```rust +// Example: Caching API responses you fetch +let client = reqwest::Client::new(); +let cached_client = HttpCache::new(client, cache_manager); +let response = cached_client.get("https://api.example.com/users").send().await?; +``` + +**Use cases:** +- Reducing calls to external APIs +- Offline support +- Bandwidth optimization +- Rate limit compliance + +### Server-Side Caching + +Cache responses your application generates: + +```rust +// Example: Caching your own endpoint responses +let app = Router::new() + .route("/users/:id", get(get_user)) + .layer(ServerCacheLayer::new(cache_manager)); // Cache your responses +``` + +**Use cases:** +- Reducing database queries +- Caching expensive computations +- Improving response times +- Reducing server load + +**Critical:** Server-side cache middleware must be placed **after** routing to preserve request context (path parameters, state, etc.). + ## Streaming vs Traditional Caching The library supports two caching approaches: - **Traditional Caching** (`CacheManager` trait): Buffers entire responses in memory before caching. Suitable for smaller responses and simpler use cases. - **Streaming Caching** (`StreamingCacheManager` trait): Processes responses as streams without full buffering. Ideal for large files, media content, or memory-constrained environments. + +Note: Streaming is currently only available for client-side caching. Server-side caching uses buffered responses. diff --git a/docs/src/server/server.md b/docs/src/server/server.md new file mode 100644 index 0000000..0fc361f --- /dev/null +++ b/docs/src/server/server.md @@ -0,0 +1,201 @@ +# Server-Side Caching + +Server-side HTTP response caching is fundamentally different from client-side caching. While client-side middleware caches responses from external APIs, server-side middleware caches your own application's responses to reduce load and improve performance. + +## What is Server-Side Caching? + +Server-side caching stores the responses your application generates so that subsequent identical requests can be served from cache without re-executing expensive operations like database queries or complex computations. + +### Example Flow + +**Without Server-Side Caching:** +``` +Request → Routing → Handler → Database Query → Response (200ms) +Request → Routing → Handler → Database Query → Response (200ms) +Request → Routing → Handler → Database Query → Response (200ms) +``` + +**With Server-Side Caching:** +``` +Request → Routing → Cache MISS → Handler → Database Query → Response (200ms) → Cached +Request → Routing → Cache HIT → Response (2ms) +Request → Routing → Cache HIT → Response (2ms) +``` + +## Key Differences from Client-Side Caching + +| Aspect | Client-Side | Server-Side | +|--------|-------------|-------------| +| **What it caches** | External API responses | Your app's responses | +| **Position** | Before making outbound requests | After routing, before handlers | +| **Use case** | Reduce external API calls | Reduce internal computation | +| **RFC 7234 behavior** | Client cache rules | Shared cache rules | +| **Request extensions** | N/A | Must preserve (path params, state) | + +## Available Implementations + +Currently, server-side caching is available for: + +- **Tower-based servers** (Axum, Hyper, Tonic) - See [tower-server](./tower-server.md) + +## When to Use Server-Side Caching + +### Good Use Cases ✅ + +1. **Public API endpoints** with expensive database queries +2. **Read-heavy workloads** where data doesn't change frequently +3. **Dashboard or analytics data** that updates periodically +4. **Static-like content** that requires dynamic generation +5. **Search results** for common queries +6. **Rendered HTML** for public pages + +### Avoid Caching ❌ + +1. **User-specific data** (unless using proper cache key differentiation) +2. **Authenticated endpoints** (without user ID in cache key) +3. **Real-time data** that must always be fresh +4. **Write operations** (POST/PUT/DELETE requests) +5. **Sensitive information** that shouldn't be shared +6. **Session-dependent responses** (without session ID in cache key) + +## Security Considerations + +Server-side caches are **shared caches** - cached responses are served to ALL users. This is different from client-side caches which are per-client. + +### Critical Security Rule + +**Never cache user-specific data without including the user/session identifier in the cache key.** + +### Safe Patterns + +**Pattern 1: Mark user-specific responses as private** +```rust +async fn user_profile() -> Response { + ( + [(header::CACHE_CONTROL, "private")], // Won't be cached + "User profile data" + ).into_response() +} +``` + +**Pattern 2: Include user ID in cache key** +```rust +let keyer = CustomKeyer::new(|req: &Request<()>| { + let user_id = extract_user_id(req); + format!("{} {} user:{}", req.method(), req.uri().path(), user_id) +}); +``` + +**Pattern 3: Don't cache at all** +```rust +async fn sensitive_data() -> Response { + ( + [(header::CACHE_CONTROL, "no-store")], + "Sensitive data" + ).into_response() +} +``` + +## RFC 7234 Compliance + +Server-side caches implement **shared cache** semantics as defined in RFC 7234: + +### Must NOT Cache + +- Responses with `Cache-Control: private` (user-specific) +- Responses with `Cache-Control: no-store` (sensitive) +- Responses with `Cache-Control: no-cache` (requires revalidation) +- Non-2xx status codes (errors) +- Responses with `Authorization` header (unless explicitly allowed) + +### Must Cache Correctly + +- Prefer `s-maxage` over `max-age` (shared cache specific) +- Respect `Vary` headers (content negotiation) +- Handle `Expires` header as fallback +- Support `max-age` and `public` directives + +## Performance Characteristics + +### Benefits + +- **Reduced database load**: Cached responses don't hit the database +- **Lower CPU usage**: Expensive computations run once +- **Faster response times**: Cache hits are typically <5ms +- **Better scalability**: Handle more requests with same resources + +### Considerations + +- **Memory usage**: Cached responses stored in memory or disk +- **Stale data**: Cached data may become outdated +- **Cache warming**: Initial requests (cache misses) are slower +- **Invalidation complexity**: Updating cached data can be tricky + +## Cache Invalidation Strategies + +### Time-Based (TTL) + +Set expiration times on cached responses: + +```rust +async fn handler() -> Response { + ( + [(header::CACHE_CONTROL, "max-age=300")], // 5 minutes + "Response data" + ).into_response() +} +``` + +### Event-Based + +Manually invalidate cache entries when data changes: + +```rust +// After updating user data +cache_manager.delete(&format!("GET /users/{}", user_id)).await?; +``` + +### Hybrid Approach + +Combine TTL with manual invalidation: +- Use TTL for automatic expiration +- Invalidate early when you know data changed + +## Best Practices + +1. **Start conservative**: Use shorter TTLs initially, increase as you gain confidence +2. **Monitor cache hit rates**: Track X-Cache headers to measure effectiveness +3. **Set size limits**: Prevent cache from consuming too much memory +4. **Use appropriate keyers**: Match cache key strategy to your needs +5. **Document caching behavior**: Make it clear which endpoints are cached +6. **Test cache invalidation**: Ensure updates propagate correctly +7. **Consider cache warming**: Pre-populate cache for common requests +8. **Handle cache failures gracefully**: Application should work even if cache fails + +## Monitoring and Debugging + +### Enable Cache Status Headers + +```rust +let options = ServerCacheOptions { + cache_status_headers: true, + ..Default::default() +}; +``` + +This adds `X-Cache` headers to responses: +- `X-Cache: HIT` - Served from cache +- `X-Cache: MISS` - Generated by handler + +### Track Metrics + +Monitor these key metrics: +- Cache hit rate (hits / total requests) +- Average response time (hits vs misses) +- Cache size and memory usage +- Cache eviction rate +- Stale response rate + +## Getting Started + +See the [tower-server](./tower-server.md) documentation for detailed implementation guide. diff --git a/docs/src/server/tower-server.md b/docs/src/server/tower-server.md new file mode 100644 index 0000000..e4a270c --- /dev/null +++ b/docs/src/server/tower-server.md @@ -0,0 +1,418 @@ +# tower-server + +The [`http-cache-tower-server`](https://github.com/06chaynes/http-cache/tree/main/http-cache-tower-server) crate provides Tower Layer and Service implementations for server-side HTTP response caching. Unlike client-side caching, this middleware caches your own application's responses to reduce database queries, computation, and improve response times. + +## Key Differences from Client-Side Caching + +**Client-Side (`http-cache-tower`)**: Caches responses from external APIs you're calling +**Server-Side (`http-cache-tower-server`)**: Caches responses your application generates + +**Critical:** Server-side cache middleware must be placed **AFTER** routing in your middleware stack to preserve request extensions like path parameters (see [Issue #121](https://github.com/06chaynes/http-cache/issues/121)). + +## Getting Started + +```sh +cargo add http-cache-tower-server +``` + +## Features + +- `manager-cacache`: (default) Enables the [`CACacheManager`](https://docs.rs/http-cache/latest/http_cache/struct.CACacheManager.html) backend cache manager. +- `manager-moka`: Enables the [`MokaManager`](https://docs.rs/http-cache/latest/http_cache/struct.MokaManager.html) backend cache manager. + +## Basic Usage with Axum + +```rust +use axum::{ + routing::get, + Router, + extract::Path, +}; +use http_cache_tower_server::ServerCacheLayer; +use http_cache::CACacheManager; +use std::path::PathBuf; + +#[tokio::main] +async fn main() { + // Create cache manager + let cache_manager = CACacheManager::new(PathBuf::from("./cache"), false); + + // Create the server cache layer + let cache_layer = ServerCacheLayer::new(cache_manager); + + // Build your Axum app + let app = Router::new() + .route("/users/:id", get(get_user)) + .route("/posts/:id", get(get_post)) + // IMPORTANT: Place cache layer AFTER routing + .layer(cache_layer); + + // Run the server + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); +} + +async fn get_user(Path(id): Path) -> String { + // Expensive database query or computation + format!("User {}", id) +} + +async fn get_post(Path(id): Path) -> String { + format!("Post {}", id) +} +``` + +## Cache Control with Response Headers + +The middleware respects standard HTTP Cache-Control headers from your handlers: + +```rust +use axum::{ + response::{IntoResponse, Response}, + http::header, +}; + +async fn cacheable_handler() -> Response { + ( + [(header::CACHE_CONTROL, "max-age=300")], // Cache for 5 minutes + "This response will be cached" + ).into_response() +} + +async fn no_cache_handler() -> Response { + ( + [(header::CACHE_CONTROL, "no-store")], // Don't cache + "This response will NOT be cached" + ).into_response() +} + +async fn private_handler() -> Response { + ( + [(header::CACHE_CONTROL, "private")], // User-specific data + "This response will NOT be cached (shared cache)" + ).into_response() +} +``` + +## RFC 7234 Compliance + +This implementation acts as a **shared cache** per RFC 7234: + +### Automatically Rejects + +- `no-store` directive +- `no-cache` directive (requires revalidation, which is not supported) +- `private` directive (shared caches cannot store private responses) +- Non-2xx status codes + +### Supports + +- `max-age`: Cache lifetime in seconds +- `s-maxage`: Shared cache specific lifetime (takes precedence over max-age) +- `public`: Makes response cacheable +- `Expires`: Fallback header when Cache-Control is absent + +## Cache Key Strategies + +### DefaultKeyer (Default) + +Caches based on HTTP method and path: + +```rust +use http_cache_tower_server::{ServerCacheLayer, DefaultKeyer}; + +let cache_layer = ServerCacheLayer::new(cache_manager); +// GET /users/123 and GET /users/456 are cached separately +``` + +### QueryKeyer + +Includes query parameters in the cache key: + +```rust +use http_cache_tower_server::{ServerCacheLayer, QueryKeyer}; + +let cache_layer = ServerCacheLayer::with_keyer(cache_manager, QueryKeyer); +// GET /search?q=rust and GET /search?q=python are cached separately +``` + +### CustomKeyer + +For advanced use cases like content negotiation or user-specific caching: + +```rust +use http_cache_tower_server::{ServerCacheLayer, CustomKeyer}; + +// Example: Include Accept-Language header in cache key +let keyer = CustomKeyer::new(|req: &http::Request<()>| { + let lang = req.headers() + .get("accept-language") + .and_then(|v| v.to_str().ok()) + .unwrap_or("en"); + format!("{} {} lang:{}", req.method(), req.uri().path(), lang) +}); + +let cache_layer = ServerCacheLayer::with_keyer(cache_manager, keyer); +``` + +## Configuration Options + +```rust +use http_cache_tower_server::{ServerCacheLayer, ServerCacheOptions}; +use std::time::Duration; + +let options = ServerCacheOptions { + // Default TTL when no Cache-Control header present + default_ttl: Some(Duration::from_secs(60)), + + // Maximum TTL (even if response specifies longer) + max_ttl: Some(Duration::from_secs(3600)), + + // Minimum TTL (even if response specifies shorter) + min_ttl: Some(Duration::from_secs(10)), + + // Add X-Cache: HIT/MISS headers for debugging + cache_status_headers: true, + + // Maximum body size to cache (bytes) + max_body_size: 128 * 1024 * 1024, // 128 MB + + // Cache responses without Cache-Control header + cache_by_default: false, + + ..Default::default() +}; + +let cache_layer = ServerCacheLayer::new(cache_manager) + .with_options(options); +``` + +## Security Warnings + +### Shared Cache Behavior + +This is a **shared cache** - cached responses are served to ALL users. Improper configuration can leak user-specific data. + +### Do NOT Cache + +- Authenticated endpoints (unless using appropriate CustomKeyer) +- User-specific data (unless keyed by user/session ID) +- Responses with sensitive information + +### Safe Approaches + +**Option 1: Use Cache-Control: private** + +```rust +async fn user_specific_handler() -> Response { + ( + [(header::CACHE_CONTROL, "private")], + "User-specific data - won't be cached" + ).into_response() +} +``` + +**Option 2: Include user ID in cache key** + +```rust +let keyer = CustomKeyer::new(|req: &http::Request<()>| { + let user_id = req.headers() + .get("x-user-id") + .and_then(|v| v.to_str().ok()) + .unwrap_or("anonymous"); + format!("{} {} user:{}", req.method(), req.uri().path(), user_id) +}); +``` + +**Option 3: Don't cache at all** + +```rust +async fn sensitive_handler() -> Response { + ( + [(header::CACHE_CONTROL, "no-store")], + "Sensitive data - never cached" + ).into_response() +} +``` + +## Content Negotiation + +The middleware extracts `Vary` headers but does not automatically enforce them. For content negotiation, use a `CustomKeyer`: + +```rust +// Example: Cache different responses based on Accept-Language +let keyer = CustomKeyer::new(|req: &http::Request<()>| { + let lang = req.headers() + .get("accept-language") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.split(',').next()) + .unwrap_or("en"); + format!("{} {} lang:{}", req.method(), req.uri().path(), lang) +}); +``` + +## Cache Inspection + +Responses include `X-Cache` headers when `cache_status_headers` is enabled: + +- `X-Cache: HIT` - Response served from cache +- `X-Cache: MISS` - Response generated by handler and cached (if cacheable) +- No header - Response not cacheable (or headers disabled) + +## Complete Example + +```rust +use axum::{ + routing::get, + Router, + extract::Path, + response::{IntoResponse, Response}, + http::header, +}; +use http_cache_tower_server::{ServerCacheLayer, ServerCacheOptions, QueryKeyer}; +use http_cache::CACacheManager; +use std::time::Duration; +use std::path::PathBuf; + +#[tokio::main] +async fn main() { + // Configure cache manager + let cache_manager = CACacheManager::new(PathBuf::from("./cache"), false); + + // Configure cache options + let options = ServerCacheOptions { + default_ttl: Some(Duration::from_secs(60)), + max_ttl: Some(Duration::from_secs(3600)), + cache_status_headers: true, + ..Default::default() + }; + + // Create cache layer with query parameter support + let cache_layer = ServerCacheLayer::with_keyer(cache_manager, QueryKeyer) + .with_options(options); + + // Build app + let app = Router::new() + .route("/users/:id", get(get_user)) + .route("/search", get(search)) + .route("/admin/stats", get(admin_stats)) + .layer(cache_layer); // AFTER routing + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); +} + +// Cacheable for 5 minutes +async fn get_user(Path(id): Path) -> Response { + ( + [(header::CACHE_CONTROL, "max-age=300")], + format!("User {}", id) + ).into_response() +} + +// Cacheable with query parameters +async fn search(query: axum::extract::Query>) -> Response { + ( + [(header::CACHE_CONTROL, "max-age=60")], + format!("Search results: {:?}", query) + ).into_response() +} + +// Never cached (admin data) +async fn admin_stats() -> Response { + ( + [(header::CACHE_CONTROL, "no-store")], + "Admin statistics - not cached" + ).into_response() +} +``` + +## Best Practices + +1. **Place middleware after routing** to preserve request extensions +2. **Set appropriate Cache-Control headers** in your handlers +3. **Use `private` directive** for user-specific responses +4. **Monitor cache hit rates** using X-Cache headers +5. **Set reasonable TTL limits** to prevent stale data +6. **Use CustomKeyer** for content negotiation or user-specific caching +7. **Don't cache authenticated endpoints** without proper keying + +## Troubleshooting + +### Path parameters not working + +**Problem:** Axum path extractors fail with cached responses + +**Solution:** Ensure cache layer is placed AFTER routing: + +```rust +// ❌ Wrong - cache layer before routing +let app = Router::new() + .layer(cache_layer) // Too early! + .route("/users/:id", get(handler)); + +// ✅ Correct - cache layer after routing +let app = Router::new() + .route("/users/:id", get(handler)) + .layer(cache_layer); // After routing +``` + +### Responses not being cached + +**Possible causes:** +1. Response has `no-store`, `no-cache`, or `private` directive +2. Response is not 2xx status code +3. Response body exceeds `max_body_size` +4. `cache_by_default` is false and no Cache-Control header present + +**Solution:** Add appropriate Cache-Control headers: + +```rust +async fn handler() -> Response { + ( + [(header::CACHE_CONTROL, "max-age=300")], + "Response body" + ).into_response() +} +``` + +### User data leaking between requests + +**Problem:** Cached user-specific responses served to other users + +**Solution:** Use `CustomKeyer` with user identifier: + +```rust +let keyer = CustomKeyer::new(|req: &http::Request<()>| { + let user = req.headers() + .get("x-user-id") + .and_then(|v| v.to_str().ok()) + .unwrap_or("anonymous"); + format!("{} {} user:{}", req.method(), req.uri().path(), user) +}); +``` + +Or use `Cache-Control: private` to prevent caching entirely. + +## Performance Considerations + +- Cache writes are fire-and-forget (non-blocking) +- Cache lookups are async but fast (especially with in-memory managers) +- Body buffering is required (responses are fully buffered before caching) +- Consider using moka manager for frequently accessed data +- Use cacache manager for larger datasets with disk persistence + +## Comparison with Other Frameworks + +| Feature | http-cache-tower-server | Django Cache | NGINX FastCGI | +|---------|------------------------|--------------|---------------| +| Middleware-based | ✅ | ✅ | ❌ | +| RFC 7234 compliant | ✅ | ⚠️ Partial | ⚠️ Partial | +| Pluggable backends | ✅ | ✅ | ❌ | +| Custom cache keys | ✅ | ✅ | ✅ | +| Type-safe | ✅ | ❌ | ❌ | +| Async-first | ✅ | ❌ | ✅ | diff --git a/http-cache-tower-server/Cargo.toml b/http-cache-tower-server/Cargo.toml new file mode 100644 index 0000000..fbe3c10 --- /dev/null +++ b/http-cache-tower-server/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "http-cache-tower-server" +version = "0.1.0" +description = "Server-side HTTP response caching middleware for Tower/Axum" +authors = ["Christian Haynes <06chaynes@gmail.com>", "Kat Marchán "] +repository = "https://github.com/06chaynes/http-cache" +homepage = "https://http-cache.rs" +license = "MIT OR Apache-2.0" +readme = "README.md" +keywords = ["cache", "http", "middleware", "tower", "server"] +categories = [ + "caching", + "web-programming::http-server" +] +edition = "2021" +rust-version = "1.83.0" + +[dependencies] +http-cache = { version = "1.0.0-alpha.2", path = "../http-cache", default-features = false } +http-cache-semantics = "2.1.0" +tower = { version = "0.5.2", features = ["util"] } +http = "1.2.0" +http-body = "1.0.1" +http-body-util = "0.1.2" +bytes = "1.8.0" +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1.43.0", features = ["time"] } +async-trait = "0.1.85" +httpdate = "1.0.3" + +[dev-dependencies] +tokio = { version = "1.43.0", features = [ "macros", "rt", "rt-multi-thread", "time" ] } +tokio-test = "0.4.4" +tower-test = "0.4.0" +http-body-util = "0.1.2" +tempfile = "3.13.0" +tower = { version = "0.5.2", features = ["util"] } +tower-http = { version = "0.6", features = ["catch-panic"] } +axum = "0.8.7" +async-trait = "0.1.85" + +[[example]] +name = "axum_basic" +required-features = ["manager-cacache"] + +[[example]] +name = "axum_advanced" +required-features = ["manager-cacache"] + +[features] +default = ["manager-cacache"] +manager-cacache = ["http-cache/manager-cacache", "http-cache/cacache-tokio"] +manager-moka = ["http-cache/manager-moka"] diff --git a/http-cache-tower-server/README.md b/http-cache-tower-server/README.md new file mode 100644 index 0000000..611ccc5 --- /dev/null +++ b/http-cache-tower-server/README.md @@ -0,0 +1,412 @@ +# http-cache-tower-server + +[![Crates.io](https://img.shields.io/crates/v/http-cache-tower-server?style=for-the-badge)](https://crates.io/crates/http-cache-tower-server) +[![Docs.rs](https://img.shields.io/docsrs/http-cache-tower-server?style=for-the-badge)](https://docs.rs/http-cache-tower-server) +![Crates.io](https://img.shields.io/crates/l/http-cache-tower-server?style=for-the-badge) + +Server-side HTTP response caching middleware for Tower-based frameworks (Axum, Hyper, Tonic). + +## Overview + +This crate provides Tower middleware for caching your server's HTTP responses to improve performance and reduce load. Unlike client-side caching, this middleware caches responses **after** your handlers execute, making it ideal for expensive operations like database queries or complex computations. + +## When to Use This + +Use `http-cache-tower-server` when you want to: + +- Cache expensive API responses (database queries, aggregations) +- Reduce load on backend services +- Improve response times for read-heavy workloads +- Cache server-rendered content +- Speed up responses that are computed but rarely change + +## Client vs Server Caching + +| Crate | Purpose | Use Case | +|-------|---------|----------| +| `http-cache-tower` | **Client-side caching** | Cache responses from external APIs you call | +| `http-cache-tower-server` | **Server-side caching** | Cache your own application's responses | + +**Important:** If you're experiencing issues with path parameter extraction or routing when using `http-cache-tower` in a server application, you should use this crate instead. See [Issue #121](https://github.com/06chaynes/http-cache/issues/121) for details. + +## Installation + +```sh +cargo add http-cache-tower-server +``` + +### Features + +By default, `manager-cacache` is enabled. + +- `manager-cacache` (default): Enable [cacache](https://github.com/zkat/cacache-rs) disk-based cache backend +- `manager-moka`: Enable [moka](https://github.com/moka-rs/moka) in-memory cache backend + +## Quick Start + +### Basic Example (Axum) + +```rust +use axum::{Router, routing::get, response::IntoResponse}; +use http_cache_tower_server::ServerCacheLayer; +use http_cache::CACacheManager; + +async fn expensive_handler() -> impl IntoResponse { + // Simulate expensive operation + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + // Set cache control to cache for 60 seconds + ( + [("cache-control", "max-age=60")], + "This response is cached for 60 seconds" + ) +} + +#[tokio::main] +async fn main() { + // Create cache manager + let manager = CACacheManager::new("./cache", false); + + // Create router with cache layer + let app = Router::new() + .route("/expensive", get(expensive_handler)) + .layer(ServerCacheLayer::new(manager)); + + // Run server + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); +} +``` + +## How It Works + +1. **Request arrives** → Routing layer processes it (path params extracted) +2. **Cache lookup** → Check if response is cached +3. **Cache hit** → Return cached response immediately +4. **Cache miss** → Call your handler +5. **Handler returns** → Check Cache-Control headers +6. **Should cache?** → Store response if cacheable +7. **Return response** → Send to client + +### Cache Status Headers + +Responses include an `x-cache` header indicating cache status: + +- `x-cache: HIT` → Response served from cache +- `x-cache: MISS` → Response generated by handler (may be cached) +- No header → Response not cacheable + +## Cache Key Generation + +### Built-in Keyers + +#### DefaultKeyer (default) + +Caches based on HTTP method and path: + +```rust +use http_cache_tower_server::{ServerCacheLayer, DefaultKeyer}; + +let layer = ServerCacheLayer::new(manager); +// GET /users/123 → "GET /users/123" +// GET /users/456 → "GET /users/456" +``` + +#### QueryKeyer + +Includes query parameters in cache key: + +```rust +use http_cache_tower_server::{ServerCacheLayer, QueryKeyer}; + +let layer = ServerCacheLayer::with_keyer(manager, QueryKeyer); +// GET /search?q=rust → "GET /search?q=rust" +// GET /search?q=http → "GET /search?q=http" +``` + +### CustomKeyer + +For advanced scenarios (authentication, content negotiation, etc.): + +```rust +use http_cache_tower_server::{ServerCacheLayer, CustomKeyer}; +use http::Request; + +// Include user ID from headers in cache key +let keyer = CustomKeyer::new(|req: &Request<()>| { + let user_id = req.headers() + .get("x-user-id") + .and_then(|v| v.to_str().ok()) + .unwrap_or("anonymous"); + + format!("{} {} user:{}", req.method(), req.uri().path(), user_id) +}); + +let layer = ServerCacheLayer::with_keyer(manager, keyer); +// GET /dashboard with x-user-id: 123 → "GET /dashboard user:123" +// GET /dashboard with x-user-id: 456 → "GET /dashboard user:456" +``` + +## Configuration Options + +```rust +use http_cache_tower_server::{ServerCacheLayer, ServerCacheOptions}; +use std::time::Duration; + +let options = ServerCacheOptions { + // Default TTL when no Cache-Control header present + default_ttl: Some(Duration::from_secs(60)), + + // Maximum TTL (even if response specifies longer) + max_ttl: Some(Duration::from_secs(3600)), + + // Minimum TTL (even if response specifies shorter) + min_ttl: Some(Duration::from_secs(10)), + + // Add X-Cache headers (HIT/MISS) + cache_status_headers: true, + + // Maximum response body size to cache (128 MB) + max_body_size: 128 * 1024 * 1024, + + // Cache responses without explicit Cache-Control + cache_by_default: false, + + // Respect Vary header (currently extracted but not enforced) + respect_vary: true, +}; + +let layer = ServerCacheLayer::new(manager) + .with_options(options); +``` + +## Caching Behavior (RFC 9111 Compliant) + +This middleware implements a **shared cache** per RFC 9111 (HTTP Caching). + +### Cached Responses + +Responses are cached when they have: + +- Status code: 2xx (200, 201, 204, etc.) +- Cache-Control: `max-age=X` → Cached for X seconds +- Cache-Control: `s-maxage=X` → Cached for X seconds (shared cache specific) +- Cache-Control: `public` → Cached with default TTL + +### Never Cached + +Responses are **never** cached if they have: + +- Status code: Non-2xx (4xx, 5xx, 3xx) +- Cache-Control: `no-store` → Prevents all caching +- Cache-Control: `no-cache` → Requires revalidation (not supported) +- Cache-Control: `private` → Only for private caches + +### Directive Precedence + +When multiple directives are present: + +1. `s-maxage` (shared cache specific) takes precedence +2. `max-age` (general directive) +3. `public` (uses default TTL) +4. Expires header (fallback, not currently parsed) + +### Example Headers + +```rust +// Cached for 60 seconds +("cache-control", "max-age=60") + +// Cached for 120 seconds (s-maxage overrides max-age for shared caches) +("cache-control", "max-age=60, s-maxage=120") + +// Cached with default TTL +("cache-control", "public") + +// Never cached +("cache-control", "no-store") +("cache-control", "private") +("cache-control", "no-cache") +``` + +## Security Considerations + +### ⚠️ This is a Shared Cache + +**Critical:** Cached responses are served to **ALL users**. Never cache user-specific data without appropriate measures. + +### Safe Usage Patterns + +#### ✅ Public Content + +```rust +async fn public_page() -> impl IntoResponse { + ( + [("cache-control", "max-age=300")], + "Public content safe to cache" + ) +} +``` + +#### ✅ User-Specific with CustomKeyer + +```rust +// Include user ID in cache key +let keyer = CustomKeyer::new(|req: &Request<()>| { + let user_id = extract_user_id(req); + format!("{} {} user:{}", req.method(), req.uri().path(), user_id) +}); +``` + +#### ❌ UNSAFE: User Data Without Keyer + +```rust +// ❌ DANGEROUS: Will serve user123's data to user456! +async fn user_profile() -> impl IntoResponse { + let user_data = get_current_user_data().await; + ( + [("cache-control", "max-age=60")], // ❌ Don't do this! + user_data + ) +} +``` + +#### ✅ User Data with Private Directive + +```rust +// ✅ Safe: Won't be cached +async fn user_profile() -> impl IntoResponse { + let user_data = get_current_user_data().await; + ( + [("cache-control", "private")], // Won't be cached + user_data + ) +} +``` + +### Best Practices + +1. **Never cache authenticated endpoints** unless using a CustomKeyer that includes session/user ID +2. **Use `Cache-Control: private`** for user-specific responses +3. **Validate cache keys** to prevent cache poisoning +4. **Set body size limits** to prevent DoS attacks +5. **Use TTL constraints** to prevent cache bloat + +## Advanced Examples + +### Content Negotiation + +For responses that vary by Accept-Language: + +```rust +let keyer = CustomKeyer::new(|req: &Request<()>| { + let lang = req.headers() + .get("accept-language") + .and_then(|v| v.to_str().ok()) + .unwrap_or("en"); + + format!("{} {} lang:{}", req.method(), req.uri().path(), lang) +}); + +let layer = ServerCacheLayer::with_keyer(manager, keyer); +``` + +### Conditional Caching + +Only cache certain routes: + +```rust +use axum::middleware; + +async fn cache_middleware( + req: Request, + next: Next, +) -> Response { + // Only cache GET requests to /api/* + if req.method() == Method::GET && req.uri().path().starts_with("/api/") { + // Apply cache layer + } + next.run(req).await +} +``` + +### TTL by Route + +```rust +async fn long_cache_handler() -> impl IntoResponse { + ( + [("cache-control", "max-age=3600")], // 1 hour + "Rarely changing content" + ) +} + +async fn short_cache_handler() -> impl IntoResponse { + ( + [("cache-control", "max-age=60")], // 1 minute + "Frequently updated content" + ) +} +``` + +## Limitations + +### Vary Header + +The middleware extracts `Vary` headers but does not currently enforce them during cache lookup. For content negotiation: + +- Use a `CustomKeyer` that includes relevant headers in the cache key, OR +- Set `Cache-Control: private` to prevent caching + +### Authorization Header + +The middleware does not check for `Authorization` headers in requests. Authenticated endpoints should either: + +- Use `Cache-Control: private` (won't be cached), OR +- Use a `CustomKeyer` that includes user/session ID, OR +- Not be cached at all + +### Expires Header + +The `Expires` header is recognized but not currently parsed. Modern applications should use `Cache-Control` directives instead. + +## Examples + +See the [examples](examples/) directory: + +- [`axum_basic.rs`](examples/axum_basic.rs) - Basic usage with Axum + +Run with: +```sh +cargo run --example axum_basic --features manager-cacache +``` + +## Comparison with Other Crates + +### vs axum-response-cache + +- This crate: RFC 9111 compliant, respects Cache-Control headers +- axum-response-cache: Simpler API, less RFC compliant + +### vs tower-cache-control + +- This crate: Full caching implementation with storage +- tower-cache-control: Only sets Cache-Control headers + +## Minimum Supported Rust Version (MSRV) + +1.82.0 + +## Contributing + +Contributions are welcome! Please see the [main repository](https://github.com/06chaynes/http-cache) for contribution guidelines. + +## License + +Licensed under either of + +- Apache License, Version 2.0 ([LICENSE-APACHE](../LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +- MIT license ([LICENSE-MIT](../LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. diff --git a/http-cache-tower-server/examples/axum_advanced.rs b/http-cache-tower-server/examples/axum_advanced.rs new file mode 100644 index 0000000..c31cde8 --- /dev/null +++ b/http-cache-tower-server/examples/axum_advanced.rs @@ -0,0 +1,274 @@ +//! Advanced HTTP caching with custom keyers, invalidation, and metrics +//! +//! This example demonstrates query-based caching, cache metrics, and invalidation. +//! +//! ## Quick Start +//! +//! ```bash +//! cargo run --example axum_advanced --features manager-cacache +//! ``` +//! +//! ## Step-by-Step Demo +//! +//! ### 1. Check initial metrics (everything at zero) +//! ```bash +//! curl http://localhost:3000/metrics +//! # Cache Metrics: +//! # Hits: 0 +//! # Misses: 0 +//! # Stores: 0 +//! # Hit Rate: 0.0% +//! ``` +//! +//! ### 2. Make a search request (cache MISS) +//! ```bash +//! curl -i http://localhost:3000/search?q=rust +//! # HTTP/1.1 200 OK +//! # x-cache: MISS +//! # cache-control: public, max-age=300 +//! # Search results for: rust +//! ``` +//! +//! ### 3. Repeat the same request (cache HIT) +//! ```bash +//! curl -i http://localhost:3000/search?q=rust +//! # HTTP/1.1 200 OK +//! # x-cache: HIT +//! # Search results for: rust +//! ``` +//! +//! ### 4. Try a different query (cache MISS - different cache key) +//! ```bash +//! curl -i http://localhost:3000/search?q=cache +//! # x-cache: MISS +//! ``` +//! +//! ### 5. Check metrics again +//! ```bash +//! curl http://localhost:3000/metrics +//! # Cache Metrics: +//! # Hits: 1 +//! # Misses: 2 +//! # Stores: 2 +//! # Hit Rate: 33.3% +//! ``` +//! +//! ### 6. Invalidate a cached entry +//! ```bash +//! curl -X DELETE "http://localhost:3000/cache?key=GET%20/search?q=rust" +//! # Invalidated cache key: GET /search?q=rust +//! ``` +//! +//! ### 7. Request again (cache MISS after invalidation) +//! ```bash +//! curl -i http://localhost:3000/search?q=rust +//! # x-cache: MISS +//! ``` +//! +//! ## Other Endpoints +//! +//! ```bash +//! # Product details (cached 10 minutes) +//! curl http://localhost:3000/products/42 +//! +//! # Dashboard (private cache - not stored in shared cache) +//! curl http://localhost:3000/dashboard +//! ``` + +use axum::{ + error_handling::HandleErrorLayer, + extract::{Query, State}, + response::{IntoResponse, Response}, + routing::{delete, get}, + BoxError, Router, +}; +use http::{Request, StatusCode}; +use http_cache::CACacheManager; +use http_cache_tower_server::{ + CacheMetrics, CustomKeyer, QueryKeyer, ServerCacheLayer, ServerCacheOptions, +}; +use serde::Deserialize; +use std::{sync::Arc, time::Duration}; +use tempfile::TempDir; +use tower::ServiceBuilder; + +#[derive(Clone)] +struct AppState { + metrics: Arc, + cache_layer: Arc>, +} + +#[tokio::main] +async fn main() { + // Create cache storage + let temp_dir = TempDir::new().expect("Failed to create temp directory"); + let manager = CACacheManager::new(temp_dir.path().to_path_buf(), false); + + // Configure cache options + let options = ServerCacheOptions { + default_ttl: Some(Duration::from_secs(120)), + max_ttl: Some(Duration::from_secs(3600)), + cache_status_headers: true, + ..Default::default() + }; + + // Create cache layer with QueryKeyer (includes query params in cache key) + let cache_layer = + ServerCacheLayer::with_keyer(manager, QueryKeyer).with_options(options); + + // Store references for metrics and invalidation + let state = AppState { + metrics: cache_layer.metrics().clone(), + cache_layer: Arc::new(cache_layer.clone()), + }; + + // Routes that should be cached + let cached_routes = Router::new() + .route("/search", get(search)) + .route("/dashboard", get(dashboard)) + .route("/products/{id}", get(get_product)) + .layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(handle_cache_error)) + .layer(cache_layer), + ); + + // Monitoring routes bypass the cache + let admin_routes = Router::new() + .route("/metrics", get(metrics)) + .route("/cache", delete(invalidate_cache)); + + // Merge all routes + let app = Router::new() + .merge(cached_routes) + .merge(admin_routes) + .with_state(state); + + // Run server + let listener = + tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap(); + + println!("Server running at http://localhost:3000"); + println!(); + println!("Endpoints:"); + println!(" GET /search?q=... - Cached by query params"); + println!(" GET /dashboard - User-specific content"); + println!(" GET /products/:id - Product details"); + println!(" GET /metrics - Cache statistics"); + println!(" DELETE /cache?key=... - Invalidate cache entry"); + + axum::serve(listener, app).await.unwrap(); +} + +async fn handle_cache_error(err: BoxError) -> Response { + (StatusCode::INTERNAL_SERVER_ERROR, format!("Cache error: {}", err)) + .into_response() +} + +#[derive(Deserialize)] +struct SearchQuery { + q: String, +} + +async fn search(Query(params): Query) -> Response { + // Simulate database query + tokio::time::sleep(Duration::from_millis(50)).await; + + ( + StatusCode::OK, + [("cache-control", "public, max-age=300")], + format!("Search results for: {}", params.q), + ) + .into_response() +} + +async fn dashboard() -> Response { + // Note: In a real app, you'd use a CustomKeyer that includes session ID + // to prevent serving User A's dashboard to User B + ( + StatusCode::OK, + [("cache-control", "private, max-age=60")], + "User dashboard - private cache only", + ) + .into_response() +} + +async fn get_product( + axum::extract::Path(id): axum::extract::Path, +) -> Response { + // Simulate slow database lookup + tokio::time::sleep(Duration::from_millis(100)).await; + + ( + StatusCode::OK, + [("cache-control", "public, max-age=600")], + format!("Product {} details - cached for 10 minutes", id), + ) + .into_response() +} + +async fn metrics(State(state): State) -> Response { + let metrics = &state.metrics; + let hits = metrics.hits.load(std::sync::atomic::Ordering::Relaxed); + let misses = metrics.misses.load(std::sync::atomic::Ordering::Relaxed); + let stores = metrics.stores.load(std::sync::atomic::Ordering::Relaxed); + + let total = hits + misses; + let hit_rate = + if total > 0 { (hits as f64 / total as f64) * 100.0 } else { 0.0 }; + + let body = format!( + "Cache Metrics:\n Hits: {}\n Misses: {}\n Stores: {}\n Hit Rate: {:.1}%", + hits, misses, stores, hit_rate + ); + + (StatusCode::OK, [("cache-control", "no-store")], body).into_response() +} + +#[derive(Deserialize)] +struct InvalidateQuery { + key: String, +} + +async fn invalidate_cache( + State(state): State, + Query(params): Query, +) -> Response { + match state.cache_layer.invalidate(¶ms.key).await { + Ok(()) => { + (StatusCode::OK, format!("Invalidated cache key: {}", params.key)) + .into_response() + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to invalidate: {}", e), + ) + .into_response(), + } +} + +// Example: Creating a session-aware cache layer +#[allow(dead_code)] +fn create_session_cache_layer( + manager: CACacheManager, +) -> ServerCacheLayer< + CACacheManager, + CustomKeyer) -> String + Clone>, +> { + let keyer = CustomKeyer::new(|req: &Request<()>| { + let session = req + .headers() + .get("cookie") + .and_then(|v| v.to_str().ok()) + .and_then(|cookies| { + cookies + .split(';') + .find_map(|c| c.trim().strip_prefix("session=")) + }) + .unwrap_or("anonymous"); + + format!("{} {} session:{}", req.method(), req.uri().path(), session) + }); + + ServerCacheLayer::with_keyer(manager, keyer) +} diff --git a/http-cache-tower-server/examples/axum_basic.rs b/http-cache-tower-server/examples/axum_basic.rs new file mode 100644 index 0000000..1738ae0 --- /dev/null +++ b/http-cache-tower-server/examples/axum_basic.rs @@ -0,0 +1,96 @@ +//! Basic HTTP caching with http-cache-tower-server and Axum +//! +//! This example runs a real HTTP server that you can test with curl: +//! +//! ```bash +//! # Start the server +//! cargo run --example axum_basic --features manager-cacache +//! +//! # Test caching behavior +//! curl -v http://localhost:3000/ # First request: MISS +//! curl -v http://localhost:3000/ # Second request: HIT +//! curl -v http://localhost:3000/users/42 # User endpoint with 30s cache +//! curl -v http://localhost:3000/no-cache # Never cached +//! ``` +//! +//! Run with: cargo run --example axum_basic --features manager-cacache + +use axum::{ + error_handling::HandleErrorLayer, + extract::Path, + response::{IntoResponse, Response}, + routing::get, + BoxError, Router, +}; +use http::StatusCode; +use http_cache::CACacheManager; +use http_cache_tower_server::ServerCacheLayer; +use tempfile::TempDir; +use tower::ServiceBuilder; + +#[tokio::main] +async fn main() { + // Create cache storage (use a persistent path in production) + let temp_dir = TempDir::new().expect("Failed to create temp directory"); + let manager = CACacheManager::new(temp_dir.path().to_path_buf(), false); + + // Build the router with standard Axum handlers + let app = Router::new() + .route("/", get(index)) + .route("/users/{id}", get(get_user)) + .route("/no-cache", get(no_cache)) + .layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(handle_cache_error)) + .layer(ServerCacheLayer::new(manager)), + ); + + // Run the server + let listener = + tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap(); + + println!("Server running at http://localhost:3000"); + println!(); + println!("Try these commands:"); + println!( + " curl -v http://localhost:3000/ # Watch X-Cache header" + ); + println!( + " curl -v http://localhost:3000/users/42 # User-specific endpoint" + ); + println!(" curl -v http://localhost:3000/no-cache # Never cached"); + + axum::serve(listener, app).await.unwrap(); +} + +async fn handle_cache_error(err: BoxError) -> Response { + (StatusCode::INTERNAL_SERVER_ERROR, format!("Cache error: {}", err)) + .into_response() +} + +async fn index() -> Response { + ( + StatusCode::OK, + [("cache-control", "max-age=60")], + "Hello! This response is cached for 60 seconds.", + ) + .into_response() +} + +async fn get_user(Path(id): Path) -> Response { + ( + StatusCode::OK, + [("cache-control", "max-age=30")], + format!("User {} - Cached for 30 seconds", id), + ) + .into_response() +} + +async fn no_cache() -> Response { + ( + StatusCode::OK, + [("cache-control", "no-store")], + "This response is never cached", + ) + .into_response() +} diff --git a/http-cache-tower-server/src/lib.rs b/http-cache-tower-server/src/lib.rs new file mode 100644 index 0000000..431e077 --- /dev/null +++ b/http-cache-tower-server/src/lib.rs @@ -0,0 +1,1018 @@ +//! Server-side HTTP response caching middleware for Tower. +//! +//! This crate provides Tower middleware for caching HTTP responses on the server side. +//! Unlike client-side caching, this middleware caches your own application's responses +//! to reduce load and improve performance. +//! +//! # Key Features +//! +//! - Response-first architecture: Caches based on response headers, not requests +//! - Preserves request context: Maintains all request extensions (path params, state, etc.) +//! - Handler-centric: Calls the handler first, then decides whether to cache +//! - RFC 7234 compliant: Respects Cache-Control, Vary, and other standard headers +//! - Reuses existing infrastructure: Leverages `CacheManager` trait from `http-cache` +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use http_cache_tower_server::ServerCacheLayer; +//! use tower::{Service, Layer}; +//! # use http_cache::{CacheManager, HttpResponse, HttpVersion}; +//! # use http_cache_semantics::CachePolicy; +//! # use std::collections::HashMap; +//! # use std::sync::{Arc, Mutex}; +//! # +//! # #[derive(Clone)] +//! # struct MemoryCacheManager { +//! # store: Arc>>, +//! # } +//! # +//! # impl MemoryCacheManager { +//! # fn new() -> Self { +//! # Self { store: Arc::new(Mutex::new(HashMap::new())) } +//! # } +//! # } +//! # +//! # #[async_trait::async_trait] +//! # impl CacheManager for MemoryCacheManager { +//! # async fn get(&self, cache_key: &str) -> http_cache::Result> { +//! # Ok(self.store.lock().unwrap().get(cache_key).cloned()) +//! # } +//! # async fn put(&self, cache_key: String, res: HttpResponse, policy: CachePolicy) -> http_cache::Result { +//! # self.store.lock().unwrap().insert(cache_key, (res.clone(), policy)); +//! # Ok(res) +//! # } +//! # async fn delete(&self, cache_key: &str) -> http_cache::Result<()> { +//! # self.store.lock().unwrap().remove(cache_key); +//! # Ok(()) +//! # } +//! # } +//! +//! # tokio_test::block_on(async { +//! let manager = MemoryCacheManager::new(); +//! let layer = ServerCacheLayer::new(manager); +//! +//! // Apply the layer to your Tower service +//! let service = tower::service_fn(|_req: Request>| async { +//! Ok::<_, std::io::Error>( +//! Response::builder() +//! .header("cache-control", "max-age=60") +//! .body(Full::new(Bytes::from("Hello, World!"))) +//! .unwrap() +//! ) +//! }); +//! +//! let mut cached_service = layer.layer(service); +//! # }); +//! ``` +//! +//! # Vary Header Support +//! +//! This cache enforces `Vary` headers using `http-cache-semantics`. When a response includes +//! a `Vary` header, subsequent requests must have matching header values to receive the cached +//! response. Requests with different header values will result in cache misses. +//! +//! For example, if a response has `Vary: Accept-Language`, a cached English response won't be +//! served to a request with `Accept-Language: de`. +//! +//! # Security Warnings +//! +//! This is a **shared cache** - cached responses are served to ALL users. Improper configuration +//! can leak user-specific data between different users. +//! +//! ## Authorization and Authentication +//! +//! This cache does not check for `Authorization` headers or session cookies in requests. +//! Caching authenticated endpoints without proper cache key differentiation will cause +//! user A's response to be served to user B. +//! +//! **Do NOT cache authenticated endpoints** unless you use a `CustomKeyer` that includes +//! the user or session identifier in the cache key: +//! +//! ```rust +//! # use http_cache_tower_server::CustomKeyer; +//! # use http::Request; +//! // Example: Include session ID in cache key +//! let keyer = CustomKeyer::new(|req: &Request<()>| { +//! let session = req.headers() +//! .get("cookie") +//! .and_then(|v| v.to_str().ok()) +//! .and_then(|c| extract_session_id(c)) +//! .unwrap_or("anonymous"); +//! format!("{} {} session:{}", req.method(), req.uri().path(), session) +//! }); +//! # fn extract_session_id(cookie: &str) -> Option<&str> { None } +//! ``` +//! +//! ## General Security Considerations +//! +//! - Never cache responses containing user-specific data without user-specific cache keys +//! - Validate cache keys to prevent cache poisoning attacks +//! - Be careful with header-based caching due to header injection risks +//! - Consider the `private` Cache-Control directive for user-specific responses (automatically rejected by this cache) + +#![warn(missing_docs)] +#![deny(unsafe_code)] + +use bytes::Bytes; +use http::{header::HeaderValue, Request, Response}; +use http_body::{Body as HttpBody, Frame}; +use http_body_util::BodyExt; +use http_cache::{CacheManager, HttpResponse, HttpVersion}; +use http_cache_semantics::{BeforeRequest, CachePolicy}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::error::Error as StdError; +use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::{Duration, SystemTime}; +use tower::{Layer, Service}; + +type BoxError = Box; + +/// Cache performance metrics. +/// +/// Tracks hits, misses, and stores for monitoring cache effectiveness. +#[derive(Debug, Default)] +pub struct CacheMetrics { + /// Number of cache hits. + pub hits: AtomicU64, + /// Number of cache misses. + pub misses: AtomicU64, + /// Number of responses stored in cache. + pub stores: AtomicU64, + /// Number of responses skipped (too large, not cacheable, etc.). + pub skipped: AtomicU64, +} + +impl CacheMetrics { + /// Create new metrics instance. + pub fn new() -> Self { + Self::default() + } + + /// Calculate cache hit rate as a percentage (0.0 to 1.0). + pub fn hit_rate(&self) -> f64 { + let hits = self.hits.load(Ordering::Relaxed); + let total = hits + self.misses.load(Ordering::Relaxed); + if total == 0 { + 0.0 + } else { + hits as f64 / total as f64 + } + } + + /// Reset all metrics to zero. + pub fn reset(&self) { + self.hits.store(0, Ordering::Relaxed); + self.misses.store(0, Ordering::Relaxed); + self.stores.store(0, Ordering::Relaxed); + self.skipped.store(0, Ordering::Relaxed); + } +} + +/// A trait for generating cache keys from HTTP requests. +pub trait Keyer: Clone + Send + Sync + 'static { + /// Generate a cache key for the given request. + fn cache_key(&self, req: &Request) -> String; +} + +/// Default keyer that uses HTTP method and path. +/// +/// Generates keys in the format: `{METHOD} {path}` +/// +/// # Example +/// +/// ``` +/// # use http::Request; +/// # use http_cache_tower_server::{Keyer, DefaultKeyer}; +/// let keyer = DefaultKeyer; +/// let req = Request::get("/users/123").body(()).unwrap(); +/// let key = keyer.cache_key(&req); +/// assert_eq!(key, "GET /users/123"); +/// ``` +#[derive(Debug, Clone, Copy, Default)] +pub struct DefaultKeyer; + +impl Keyer for DefaultKeyer { + fn cache_key(&self, req: &Request) -> String { + format!("{} {}", req.method(), req.uri().path()) + } +} + +/// Keyer that includes query parameters in the cache key. +/// +/// Generates keys in the format: `{METHOD} {path}?{query}` +/// +/// # Example +/// +/// ``` +/// # use http::Request; +/// # use http_cache_tower_server::{Keyer, QueryKeyer}; +/// let keyer = QueryKeyer; +/// let req = Request::get("/users?page=1").body(()).unwrap(); +/// let key = keyer.cache_key(&req); +/// assert_eq!(key, "GET /users?page=1"); +/// ``` +#[derive(Debug, Clone, Copy, Default)] +pub struct QueryKeyer; + +impl Keyer for QueryKeyer { + fn cache_key(&self, req: &Request) -> String { + format!("{} {}", req.method(), req.uri()) + } +} + +/// Custom keyer that uses a user-provided function. +/// +/// Use this when the default method+path keying is insufficient, such as: +/// - Content negotiation based on request headers (Accept-Language, Accept-Encoding) +/// - User-specific or session-specific caching +/// - Query parameter normalization +/// +/// # Examples +/// +/// Basic custom format: +/// +/// ``` +/// # use http::Request; +/// # use http_cache_tower_server::{Keyer, CustomKeyer}; +/// let keyer = CustomKeyer::new(|req: &Request<()>| { +/// format!("custom-{}-{}", req.method(), req.uri().path()) +/// }); +/// let req = Request::get("/users").body(()).unwrap(); +/// let key = keyer.cache_key(&req); +/// assert_eq!(key, "custom-GET-/users"); +/// ``` +/// +/// Content negotiation (Accept-Language): +/// +/// ``` +/// # use http::Request; +/// # use http_cache_tower_server::{Keyer, CustomKeyer}; +/// let keyer = CustomKeyer::new(|req: &Request<()>| { +/// let lang = req.headers() +/// .get("accept-language") +/// .and_then(|v| v.to_str().ok()) +/// .and_then(|s| s.split(',').next()) +/// .unwrap_or("en"); +/// format!("{} {} lang:{}", req.method(), req.uri().path(), lang) +/// }); +/// ``` +/// +/// User-specific caching (session-based): +/// +/// ``` +/// # use http::Request; +/// # use http_cache_tower_server::{Keyer, CustomKeyer}; +/// let keyer = CustomKeyer::new(|req: &Request<()>| { +/// let user_id = req.headers() +/// .get("x-user-id") +/// .and_then(|v| v.to_str().ok()) +/// .unwrap_or("anonymous"); +/// format!("{} {} user:{}", req.method(), req.uri().path(), user_id) +/// }); +/// ``` +/// +/// # Security Warning +/// +/// When caching user-specific or session-specific data, ensure the user/session identifier +/// is included in the cache key. Failure to do so will cause responses from one user to be +/// served to other users. +#[derive(Clone)] +pub struct CustomKeyer { + func: F, +} + +impl CustomKeyer { + /// Create a new custom keyer with the given function. + pub fn new(func: F) -> Self { + Self { func } + } +} + +impl Keyer for CustomKeyer +where + F: Fn(&Request<()>) -> String + Clone + Send + Sync + 'static, +{ + fn cache_key(&self, req: &Request) -> String { + // Create a temporary request with the same parts but () body + let mut temp_req = Request::builder() + .method(req.method()) + .uri(req.uri()) + .version(req.version()) + .body(()) + .unwrap(); + + // Copy headers for content negotiation support + *temp_req.headers_mut() = req.headers().clone(); + + (self.func)(&temp_req) + } +} + +/// Configuration options for server-side caching. +#[derive(Debug, Clone)] +pub struct ServerCacheOptions { + /// Default TTL when response has no Cache-Control header. + pub default_ttl: Option, + + /// Maximum TTL, even if response specifies longer. + pub max_ttl: Option, + + /// Minimum TTL, even if response specifies shorter. + pub min_ttl: Option, + + /// Whether to add X-Cache headers (HIT/MISS). + pub cache_status_headers: bool, + + /// Maximum response body size to cache (in bytes). + pub max_body_size: usize, + + /// Whether to cache responses without explicit Cache-Control. + pub cache_by_default: bool, + + /// Whether to respect Vary header for content negotiation. + /// + /// When true (default), cached responses are only served if the request's + /// headers match those specified in the response's Vary header. This is + /// enforced via `http-cache-semantics`. + pub respect_vary: bool, + + /// Whether to respect Authorization headers per RFC 9111 §3.5. + /// + /// When true (default), requests with `Authorization` headers are not cached + /// unless the response explicitly permits it via `public`, `s-maxage`, or + /// `must-revalidate` directives. + /// + /// This prevents accidental caching of authenticated responses that could + /// leak user-specific data to other users. + pub respect_authorization: bool, +} + +impl Default for ServerCacheOptions { + fn default() -> Self { + Self { + default_ttl: Some(Duration::from_secs(60)), + max_ttl: Some(Duration::from_secs(3600)), + min_ttl: None, + cache_status_headers: true, + max_body_size: 128 * 1024 * 1024, + cache_by_default: false, + respect_vary: true, + respect_authorization: true, + } + } +} + +/// A cached HTTP response with metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedResponse { + /// Response status code. + pub status: u16, + + /// Response headers. + pub headers: HashMap, + + /// Response body bytes. + pub body: Vec, + + /// When this response was cached. + pub cached_at: SystemTime, + + /// Time-to-live duration. + pub ttl: Duration, + + /// Optional vary headers for content negotiation. + pub vary: Option>, +} + +impl CachedResponse { + /// Check if this cached response is stale. + pub fn is_stale(&self) -> bool { + SystemTime::now() + .duration_since(self.cached_at) + .unwrap_or(Duration::MAX) + > self.ttl + } + + /// Convert to an HTTP response. + pub fn into_response(self) -> Response { + let mut builder = Response::builder().status(self.status); + + for (key, value) in self.headers { + if let Ok(header_value) = HeaderValue::from_str(&value) { + builder = builder.header(key, header_value); + } + } + + builder.body(Bytes::from(self.body)).unwrap() + } +} + +/// Response body types. +#[derive(Debug)] +pub enum ResponseBody { + /// Cached response body. + Cached(Bytes), + /// Fresh response body. + Fresh(Bytes), + /// Uncacheable response body. + Uncacheable(Bytes), +} + +impl HttpBody for ResponseBody { + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let bytes = match &mut *self { + ResponseBody::Cached(b) + | ResponseBody::Fresh(b) + | ResponseBody::Uncacheable(b) => { + std::mem::replace(b, Bytes::new()) + } + }; + + if bytes.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Frame::data(bytes)))) + } + } + + fn is_end_stream(&self) -> bool { + match self { + ResponseBody::Cached(b) + | ResponseBody::Fresh(b) + | ResponseBody::Uncacheable(b) => b.is_empty(), + } + } +} + +/// Tower layer for server-side HTTP response caching. +/// +/// This layer should be placed AFTER routing to ensure request +/// extensions (like path parameters) are preserved. +/// +/// # Shared Cache Behavior +/// +/// This implements a **shared cache** as defined in RFC 9111. Responses cached by this layer +/// are served to all users making requests with matching cache keys. The cache automatically +/// rejects responses with the `private` directive, but does not inspect `Authorization` headers +/// or session cookies. +/// +/// For authenticated or user-specific endpoints, either: +/// - Set `Cache-Control: private` in responses (prevents caching) +/// - Use a `CustomKeyer` that includes user/session identifiers in the cache key +#[derive(Clone)] +pub struct ServerCacheLayer +where + M: CacheManager, + K: Keyer, +{ + manager: M, + keyer: K, + options: ServerCacheOptions, + metrics: Arc, +} + +impl ServerCacheLayer +where + M: CacheManager, +{ + /// Create a new cache layer with default options. + pub fn new(manager: M) -> Self { + Self { + manager, + keyer: DefaultKeyer, + options: ServerCacheOptions::default(), + metrics: Arc::new(CacheMetrics::new()), + } + } +} + +impl ServerCacheLayer +where + M: CacheManager, + K: Keyer, +{ + /// Create a cache layer with a custom keyer. + pub fn with_keyer(manager: M, keyer: K) -> Self { + Self { + manager, + keyer, + options: ServerCacheOptions::default(), + metrics: Arc::new(CacheMetrics::new()), + } + } + + /// Set custom options. + pub fn with_options(mut self, options: ServerCacheOptions) -> Self { + self.options = options; + self + } + + /// Get a reference to the cache metrics. + pub fn metrics(&self) -> &Arc { + &self.metrics + } + + /// Invalidate a specific cache entry by its key. + pub async fn invalidate(&self, cache_key: &str) -> Result<(), BoxError> { + self.manager.delete(cache_key).await + } + + /// Invalidate cache entry for a specific request. + /// + /// Uses the configured keyer to generate the cache key from the request. + pub async fn invalidate_request( + &self, + req: &Request, + ) -> Result<(), BoxError> { + let cache_key = self.keyer.cache_key(req); + self.invalidate(&cache_key).await + } +} + +impl Layer for ServerCacheLayer +where + M: CacheManager + Clone, + K: Keyer, +{ + type Service = ServerCacheService; + + fn layer(&self, inner: S) -> Self::Service { + ServerCacheService { + inner, + manager: self.manager.clone(), + keyer: self.keyer.clone(), + options: self.options.clone(), + metrics: self.metrics.clone(), + } + } +} + +/// Tower service that implements response caching. +#[derive(Clone)] +pub struct ServerCacheService +where + M: CacheManager, + K: Keyer, +{ + inner: S, + manager: M, + keyer: K, + options: ServerCacheOptions, + metrics: Arc, +} + +impl Service> + for ServerCacheService +where + S: Service, Response = Response> + + Clone + + Send + + 'static, + S::Error: Into, + S::Future: Send + 'static, + M: CacheManager + Clone, + K: Keyer, + ReqBody: Send + 'static, + ResBody: HttpBody + Send + 'static, + ResBody::Data: Send, + ResBody::Error: Into, +{ + type Response = Response; + type Error = BoxError; + type Future = Pin< + Box< + dyn std::future::Future< + Output = std::result::Result, + > + Send, + >, + >; + + fn poll_ready( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let manager = self.manager.clone(); + let keyer = self.keyer.clone(); + let options = self.options.clone(); + let metrics = self.metrics.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + // Store request parts for later use in should_cache + let (req_parts, req_body) = req.into_parts(); + + // Generate cache key from request parts + let temp_req = Request::from_parts(req_parts.clone(), ()); + let cache_key = keyer.cache_key(&temp_req); + + // Try to get from cache + if let Ok(Some((cached_resp, policy))) = + manager.get(&cache_key).await + { + // Deserialize cached response first + if let Ok(cached) = + serde_json::from_slice::(&cached_resp.body) + { + // Check freshness using both CachePolicy and our TTL tracking. + // CachePolicy handles Vary header matching. + // Our is_stale() handles the TTL we assigned (especially for cache_by_default). + let before_req = + policy.before_request(&req_parts, SystemTime::now()); + + // Determine if response had explicit freshness directives + // (max-age or s-maxage). If it only has "public" or other directives + // without explicit TTL, we use our own TTL tracking. + let has_explicit_ttl = + cached.headers.get("cache-control").is_some_and(|cc| { + cc.contains("max-age") || cc.contains("s-maxage") + }); + + let is_fresh = match before_req { + BeforeRequest::Fresh(_) => { + // CachePolicy says fresh - use it + true + } + BeforeRequest::Stale { .. } => { + // CachePolicy says stale. This could be due to: + // 1. Vary header mismatch + // 2. Time-based staleness per cache headers + // 3. No explicit TTL (cache_by_default or public-only) + // + // For case 3, our TTL tracking is authoritative. + // For cases 1-2, we should respect CachePolicy. + if has_explicit_ttl { + // Had explicit TTL - trust CachePolicy + false + } else { + // No explicit TTL - use our TTL + !cached.is_stale() + } + } + }; + + if is_fresh { + // Cache hit + metrics.hits.fetch_add(1, Ordering::Relaxed); + let mut response = cached.into_response(); + + if options.cache_status_headers { + response.headers_mut().insert( + "x-cache", + HeaderValue::from_static("HIT"), + ); + } + + return Ok(response.map(ResponseBody::Cached)); + } + } + } + + // Reconstruct request for handler + let req = Request::from_parts(req_parts.clone(), req_body); + + // Cache miss or stale - call the handler + metrics.misses.fetch_add(1, Ordering::Relaxed); + let response = inner.call(req).await.map_err(Into::into)?; + + // Split response to check if we should cache + let (res_parts, body) = response.into_parts(); + + // Check if we should cache this response + if let Some(ttl) = should_cache(&req_parts, &res_parts, &options) { + // Buffer the response body + let body_bytes = match collect_body(body).await { + Ok(bytes) => bytes, + Err(e) => { + // If we can't collect the body, return an error response + return Err(e); + } + }; + + // Check size limit + if body_bytes.len() <= options.max_body_size { + metrics.stores.fetch_add(1, Ordering::Relaxed); + // Create cached response + let cached = CachedResponse { + status: res_parts.status.as_u16(), + headers: res_parts + .headers + .iter() + .filter_map(|(k, v)| { + v.to_str() + .ok() + .map(|s| (k.to_string(), s.to_string())) + }) + .collect(), + body: body_bytes.to_vec(), + cached_at: SystemTime::now(), + ttl, + vary: extract_vary_headers(&res_parts), + }; + + // Store in cache (fire and forget) + let cached_json = serde_json::to_vec(&cached) + .map_err(|e| Box::new(e) as BoxError)?; + let http_response = HttpResponse { + body: cached_json, + headers: Default::default(), + status: 200, + url: cache_key.clone().parse().unwrap_or_else(|_| { + "http://localhost/".parse().unwrap() + }), + version: HttpVersion::Http11, + }; + + // Create CachePolicy from actual request/response for Vary support + let policy_req = Request::from_parts(req_parts.clone(), ()); + let policy_res = + Response::from_parts(res_parts.clone(), ()); + let policy = CachePolicy::new(&policy_req, &policy_res); + + // Spawn cache write asynchronously + let manager_clone = manager.clone(); + tokio::spawn(async move { + let _ = manager_clone + .put(cache_key, http_response, policy) + .await; + }); + } else { + // Body too large + metrics.skipped.fetch_add(1, Ordering::Relaxed); + } + + // Return response with MISS header + let mut response = Response::from_parts(res_parts, body_bytes); + if options.cache_status_headers { + response + .headers_mut() + .insert("x-cache", HeaderValue::from_static("MISS")); + } + return Ok(response.map(ResponseBody::Fresh)); + } + + // Don't cache - just return + metrics.skipped.fetch_add(1, Ordering::Relaxed); + let body_bytes = collect_body(body).await?; + Ok(Response::from_parts(res_parts, body_bytes) + .map(ResponseBody::Uncacheable)) + }) + } +} + +/// Collect a body into bytes. +async fn collect_body(body: B) -> std::result::Result +where + B: HttpBody, + B::Error: Into, +{ + body.collect() + .await + .map(|collected| collected.to_bytes()) + .map_err(Into::into) +} + +/// Extract Vary headers from response parts. +fn extract_vary_headers(parts: &http::response::Parts) -> Option> { + parts + .headers + .get(http::header::VARY) + .and_then(|v| v.to_str().ok()) + .map(|s| s.split(',').map(|h| h.trim().to_string()).collect()) +} + +/// Determine if a response should be cached based on its headers. +/// Implements RFC 7234/9111 requirements for shared caches. +/// Helper function to check if a Cache-Control directive is present. +/// This properly parses directives by splitting on commas and matching exact names. +fn has_directive(cache_control: &str, directive: &str) -> bool { + cache_control + .split(',') + .map(|d| d.trim()) + .any(|d| d == directive || d.starts_with(&format!("{}=", directive))) +} + +/// Check if response explicitly permits caching of authorized requests per RFC 9111 §3.5. +/// +/// Returns true if the response contains directives that allow caching despite +/// the request having an Authorization header. +fn response_permits_authorized_caching(cc_str: &str) -> bool { + has_directive(cc_str, "public") + || has_directive(cc_str, "s-maxage") + || has_directive(cc_str, "must-revalidate") +} + +fn should_cache( + req_parts: &http::request::Parts, + res_parts: &http::response::Parts, + options: &ServerCacheOptions, +) -> Option { + // RFC 7234: Only cache successful responses (2xx) + if !res_parts.status.is_success() { + return None; + } + + // RFC 9111 §3.5: Check Authorization header + let has_authorization = + req_parts.headers.contains_key(http::header::AUTHORIZATION); + + // RFC 7234: Check Cache-Control directives + if let Some(cc) = res_parts.headers.get(http::header::CACHE_CONTROL) { + let cc_str = cc.to_str().ok()?; + + // RFC 9111 §3.5: If request has Authorization header, only cache if + // response explicitly permits it + if has_authorization + && options.respect_authorization + && !response_permits_authorized_caching(cc_str) + { + return None; + } + + // RFC 7234: MUST NOT store if no-store directive present + if has_directive(cc_str, "no-store") { + return None; + } + + // RFC 7234: MUST NOT store if no-cache + // Note: Per RFC, no-cache means "cache but always revalidate". However, + // without conditional request support (ETag/If-None-Match), we cannot + // revalidate, so we skip caching entirely. + if has_directive(cc_str, "no-cache") { + return None; + } + + // RFC 7234: Shared caches MUST NOT store responses with private directive + if has_directive(cc_str, "private") { + return None; + } + + // RFC 7234: s-maxage directive overrides max-age for shared caches + if let Some(s_maxage) = parse_s_maxage(cc_str) { + let ttl = Duration::from_secs(s_maxage); + let ttl = apply_ttl_constraints(ttl, options); + return Some(ttl); + } + + // RFC 7234: Extract max-age for cache lifetime + if let Some(max_age) = parse_max_age(cc_str) { + let ttl = Duration::from_secs(max_age); + let ttl = apply_ttl_constraints(ttl, options); + return Some(ttl); + } + + // RFC 7234: public directive makes response cacheable + if has_directive(cc_str, "public") { + return options.default_ttl; + } + } else { + // No Cache-Control header + // RFC 9111 §3.5: Don't cache authorized requests without explicit permission + if has_authorization && options.respect_authorization { + return None; + } + } + + // RFC 7234: Check for Expires header if no Cache-Control + if let Some(expires) = res_parts.headers.get(http::header::EXPIRES) { + if let Ok(expires_str) = expires.to_str() { + if let Some(ttl) = parse_expires(expires_str) { + let ttl = apply_ttl_constraints(ttl, options); + return Some(ttl); + } + } + } + + // No explicit caching directive + if options.cache_by_default { + options.default_ttl + } else { + None + } +} + +/// Apply min/max TTL constraints from options. +fn apply_ttl_constraints( + ttl: Duration, + options: &ServerCacheOptions, +) -> Duration { + let mut result = ttl; + + if let Some(max) = options.max_ttl { + result = result.min(max); + } + + if let Some(min) = options.min_ttl { + result = result.max(min); + } + + result +} + +/// Parse max-age from Cache-Control header. +fn parse_max_age(cache_control: &str) -> Option { + for directive in cache_control.split(',') { + let directive = directive.trim(); + if let Some(value) = directive.strip_prefix("max-age=") { + return value.parse().ok(); + } + } + None +} + +/// Parse s-maxage from Cache-Control header (shared cache specific). +fn parse_s_maxage(cache_control: &str) -> Option { + for directive in cache_control.split(',') { + let directive = directive.trim(); + if let Some(value) = directive.strip_prefix("s-maxage=") { + return value.parse().ok(); + } + } + None +} + +/// Parse Expires header to calculate TTL. +/// +/// Returns the duration until expiration, or None if the date is invalid or in the past. +fn parse_expires(expires: &str) -> Option { + let expires_time = httpdate::parse_http_date(expires).ok()?; + let now = SystemTime::now(); + + expires_time.duration_since(now).ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_keyer() { + let keyer = DefaultKeyer; + let req = Request::get("/users/123").body(()).unwrap(); + let key = keyer.cache_key(&req); + assert_eq!(key, "GET /users/123"); + } + + #[test] + fn test_query_keyer() { + let keyer = QueryKeyer; + let req = Request::get("/users?page=1").body(()).unwrap(); + let key = keyer.cache_key(&req); + assert_eq!(key, "GET /users?page=1"); + } + + #[test] + fn test_parse_max_age() { + assert_eq!(parse_max_age("max-age=3600"), Some(3600)); + assert_eq!(parse_max_age("public, max-age=3600"), Some(3600)); + assert_eq!(parse_max_age("max-age=3600, public"), Some(3600)); + assert_eq!(parse_max_age("public"), None); + } + + #[test] + fn test_parse_s_maxage() { + assert_eq!(parse_s_maxage("s-maxage=7200"), Some(7200)); + assert_eq!(parse_s_maxage("public, s-maxage=7200"), Some(7200)); + assert_eq!(parse_s_maxage("s-maxage=7200, max-age=3600"), Some(7200)); + assert_eq!(parse_s_maxage("public"), None); + } + + #[test] + fn test_apply_ttl_constraints() { + let options = ServerCacheOptions { + min_ttl: Some(Duration::from_secs(10)), + max_ttl: Some(Duration::from_secs(100)), + ..Default::default() + }; + + assert_eq!( + apply_ttl_constraints(Duration::from_secs(5), &options), + Duration::from_secs(10) + ); + assert_eq!( + apply_ttl_constraints(Duration::from_secs(50), &options), + Duration::from_secs(50) + ); + assert_eq!( + apply_ttl_constraints(Duration::from_secs(200), &options), + Duration::from_secs(100) + ); + } +} diff --git a/http-cache-tower-server/tests/integration.rs b/http-cache-tower-server/tests/integration.rs new file mode 100644 index 0000000..11cf17e --- /dev/null +++ b/http-cache-tower-server/tests/integration.rs @@ -0,0 +1,1359 @@ +use bytes::Bytes; +use http::{Request, Response, StatusCode}; +use http_body_util::Full; +use http_cache::{CacheManager, HttpResponse, Result}; +use http_cache_semantics::CachePolicy; +use http_cache_tower_server::{ + CustomKeyer, DefaultKeyer, Keyer, QueryKeyer, ServerCacheLayer, + ServerCacheOptions, +}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tower::{Layer, Service, ServiceExt}; + +// Extension type for testing path parameter preservation +#[derive(Debug, Clone, PartialEq)] +struct PathParams { + id: String, +} + +// Simple in-memory cache manager for testing +#[derive(Clone)] +struct MemoryCacheManager { + store: Arc>>, +} + +impl MemoryCacheManager { + fn new() -> Self { + Self { store: Arc::new(Mutex::new(HashMap::new())) } + } +} + +#[async_trait::async_trait] +impl CacheManager for MemoryCacheManager { + async fn get( + &self, + cache_key: &str, + ) -> Result> { + Ok(self.store.lock().unwrap().get(cache_key).cloned()) + } + + async fn put( + &self, + cache_key: String, + res: HttpResponse, + policy: CachePolicy, + ) -> Result { + self.store.lock().unwrap().insert(cache_key, (res.clone(), policy)); + Ok(res) + } + + async fn delete(&self, cache_key: &str) -> Result<()> { + self.store.lock().unwrap().remove(cache_key); + Ok(()) + } +} + +#[tokio::test] +async fn test_cache_hit_and_miss() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from("Hello, World!"))) + .unwrap(), + ) + })); + + // First request - cache miss + let req = Request::get("/test").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + // Give cache write time to complete + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request - cache hit + let req = Request::get("/test").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_no_store_directive() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "no-store") + .body(Full::new(Bytes::from("Don't cache me"))) + .unwrap(), + ) + })); + + // Request should not be cached + let req = Request::get("/no-store").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + // Should not have cache header if not cached + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "MISS" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should also not hit cache + let req = Request::get("/no-store").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + // Should not be a cache hit + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT" + ); +} + +#[tokio::test] +async fn test_private_directive() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "private, max-age=60") + .body(Full::new(Bytes::from("Private data"))) + .unwrap(), + ) + })); + + // Request should not be cached (shared cache) + let req = Request::get("/private").body(Full::new(Bytes::new())).unwrap(); + let _res = service.ready().await.unwrap().call(req).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should not hit cache + let req = Request::get("/private").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT" + ); +} + +#[tokio::test] +async fn test_s_maxage_override() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60, s-maxage=120") + .body(Full::new(Bytes::from("Shared cache data"))) + .unwrap(), + ) + })); + + // Request should be cached with s-maxage (120s, not 60s) + let req = Request::get("/s-maxage").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should hit cache + let req = Request::get("/s-maxage").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_only_cache_success_status() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::NOT_FOUND) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from("Not found"))) + .unwrap(), + ) + })); + + // 404 should not be cached + let req = Request::get("/not-found").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should not hit cache + let req = Request::get("/not-found").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT" + ); +} + +#[tokio::test] +async fn test_default_keyer() { + let keyer = DefaultKeyer; + let req = Request::get("/users/123?page=1").body(()).unwrap(); + let key = keyer.cache_key(&req); + + // Default keyer should only include path, not query + assert_eq!(key, "GET /users/123"); +} + +#[tokio::test] +async fn test_query_keyer() { + let keyer = QueryKeyer; + let req = Request::get("/users/123?page=1").body(()).unwrap(); + let key = keyer.cache_key(&req); + + // Query keyer should include query parameters + assert_eq!(key, "GET /users/123?page=1"); +} + +#[tokio::test] +async fn test_body_size_limit() { + let manager = MemoryCacheManager::new(); + let options = ServerCacheOptions { + max_body_size: 10, // Very small limit + ..Default::default() + }; + let layer = ServerCacheLayer::new(manager.clone()).with_options(options); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from( + "This is a long response body", + ))) + .unwrap(), + ) + })); + + // First request - too large to cache + let req = Request::get("/large").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request - should not hit cache + let req = Request::get("/large").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); +} + +#[tokio::test] +async fn test_ttl_constraints() { + let manager = MemoryCacheManager::new(); + let options = ServerCacheOptions { + min_ttl: Some(std::time::Duration::from_secs(30)), + max_ttl: Some(std::time::Duration::from_secs(90)), + ..Default::default() + }; + let layer = ServerCacheLayer::new(manager.clone()).with_options(options); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=10") // Below min_ttl + .body(Full::new(Bytes::from("Response"))) + .unwrap(), + ) + })); + + // Request should be cached with min_ttl (30s, not 10s) + let req = Request::get("/ttl").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should hit cache + let req = Request::get("/ttl").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_public_directive() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "public") + .body(Full::new(Bytes::from("Public data"))) + .unwrap(), + ) + })); + + // Request should be cached with default TTL + let req = Request::get("/public").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should hit cache + let req = Request::get("/public").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_no_cache_directive() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "no-cache") + .body(Full::new(Bytes::from("No cache"))) + .unwrap(), + ) + })); + + // Request should not be cached + let req = Request::get("/no-cache").body(Full::new(Bytes::new())).unwrap(); + let _res = service.ready().await.unwrap().call(req).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request should not hit cache + let req = Request::get("/no-cache").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT" + ); +} + +#[tokio::test] +async fn test_expires_future_date() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let future_time = + std::time::SystemTime::now() + std::time::Duration::from_secs(60); + let expires_date = httpdate::fmt_http_date(future_time); + + let mut service = + layer.layer(tower::service_fn(move |_req: Request>| { + let expires = expires_date.clone(); + async move { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("expires", expires) + .body(Full::new(Bytes::from("Cacheable with Expires"))) + .unwrap(), + ) + } + })); + + let req = + Request::get("/expires-future").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/expires-future").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_expires_past_date() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let past_time = + std::time::SystemTime::now() - std::time::Duration::from_secs(60); + let expires_date = httpdate::fmt_http_date(past_time); + + let mut service = + layer.layer(tower::service_fn(move |_req: Request>| { + let expires = expires_date.clone(); + async move { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("expires", expires) + .body(Full::new(Bytes::from("Already expired"))) + .unwrap(), + ) + } + })); + + let req = + Request::get("/expires-past").body(Full::new(Bytes::new())).unwrap(); + let _res = service.ready().await.unwrap().call(req).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/expires-past").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT" + ); +} + +#[tokio::test] +async fn test_expires_invalid_format() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("expires", "not-a-valid-date") + .body(Full::new(Bytes::from("Invalid expires"))) + .unwrap(), + ) + })); + + let req = + Request::get("/invalid-expires").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/invalid-expires").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT" + ); +} + +#[tokio::test] +async fn test_cache_control_overrides_expires() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let future_time = + std::time::SystemTime::now() + std::time::Duration::from_secs(10); + let expires_date = httpdate::fmt_http_date(future_time); + + let mut service = + layer.layer(tower::service_fn(move |_req: Request>| { + let expires = expires_date.clone(); + async move { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .header("expires", expires) + .body(Full::new(Bytes::from("Both headers"))) + .unwrap(), + ) + } + })); + + let req = + Request::get("/both-headers").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/both-headers").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_expires_only_no_cache_control() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let future_time = + std::time::SystemTime::now() + std::time::Duration::from_secs(60); + let expires_date = httpdate::fmt_http_date(future_time); + + let mut service = + layer.layer(tower::service_fn(move |_req: Request>| { + let expires = expires_date.clone(); + async move { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("expires", expires) + .body(Full::new(Bytes::from("Expires only"))) + .unwrap(), + ) + } + })); + + let req = + Request::get("/expires-only").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/expires-only").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_expires_with_ttl_constraints() { + let manager = MemoryCacheManager::new(); + let options = ServerCacheOptions { + max_ttl: Some(std::time::Duration::from_secs(30)), + ..Default::default() + }; + let layer = ServerCacheLayer::new(manager.clone()).with_options(options); + + let future_time = + std::time::SystemTime::now() + std::time::Duration::from_secs(3600); + let expires_date = httpdate::fmt_http_date(future_time); + + let mut service = + layer.layer(tower::service_fn(move |_req: Request>| { + let expires = expires_date.clone(); + async move { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("expires", expires) + .body(Full::new(Bytes::from( + "Long expires with max_ttl", + ))) + .unwrap(), + ) + } + })); + + let req = + Request::get("/expires-capped").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/expires-capped").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); +} + +#[tokio::test] +async fn test_concurrent_cache_requests() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let request_count = Arc::new(Mutex::new(0)); + let request_count_clone = request_count.clone(); + + let mut service = + layer.layer(tower::service_fn(move |_req: Request>| { + let count = request_count_clone.clone(); + async move { + // Increment counter to track how many times backend is called + *count.lock().unwrap() += 1; + tokio::time::sleep(tokio::time::Duration::from_millis(50)) + .await; + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from("Concurrent response"))) + .unwrap(), + ) + } + })); + + // Make multiple concurrent requests to the same endpoint + let mut handles = vec![]; + for _ in 0..5 { + let req = + Request::get("/concurrent").body(Full::new(Bytes::new())).unwrap(); + let mut svc = service.clone(); + let handle = tokio::spawn(async move { + svc.ready().await.unwrap().call(req).await.unwrap() + }); + handles.push(handle); + } + + // Wait for all requests to complete + let mut responses = vec![]; + for handle in handles { + responses.push(handle.await.unwrap()); + } + + // Verify all requests succeeded + assert_eq!(responses.len(), 5, "All concurrent requests should complete"); + + // At least one should be a MISS (the first one) + let miss_count = responses + .iter() + .filter(|r| { + r.headers().get("x-cache").map(|v| v == "MISS").unwrap_or(false) + }) + .count(); + assert!(miss_count >= 1, "At least one request should be a cache MISS"); + + // Give cache writes time to complete + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Verify subsequent request hits cache + let req = + Request::get("/concurrent").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "Subsequent request should hit cache" + ); +} + +#[tokio::test] +async fn test_stale_cache_expiration() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=1") // Very short TTL + .body(Full::new(Bytes::from("Expires soon"))) + .unwrap(), + ) + })); + + // First request - cache miss + let req = Request::get("/stale").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "First request should be a cache MISS" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request - should hit cache while still fresh + let req = Request::get("/stale").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "Request within TTL should be a cache HIT" + ); + + // Wait for cache entry to expire (1 second + buffer) + tokio::time::sleep(tokio::time::Duration::from_millis(1100)).await; + + // Third request - should be a miss due to expiration + let req = Request::get("/stale").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "Request after expiration should be a cache MISS" + ); +} + +#[tokio::test] +async fn test_multiple_directives() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header( + "cache-control", + "max-age=60, public, must-revalidate", + ) + .body(Full::new(Bytes::from("Multiple directives"))) + .unwrap(), + ) + })); + + // First request - cache miss + let req = + Request::get("/multi-directive").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "First request should be a cache MISS" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request - should hit cache (all directives should be recognized) + let req = + Request::get("/multi-directive").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "Cache should handle multiple directives correctly" + ); + + // Also test with spaces variations + let mut service2 = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60,public,s-maxage=120") + .body(Full::new(Bytes::from("No spaces"))) + .unwrap(), + ) + })); + + let req = + Request::get("/multi-no-space").body(Full::new(Bytes::new())).unwrap(); + let res = service2.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "Should handle directives without spaces" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/multi-no-space").body(Full::new(Bytes::new())).unwrap(); + let res = service2.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "Should cache with directives without spaces" + ); +} + +#[tokio::test] +async fn test_malformed_cache_control() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + // Test with invalid directive + let mut service1 = layer.clone().layer(tower::service_fn( + |_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "invalid-directive") + .body(Full::new(Bytes::from("Invalid directive"))) + .unwrap(), + ) + }, + )); + + let req = Request::get("/invalid-directive") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service1.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.status(), + StatusCode::OK, + "Should handle invalid directive gracefully" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = Request::get("/invalid-directive") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service1.ready().await.unwrap().call(req).await.unwrap(); + // Should not cache with invalid directive + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT", + "Should not cache with invalid directive" + ); + + // Test with malformed max-age value + let mut service2 = layer.clone().layer(tower::service_fn( + |_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=notanumber") + .body(Full::new(Bytes::from("Invalid max-age"))) + .unwrap(), + ) + }, + )); + + let req = + Request::get("/bad-max-age").body(Full::new(Bytes::new())).unwrap(); + let res = service2.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.status(), + StatusCode::OK, + "Should handle malformed max-age gracefully" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/bad-max-age").body(Full::new(Bytes::new())).unwrap(); + let res = service2.ready().await.unwrap().call(req).await.unwrap(); + // Should not cache with malformed max-age + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT", + "Should not cache with malformed max-age" + ); + + // Test with empty cache-control + let mut service3 = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "") + .body(Full::new(Bytes::from("Empty cache-control"))) + .unwrap(), + ) + })); + + let req = Request::get("/empty-cc").body(Full::new(Bytes::new())).unwrap(); + let res = service3.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.status(), + StatusCode::OK, + "Should handle empty cache-control gracefully" + ); +} + +#[tokio::test] +async fn test_path_parameter_preservation() { + // This is a regression test for issue #121 + // Verifies that request extensions (like Axum path parameters) are preserved + // through the caching layer and accessible to the handler + + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + // Counter to track handler invocations + let call_count = Arc::new(Mutex::new(0)); + let call_count_clone = call_count.clone(); + + let mut service = + layer.layer(tower::service_fn(move |req: Request>| { + let count = call_count_clone.clone(); + async move { + // Increment handler call counter + *count.lock().unwrap() += 1; + + // Extract the path parameter from request extensions + // This simulates what Axum does after routing + let path_params = req + .extensions() + .get::() + .expect("PathParams extension should be present"); + + // Generate response that includes the path parameter + // This proves the extension was preserved through the cache layer + let body = format!("User ID: {}", path_params.id); + + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from(body))) + .unwrap(), + ) + } + })); + + // First request with path parameter "123" - should be a cache miss + let mut req1 = + Request::get("/users/123").body(Full::new(Bytes::new())).unwrap(); + req1.extensions_mut().insert(PathParams { id: "123".to_string() }); + + let res1 = service.ready().await.unwrap().call(req1).await.unwrap(); + assert_eq!(res1.status(), StatusCode::OK); + assert_eq!(res1.headers().get("x-cache").unwrap(), "MISS"); + + // Collect body to verify the handler received the correct extension + let body1 = http_body_util::BodyExt::collect(res1.into_body()) + .await + .unwrap() + .to_bytes(); + assert_eq!( + body1, "User ID: 123", + "Handler should receive path parameter on cache miss" + ); + + // Verify handler was called once + assert_eq!( + *call_count.lock().unwrap(), + 1, + "Handler should be called on cache miss" + ); + + // Give cache write time to complete + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request to same path with same parameter - should be a cache hit + let mut req2 = + Request::get("/users/123").body(Full::new(Bytes::new())).unwrap(); + req2.extensions_mut().insert(PathParams { id: "123".to_string() }); + + let res2 = service.ready().await.unwrap().call(req2).await.unwrap(); + assert_eq!(res2.status(), StatusCode::OK); + assert_eq!(res2.headers().get("x-cache").unwrap(), "HIT"); + + // Verify cached response has correct content + let body2 = http_body_util::BodyExt::collect(res2.into_body()) + .await + .unwrap() + .to_bytes(); + assert_eq!( + body2, "User ID: 123", + "Cached response should have correct content" + ); + + // Verify handler was NOT called again (cache hit) + assert_eq!( + *call_count.lock().unwrap(), + 1, + "Handler should not be called on cache hit" + ); + + // Third request with different path parameter - should be a cache miss + // This verifies that different requests don't interfere with each other + let mut req3 = + Request::get("/users/456").body(Full::new(Bytes::new())).unwrap(); + req3.extensions_mut().insert(PathParams { id: "456".to_string() }); + + let res3 = service.ready().await.unwrap().call(req3).await.unwrap(); + assert_eq!(res3.status(), StatusCode::OK); + assert_eq!(res3.headers().get("x-cache").unwrap(), "MISS"); + + // Verify the handler received the NEW path parameter + let body3 = http_body_util::BodyExt::collect(res3.into_body()) + .await + .unwrap() + .to_bytes(); + assert_eq!( + body3, "User ID: 456", + "Handler should receive different path parameter for different request" + ); + + // Verify handler was called again for the new path + assert_eq!( + *call_count.lock().unwrap(), + 2, + "Handler should be called for new path" + ); +} + +#[tokio::test] +async fn test_request_extensions_not_stripped() { + // Verifies that the cache layer doesn't strip request extensions + // even when they're not used by the handler + + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + // Custom extension type + #[derive(Debug, Clone, PartialEq)] + struct CustomExtension { + value: String, + } + + let mut service = layer.layer(tower::service_fn( + |req: Request>| async move { + // Verify extension is still present + let ext = req.extensions().get::(); + assert!( + ext.is_some(), + "Extension should be preserved through cache layer" + ); + assert_eq!(ext.unwrap().value, "test-value"); + + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from("OK"))) + .unwrap(), + ) + }, + )); + + // Make request with custom extension + let mut req = Request::get("/test").body(Full::new(Bytes::new())).unwrap(); + req.extensions_mut() + .insert(CustomExtension { value: "test-value".to_string() }); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); +} + +#[tokio::test] +async fn test_cache_by_default_option() { + let manager = MemoryCacheManager::new(); + + // With cache_by_default = false (default), no caching without directives + let options_disabled = + ServerCacheOptions { cache_by_default: false, ..Default::default() }; + let layer = + ServerCacheLayer::new(manager.clone()).with_options(options_disabled); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from("No directives"))) + .unwrap(), + ) + })); + + let req = + Request::get("/no-directive").body(Full::new(Bytes::new())).unwrap(); + let _res = service.ready().await.unwrap().call(req).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/no-directive").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none() + || res.headers().get("x-cache").unwrap() != "HIT", + "Should not cache without directives when cache_by_default is false" + ); + + // With cache_by_default = true, should cache even without directives + let options_enabled = + ServerCacheOptions { cache_by_default: true, ..Default::default() }; + let layer = + ServerCacheLayer::new(manager.clone()).with_options(options_enabled); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from("No directives but cached"))) + .unwrap(), + ) + })); + + let req = Request::get("/cache-by-default") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = Request::get("/cache-by-default") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "Should cache without directives when cache_by_default is true" + ); +} + +#[tokio::test] +async fn test_custom_keyer() { + let manager = MemoryCacheManager::new(); + + // Create a custom keyer that includes a header in the cache key + let keyer = CustomKeyer::new(|req: &Request<()>| { + let lang = req + .headers() + .get("accept-language") + .and_then(|v| v.to_str().ok()) + .unwrap_or("en"); + format!("{} {} lang:{}", req.method(), req.uri().path(), lang) + }); + + let layer = ServerCacheLayer::with_keyer(manager.clone(), keyer); + + let mut service = layer.layer(tower::service_fn( + |req: Request>| async move { + let lang = req + .headers() + .get("accept-language") + .and_then(|v| v.to_str().ok()) + .unwrap_or("en"); + let body = format!("Response for {}", lang); + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from(body))) + .unwrap(), + ) + }, + )); + + // Request with English + let req = Request::get("/test") + .header("accept-language", "en") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Same request with English - should hit cache + let req = Request::get("/test") + .header("accept-language", "en") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); + + // Request with French - should miss cache (different key) + let req = Request::get("/test") + .header("accept-language", "fr") + .body(Full::new(Bytes::new())) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "Different language should have different cache key" + ); +} + +#[tokio::test] +async fn test_directive_parsing_edge_cases() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + // Test that "no-store-something" does NOT match "no-store" + let mut service1 = layer.clone().layer(tower::service_fn( + |_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + // This should NOT prevent caching since it's not "no-store" + .header("cache-control", "max-age=60, no-store-custom") + .body(Full::new(Bytes::from("Should be cached"))) + .unwrap(), + ) + }, + )); + + let req = + Request::get("/no-store-custom").body(Full::new(Bytes::new())).unwrap(); + let res = service1.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/no-store-custom").body(Full::new(Bytes::new())).unwrap(); + let res = service1.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "no-store-custom should not prevent caching" + ); + + // Test that "private-something" does NOT match "private" + let mut service2 = layer.clone().layer(tower::service_fn( + |_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60, private-ext") + .body(Full::new(Bytes::from("Should be cached"))) + .unwrap(), + ) + }, + )); + + let req = + Request::get("/private-ext").body(Full::new(Bytes::new())).unwrap(); + let res = service2.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let req = + Request::get("/private-ext").body(Full::new(Bytes::new())).unwrap(); + let res = service2.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "HIT", + "private-ext should not prevent caching" + ); +} + +#[tokio::test] +async fn test_zero_max_age() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=0") + .body(Full::new(Bytes::from("Zero TTL"))) + .unwrap(), + ) + })); + + // First request + let req = Request::get("/zero-ttl").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request - should not hit cache because TTL is 0 + let req = Request::get("/zero-ttl").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + // With zero TTL, entry is immediately stale + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "Zero max-age should result in immediately stale cache entry" + ); +} + +#[tokio::test] +async fn test_different_http_methods() { + let manager = MemoryCacheManager::new(); + let layer = ServerCacheLayer::new(manager.clone()); + + let mut service = layer.layer(tower::service_fn( + |req: Request>| async move { + let body = format!("Method: {}", req.method()); + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from(body))) + .unwrap(), + ) + }, + )); + + // GET request + let req = + Request::get("/method-test").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "MISS"); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Same GET request - should hit cache + let req = + Request::get("/method-test").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get("x-cache").unwrap(), "HIT"); + + // POST request to same path - should be a different cache key + let req = + Request::post("/method-test").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!( + res.headers().get("x-cache").unwrap(), + "MISS", + "POST should have different cache key than GET" + ); +} + +#[tokio::test] +async fn test_cache_status_headers_disabled() { + let manager = MemoryCacheManager::new(); + let options = ServerCacheOptions { + cache_status_headers: false, + ..Default::default() + }; + let layer = ServerCacheLayer::new(manager.clone()).with_options(options); + + let mut service = + layer.layer(tower::service_fn(|_req: Request>| async { + Ok::<_, std::io::Error>( + Response::builder() + .status(StatusCode::OK) + .header("cache-control", "max-age=60") + .body(Full::new(Bytes::from("No status headers"))) + .unwrap(), + ) + })); + + // First request + let req = Request::get("/no-status").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none(), + "Should not have x-cache header when disabled" + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Second request - should hit cache but no header + let req = Request::get("/no-status").body(Full::new(Bytes::new())).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert!( + res.headers().get("x-cache").is_none(), + "Should not have x-cache header when disabled, even on HIT" + ); +} diff --git a/http-cache/README.md b/http-cache/README.md index 96f8e29..b2cb74e 100644 --- a/http-cache/README.md +++ b/http-cache/README.md @@ -51,6 +51,11 @@ The following features are available. By default `manager-cacache` and `cacache- - **Reqwest**: See [README](https://github.com/06chaynes/http-cache/blob/main/http-cache-reqwest/README.md) for more details - **Tower**: See [README](https://github.com/06chaynes/http-cache/blob/main/http-cache-tower/README.md) for more details - **Surf**: See [README](https://github.com/06chaynes/http-cache/blob/main/http-cache-surf/README.md) for more details +- **Ureq**: See [README](https://github.com/06chaynes/http-cache/blob/main/http-cache-ureq/README.md) for more details + +## Server-Side Caching Middleware + +- **Tower Server**: See [README](https://github.com/06chaynes/http-cache/blob/main/http-cache-tower-server/README.md) for more details ## Additional Manager Implementations diff --git a/http-cache/src/managers/streaming_cache.rs b/http-cache/src/managers/streaming_cache.rs index 252c9b3..c789725 100644 --- a/http-cache/src/managers/streaming_cache.rs +++ b/http-cache/src/managers/streaming_cache.rs @@ -1582,8 +1582,8 @@ mod tests { /// Test concurrent access to reference counting #[tokio::test] async fn test_concurrent_reference_counting() { + use futures::future::join_all; use std::sync::Arc; - use tokio::task; let temp_dir = TempDir::new().unwrap(); let cache = @@ -1593,43 +1593,46 @@ mod tests { let shared_content = Bytes::from("concurrent test content"); let tasks_count = 10; - // Create multiple tasks that store identical content concurrently - let mut handles = Vec::new(); - for i in 0..tasks_count { - let cache = Arc::clone(&cache); - let content = shared_content.clone(); - let url = request_url.clone(); + // Create multiple futures that store identical content concurrently + let put_futures: Vec<_> = (0..tasks_count) + .map(|i| { + let cache = Arc::clone(&cache); + let content = shared_content.clone(); + let url = request_url.clone(); - let handle = task::spawn(async move { - let response = Response::builder() - .status(200) - .header("x-task-id", i.to_string()) - .body(Full::new(content)) - .unwrap(); + async move { + let response = Response::builder() + .status(200) + .header("x-task-id", i.to_string()) + .body(Full::new(content)) + .unwrap(); - let policy = CachePolicy::new( - &http::request::Request::builder() - .method("GET") - .uri(format!("/concurrent-test-{}", i)) - .body(()) - .unwrap() - .into_parts() - .0, - &response.clone().map(|_| ()), - ); + let policy = CachePolicy::new( + &http::request::Request::builder() + .method("GET") + .uri(format!("/concurrent-test-{}", i)) + .body(()) + .unwrap() + .into_parts() + .0, + &response.clone().map(|_| ()), + ); - cache - .put(format!("concurrent-key-{}", i), response, policy, url) - .await - .unwrap(); - }); - handles.push(handle); - } + cache + .put( + format!("concurrent-key-{}", i), + response, + policy, + url, + ) + .await + .unwrap(); + } + }) + .collect(); - // Wait for all tasks to complete - for handle in handles { - handle.await.unwrap(); - } + // Wait for all futures to complete + join_all(put_futures).await; // Verify all entries can be retrieved for i in 0..tasks_count { @@ -1645,35 +1648,37 @@ mod tests { assert!(content_path.exists(), "Shared content file should exist"); // Delete half the entries concurrently - let mut delete_handles = Vec::new(); - for i in 0..tasks_count / 2 { - let cache = Arc::clone(&cache); - let handle = task::spawn(async move { - cache.delete(&format!("concurrent-key-{}", i)).await.unwrap(); - }); - delete_handles.push(handle); - } + let delete_futures: Vec<_> = (0..tasks_count / 2) + .map(|i| { + let cache = Arc::clone(&cache); + async move { + cache + .delete(&format!("concurrent-key-{}", i)) + .await + .unwrap(); + } + }) + .collect(); - for handle in delete_handles { - handle.await.unwrap(); - } + join_all(delete_futures).await; // Content should still exist (remaining references) assert!(content_path.exists(), "Content file should still exist"); // Delete remaining entries - let mut final_delete_handles = Vec::new(); - for i in tasks_count / 2..tasks_count { - let cache = Arc::clone(&cache); - let handle = task::spawn(async move { - cache.delete(&format!("concurrent-key-{}", i)).await.unwrap(); - }); - final_delete_handles.push(handle); - } + let final_delete_futures: Vec<_> = (tasks_count / 2..tasks_count) + .map(|i| { + let cache = Arc::clone(&cache); + async move { + cache + .delete(&format!("concurrent-key-{}", i)) + .await + .unwrap(); + } + }) + .collect(); - for handle in final_delete_handles { - handle.await.unwrap(); - } + join_all(final_delete_futures).await; // Now content should be deleted assert!( diff --git a/justfile b/justfile index ae8331d..155be8f 100644 --- a/justfile +++ b/justfile @@ -21,6 +21,8 @@ cd http-cache-ureq && cargo nextest run --all-features echo "\n----------\nTower middleware:\n" cd http-cache-tower && cargo nextest run --all-features + echo "\n----------\nTower server middleware:\n" + cd http-cache-tower-server && cargo nextest run --all-features echo "\n----------\nQuickcache middleware:\n" cd http-cache-quickcache && cargo nextest run --all-features @@ -43,6 +45,8 @@ cd http-cache-ureq && cargo test --doc --all-features echo "\n----------\nTower middleware:\n" cd http-cache-tower && cargo test --doc --all-features + echo "\n----------\nTower server middleware:\n" + cd http-cache-tower-server && cargo test --doc --all-features echo "\n----------\nQuickcache middleware:\n" cd http-cache-quickcache && cargo test --doc --all-features @@ -64,6 +68,8 @@ cd http-cache-ureq && cargo check --all-features echo "\n----------\nTower middleware:\n" cd http-cache-tower && cargo check --all-features + echo "\n----------\nTower server middleware:\n" + cd http-cache-tower-server && cargo check --all-features echo "\n----------\nQuickcache middleware:\n" cd http-cache-quickcache && cargo check --all-features @@ -135,6 +141,8 @@ changelog TAG: cd http-cache-ureq && cargo clippy --lib --tests --all-targets --all-features -- -D warnings echo "\n----------\nTower middleware:\n" cd http-cache-tower && cargo clippy --lib --tests --all-targets --all-features -- -D warnings + echo "\n----------\nTower server middleware:\n" + cd http-cache-tower-server && cargo clippy --lib --tests --all-targets --all-features -- -D warnings echo "\n----------\nQuickcache middleware:\n" cd http-cache-quickcache && cargo clippy --lib --tests --all-targets --all-features -- -D warnings echo "\n----------\nFormatting check:\n" @@ -156,6 +164,8 @@ changelog TAG: cd http-cache-ureq && cargo msrv find echo "\n----------\nTower middleware:\n" cd http-cache-tower && cargo msrv find + echo "\n----------\nTower server middleware:\n" + cd http-cache-tower-server && cargo msrv find echo "\n----------\nQuickcache middleware:\n" cd http-cache-quickcache && cargo msrv find @@ -171,6 +181,8 @@ changelog TAG: cd http-cache-ureq && cargo msrv verify echo "\n----------\nTower middleware:\n" cd http-cache-tower && cargo msrv verify + echo "\n----------\nTower server middleware:\n" + cd http-cache-tower-server && cargo msrv verify echo "\n----------\nQuickcache middleware:\n" cd http-cache-quickcache && cargo msrv verify @@ -190,6 +202,8 @@ changelog TAG: cd http-cache-ureq && cargo publish --dry-run echo "Tower middleware:" cd http-cache-tower && cargo publish --dry-run + echo "Tower server middleware:" + cd http-cache-tower-server && cargo publish --dry-run echo "Quickcache middleware:" cd http-cache-quickcache && cargo publish --dry-run @@ -212,6 +226,8 @@ changelog TAG: cd http-cache-ureq && cargo publish echo "Tower middleware:" cd http-cache-tower && cargo publish + echo "Tower server middleware:" + cd http-cache-tower-server && cargo publish echo "Quickcache middleware:" cd http-cache-quickcache && cargo publish