|
| 1 | +use std::pin::Pin; |
| 2 | +use std::task::{Context as StdContext, Poll}; |
| 3 | + |
| 4 | +use tonic::body::Body; |
| 5 | +use tower::{Layer, Service}; |
| 6 | + |
| 7 | +/// Simple health check layer that intercepts requests to root path. |
| 8 | +/// |
| 9 | +/// The root path is used by load-balancers and options requests to check the health |
| 10 | +/// of the server. Since our gRPC server doesn't serve anything on the root |
| 11 | +/// these get logged as errors. This layer instead intercepts these requests |
| 12 | +/// and returns `Ok(200)`, preventing the errors. |
| 13 | +#[derive(Clone)] |
| 14 | +pub struct HealthCheckLayer; |
| 15 | + |
| 16 | +impl<S> Layer<S> for HealthCheckLayer { |
| 17 | + type Service = HealthCheckService<S>; |
| 18 | + |
| 19 | + fn layer(&self, service: S) -> Self::Service { |
| 20 | + HealthCheckService { inner: service } |
| 21 | + } |
| 22 | +} |
| 23 | + |
| 24 | +#[derive(Clone)] |
| 25 | +pub struct HealthCheckService<S> { |
| 26 | + inner: S, |
| 27 | +} |
| 28 | + |
| 29 | +impl<S> Service<http::Request<Body>> for HealthCheckService<S> |
| 30 | +where |
| 31 | + S: Service<http::Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static, |
| 32 | + S::Future: Send + 'static, |
| 33 | + S::Error: Send + 'static, |
| 34 | +{ |
| 35 | + type Response = http::Response<Body>; |
| 36 | + type Error = S::Error; |
| 37 | + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; |
| 38 | + |
| 39 | + fn poll_ready(&mut self, cx: &mut StdContext<'_>) -> Poll<Result<(), Self::Error>> { |
| 40 | + self.inner.poll_ready(cx) |
| 41 | + } |
| 42 | + |
| 43 | + fn call(&mut self, req: http::Request<Body>) -> Self::Future { |
| 44 | + if req.uri().path() == "/" { |
| 45 | + let response = http::Response::builder() |
| 46 | + .status(http::StatusCode::OK) |
| 47 | + .body(Body::empty()) |
| 48 | + .expect("valid empty 200 response"); |
| 49 | + Box::pin(async move { Ok(response) }) |
| 50 | + } else { |
| 51 | + Box::pin(self.inner.call(req)) |
| 52 | + } |
| 53 | + } |
| 54 | +} |
0 commit comments