|
| 1 | +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +//! Merge a ValidateRequest and an AsyncAuthorizeRequest |
| 5 | +//! |
| 6 | +//! executes a ValidateRequest returning the request if it succeeds |
| 7 | +//! or else, executed the future and return it |
| 8 | +
|
| 9 | +use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; |
| 10 | + |
| 11 | +use axum::http::{Request, Response}; |
| 12 | +use pin_project::pin_project; |
| 13 | +use tower_http::{auth::AsyncAuthorizeRequest, validate_request::ValidateRequest}; |
| 14 | + |
| 15 | +pub trait OrExt<T, B, Resp>: Sized { |
| 16 | + fn or(self, other: T) -> Or<Self, T, B, Resp>; |
| 17 | +} |
| 18 | + |
| 19 | +impl<T, A, B, Resp, Fut> OrExt<A, B, Resp> for T |
| 20 | +where |
| 21 | + B: 'static + Send, |
| 22 | + Resp: 'static + Send, |
| 23 | + T: ValidateRequest<B, ResponseBody = Resp>, |
| 24 | + A: AsyncAuthorizeRequest<B, RequestBody = B, ResponseBody = Resp, Future = Fut> |
| 25 | + + Clone |
| 26 | + + 'static |
| 27 | + + Send, |
| 28 | + Fut: Future<Output = Result<Request<B>, Response<Resp>>> + Send, |
| 29 | +{ |
| 30 | + fn or(self, other: A) -> Or<Self, A, B, Resp> { |
| 31 | + Or(self, other, PhantomData) |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +pub struct Or<T, E, B, Resp>(T, E, PhantomData<fn(B) -> Resp>); |
| 36 | + |
| 37 | +impl<T, E, B, Resp> Clone for Or<T, E, B, Resp> |
| 38 | +where |
| 39 | + T: Clone, |
| 40 | + E: Clone, |
| 41 | +{ |
| 42 | + fn clone(&self) -> Self { |
| 43 | + Self(self.0.clone(), self.1.clone(), self.2) |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +impl<T, E, Req, Resp, Fut> AsyncAuthorizeRequest<Req> for Or<T, E, Req, Resp> |
| 48 | +where |
| 49 | + Req: 'static + Send, |
| 50 | + Resp: 'static + Send, |
| 51 | + T: ValidateRequest<Req, ResponseBody = Resp>, |
| 52 | + E: AsyncAuthorizeRequest<Req, RequestBody = Req, ResponseBody = Resp, Future = Fut> |
| 53 | + + Clone |
| 54 | + + 'static |
| 55 | + + Send, |
| 56 | + Fut: Future<Output = Result<Request<Req>, Response<Resp>>> + Send, |
| 57 | +{ |
| 58 | + type RequestBody = Req; |
| 59 | + type ResponseBody = Resp; |
| 60 | + |
| 61 | + type Future = OrFuture<Fut, Req, Resp>; |
| 62 | + |
| 63 | + fn authorize(&mut self, mut request: axum::http::Request<Req>) -> Self::Future { |
| 64 | + let mut this = self.1.clone(); |
| 65 | + if self.0.validate(&mut request).is_ok() { |
| 66 | + return OrFuture::with_result(Ok(request)); |
| 67 | + } |
| 68 | + OrFuture::with_future(this.authorize(request)) |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +#[pin_project::pin_project(project = KindProj)] |
| 73 | +pub enum Kind<Fut, Req, Resp> { |
| 74 | + QueryResult { |
| 75 | + #[pin] |
| 76 | + fut: Fut, |
| 77 | + }, |
| 78 | + ReturnResult { |
| 79 | + validation_result: Option<Result<Request<Req>, Response<Resp>>>, |
| 80 | + }, |
| 81 | +} |
| 82 | + |
| 83 | +#[pin_project] |
| 84 | +pub struct OrFuture<Fut, Req, Resp> { |
| 85 | + #[pin] |
| 86 | + kind: Kind<Fut, Req, Resp>, |
| 87 | +} |
| 88 | + |
| 89 | +impl<Fut, Req, Resp> OrFuture<Fut, Req, Resp> { |
| 90 | + fn with_result(validation_result: Result<Request<Req>, Response<Resp>>) -> Self { |
| 91 | + let validation_result = Some(validation_result); |
| 92 | + Self { |
| 93 | + kind: Kind::ReturnResult { validation_result }, |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + fn with_future(fut: Fut) -> Self { |
| 98 | + Self { |
| 99 | + kind: Kind::QueryResult { fut }, |
| 100 | + } |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +impl<Fut, Req, Resp> Future for OrFuture<Fut, Req, Resp> |
| 105 | +where |
| 106 | + Fut: Future<Output = Result<Request<Req>, Response<Resp>>>, |
| 107 | +{ |
| 108 | + type Output = Result<Request<Req>, Response<Resp>>; |
| 109 | + |
| 110 | + fn poll( |
| 111 | + self: Pin<&mut Self>, |
| 112 | + cx: &mut std::task::Context<'_>, |
| 113 | + ) -> std::task::Poll<Self::Output> { |
| 114 | + let this = self.project(); |
| 115 | + match this.kind.project() { |
| 116 | + KindProj::QueryResult { fut } => fut.poll(cx), |
| 117 | + KindProj::ReturnResult { validation_result } => { |
| 118 | + Poll::Ready(validation_result.take().expect("cannot poll twice")) |
| 119 | + } |
| 120 | + } |
| 121 | + } |
| 122 | +} |
0 commit comments