Skip to content

Commit e1350ca

Browse files
committed
Add S3CachingServiceProxy
1 parent 60847b0 commit e1350ca

File tree

5 files changed

+151
-2
lines changed

5 files changed

+151
-2
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ tracing-subscriber = { version = "0.3", features = [
8080

8181
[dev-dependencies]
8282
http = "1"
83+
http-body-util = "0.1"
84+
hyper = { version = "1", features = ["client"] }
8385
mock_instant = "0.6"
8486
testcontainers = "0.27"
8587
testcontainers-modules = { version = "0.15", features = ["minio"] }

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ mod fifo_cache;
4040
mod metrics_writer;
4141
mod proxy;
4242
mod s3_cache;
43+
mod service;
4344
mod statistics;
4445
mod telemetry;
4546

@@ -140,11 +141,11 @@ where
140141
config.cache_dry_run,
141142
);
142143

143-
// Build S3 service with auth
144+
// Build S3 service with auth, wrapped in a health check layer
144145
let service = {
145146
let mut b = S3ServiceBuilder::new(caching_proxy);
146147
b.set_auth(auth::create_auth(&config));
147-
b.build()
148+
service::S3CachingServiceProxy::new(b.build())
148149
};
149150

150151
// Start Prometheus metrics writer if configured

src/service.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use bytes::Bytes;
2+
use hyper::service::Service;
3+
use hyper::{Method, Request, Response, body::Incoming};
4+
use s3s::{Body, HttpError};
5+
use std::future::Future;
6+
use std::pin::Pin;
7+
8+
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
9+
10+
/// Wraps an S3 service and short-circuits `GET /` and `GET /health` requests,
11+
/// returning `200 OK` with a plain-text `"Status OK"` body without forwarding
12+
/// them to the S3 layer or requiring authentication.
13+
pub struct S3CachingServiceProxy<S> {
14+
inner: S,
15+
}
16+
17+
impl<S> S3CachingServiceProxy<S> {
18+
pub fn new(inner: S) -> Self {
19+
Self { inner }
20+
}
21+
}
22+
23+
impl<S: Clone> Clone for S3CachingServiceProxy<S> {
24+
fn clone(&self) -> Self {
25+
Self {
26+
inner: self.inner.clone(),
27+
}
28+
}
29+
}
30+
31+
impl<S> Service<Request<Incoming>> for S3CachingServiceProxy<S>
32+
where
33+
S: Service<Request<Incoming>, Response = Response<Body>, Error = HttpError>,
34+
S::Future: Send + 'static,
35+
{
36+
type Response = Response<Body>;
37+
type Error = HttpError;
38+
type Future = BoxFuture<Result<Self::Response, Self::Error>>;
39+
40+
fn call(&self, req: Request<Incoming>) -> Self::Future {
41+
if req.method() == Method::GET && (req.uri().path() == "/" || req.uri().path() == "/health")
42+
{
43+
let response = Response::builder()
44+
.status(200)
45+
.body(Body::from(Bytes::from_static(b"Status OK")))
46+
.unwrap();
47+
Box::pin(std::future::ready(Ok(response)))
48+
} else {
49+
Box::pin(self.inner.call(req))
50+
}
51+
}
52+
}

tests/integration_health.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use bytes::Bytes;
2+
use http_body_util::{BodyExt, Empty};
3+
use hyper_util::rt::TokioIo;
4+
use std::net::SocketAddr;
5+
use tokio::net::TcpStream;
6+
7+
async fn start_test_server() -> (SocketAddr, tokio::sync::oneshot::Sender<()>) {
8+
let config = s3_cache::Config {
9+
listen_addr: "127.0.0.1:0".parse().unwrap(),
10+
upstream_endpoint: "http://127.0.0.1:1".to_string(),
11+
upstream_access_key_id: "test".to_string(),
12+
upstream_secret_access_key: "test".to_string(),
13+
upstream_region: "us-east-1".to_string(),
14+
client_access_key_id: "testclient".to_string(),
15+
client_secret_access_key: "testsecret".to_string(),
16+
cache_enabled: false,
17+
cache_dry_run: false,
18+
cache_shards: 4,
19+
cache_max_entries: 100,
20+
cache_max_size_bytes: 1024 * 1024,
21+
cache_max_object_size_bytes: 1024,
22+
cache_ttl_seconds: 60,
23+
worker_threads: 2,
24+
otel_grpc_endpoint_url: None,
25+
prometheus_textfile_dir: None,
26+
};
27+
28+
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
29+
let (addr_tx, addr_rx) = tokio::sync::oneshot::channel::<SocketAddr>();
30+
31+
tokio::spawn(async move {
32+
let shutdown = async move {
33+
let _ = shutdown_rx.await;
34+
};
35+
let _ = s3_cache::start_app_with_shutdown(config, shutdown, addr_tx).await;
36+
});
37+
38+
let addr = addr_rx.await.expect("server startup failed");
39+
(addr, shutdown_tx)
40+
}
41+
42+
async fn http_get(addr: SocketAddr, path: &str) -> (u16, String) {
43+
let stream = TcpStream::connect(addr).await.unwrap();
44+
let io = TokioIo::new(stream);
45+
46+
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
47+
tokio::spawn(conn);
48+
49+
let req = hyper::Request::builder()
50+
.method(hyper::Method::GET)
51+
.uri(path)
52+
.header("Host", "localhost")
53+
.body(Empty::<Bytes>::new())
54+
.unwrap();
55+
56+
let resp = sender.send_request(req).await.unwrap();
57+
let status = resp.status().as_u16();
58+
let body = resp.collect().await.unwrap().to_bytes();
59+
60+
(status, String::from_utf8_lossy(&body).to_string())
61+
}
62+
63+
// MARK: - Health
64+
65+
#[tokio::test(flavor = "multi_thread")]
66+
async fn health_check_ok() {
67+
let (addr, _shutdown) = start_test_server().await;
68+
69+
let (status, body) = http_get(addr, "/health").await;
70+
71+
assert_eq!(status, 200);
72+
assert_eq!(body, "Status OK");
73+
}
74+
75+
#[tokio::test(flavor = "multi_thread")]
76+
async fn health_check_root_ok() {
77+
let (addr, _shutdown) = start_test_server().await;
78+
79+
let (status, body) = http_get(addr, "/").await;
80+
81+
assert_eq!(status, 200);
82+
assert_eq!(body, "Status OK");
83+
}
84+
85+
#[tokio::test(flavor = "multi_thread")]
86+
async fn health_check_does_not_require_auth() {
87+
let (addr, _shutdown) = start_test_server().await;
88+
89+
// No Authorization header — must still succeed
90+
let (status, _) = http_get(addr, "/health").await;
91+
92+
assert_eq!(status, 200);
93+
}

0 commit comments

Comments
 (0)