Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 83 additions & 20 deletions src/governor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use axum::body::Body;
use governor::{
clock::{DefaultClock, QuantaInstant},
clock::{Clock, DefaultClock, QuantaInstant},
middleware::{NoOpMiddleware, RateLimitingMiddleware, StateInformationMiddleware},
state::keyed::DefaultKeyedStateStore,
Quota, RateLimiter,
Expand All @@ -17,8 +17,7 @@ pub const DEFAULT_BURST_SIZE: u32 = 8;

// Required by Governor's RateLimiter to share it across threads
// See Governor User Guide: https://docs.rs/governor/0.6.0/governor/_guide/index.html
pub type SharedRateLimiter<Key, M> =
Arc<RateLimiter<Key, DefaultKeyedStateStore<Key>, DefaultClock, M>>;
pub type SharedRateLimiter<K, M, C> = Arc<RateLimiter<K, DefaultKeyedStateStore<K>, C, M>>;

/// Helper struct for building a configuration for the governor middleware.
///
Expand Down Expand Up @@ -50,13 +49,35 @@ pub type SharedRateLimiter<Key, M> =
/// .unwrap();
/// ```
#[derive(Debug, Eq, Clone, PartialEq)]
pub struct GovernorConfigBuilder<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> {
pub struct GovernorConfigBuilder<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug = DefaultClock,
> {
period: Duration,
burst_size: u32,
methods: Option<Vec<Method>>,
key_extractor: K,
error_handler: ErrorHandler,
middleware: PhantomData<M>,
clock: C,
}

impl<C: Clock + Clone + std::fmt::Debug>
GovernorConfigBuilder<PeerIpKeyExtractor, NoOpMiddleware<C::Instant>, C>
{
/// Creates a `GovernorConfigBuilder` with default parameters and a provided `Clock``
pub fn default_with_clock(clock: C) -> Self {
Self {
period: DEFAULT_PERIOD,
burst_size: DEFAULT_BURST_SIZE,
methods: None,
key_extractor: PeerIpKeyExtractor,
error_handler: ErrorHandler::default(),
middleware: PhantomData,
clock,
}
}
}

// function for handling GovernorError and produce valid http Response type.
Expand Down Expand Up @@ -93,7 +114,12 @@ impl Default for GovernorConfigBuilder<PeerIpKeyExtractor, NoOpMiddleware> {
}
}

impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBuilder<K, M> {
impl<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug,
> GovernorConfigBuilder<K, M, C>
{
/// Set handler function for handling [GovernorError]
/// # Example
/// ```rust
Expand Down Expand Up @@ -127,8 +153,10 @@ impl<M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBuilder<PeerIpKeyEx
key_extractor: PeerIpKeyExtractor,
error_handler: ErrorHandler::default(),
middleware: PhantomData,
clock: DefaultClock::default(),
}
}

/// Set the interval after which one element of the quota is replenished.
///
/// **The interval must not be zero.**
Expand Down Expand Up @@ -169,7 +197,12 @@ impl<M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBuilder<PeerIpKeyEx
}

/// Sets configuration options when any Key Extractor is provided
impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBuilder<K, M> {
impl<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug,
> GovernorConfigBuilder<K, M, C>
{
/// Set the interval after which one element of the quota is replenished.
///
/// **The interval must not be zero.**
Expand Down Expand Up @@ -220,16 +253,18 @@ impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBu
pub fn key_extractor<K2: KeyExtractor>(
&mut self,
key_extractor: K2,
) -> GovernorConfigBuilder<K2, M> {
) -> GovernorConfigBuilder<K2, M, C> {
GovernorConfigBuilder {
period: self.period,
burst_size: self.burst_size,
methods: self.methods.to_owned(),
key_extractor,
error_handler: self.error_handler.clone(),
middleware: PhantomData,
clock: self.clock.clone(),
}
}

/// Set ratelimit headers to response, the headers is
/// - `x-ratelimit-limit` - Request limit
/// - `x-ratelimit-remaining` - The number of requests left for the time window
Expand All @@ -241,28 +276,31 @@ impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBu
///
/// [`methods`]: crate::GovernorConfigBuilder::methods()
/// [`use_headers`]: Self::use_headers
pub fn use_headers(&mut self) -> GovernorConfigBuilder<K, StateInformationMiddleware> {
pub fn use_headers(&mut self) -> GovernorConfigBuilder<K, StateInformationMiddleware, C> {
GovernorConfigBuilder {
period: self.period,
burst_size: self.burst_size,
methods: self.methods.to_owned(),
key_extractor: self.key_extractor.clone(),
error_handler: self.error_handler.clone(),
middleware: PhantomData,
clock: self.clock.clone(),
}
}

/// Finish building the configuration and return the configuration for the middleware.
/// Returns `None` if either burst size or period interval are zero.
pub fn finish(&mut self) -> Option<GovernorConfig<K, M>> {
pub fn finish(&mut self) -> Option<GovernorConfig<K, M, C>> {
if self.burst_size != 0 && self.period.as_nanos() != 0 {
Some(GovernorConfig {
key_extractor: self.key_extractor.clone(),
limiter: Arc::new(
RateLimiter::keyed(
RateLimiter::<_, _, _, M>::new(
Quota::with_period(self.period)
.unwrap()
.allow_burst(NonZeroU32::new(self.burst_size).unwrap()),
DefaultKeyedStateStore::new(),
self.clock.clone(),
)
.with_middleware::<M>(),
),
Expand All @@ -277,15 +315,24 @@ impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBu

#[derive(Debug, Clone)]
/// Configuration for the Governor middleware.
pub struct GovernorConfig<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> {
pub struct GovernorConfig<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug = DefaultClock,
> {
key_extractor: K,
limiter: SharedRateLimiter<K::Key, M>,
limiter: SharedRateLimiter<K::Key, M, C>,
methods: Option<Vec<Method>>,
error_handler: ErrorHandler,
}

impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfig<K, M> {
pub fn limiter(&self) -> &SharedRateLimiter<K::Key, M> {
impl<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug,
> GovernorConfig<K, M, C>
{
pub fn limiter(&self) -> &SharedRateLimiter<K::Key, M, C> {
&self.limiter
}
}
Expand All @@ -312,6 +359,7 @@ impl<M: RateLimitingMiddleware<QuantaInstant>> GovernorConfig<PeerIpKeyExtractor
key_extractor: PeerIpKeyExtractor,
error_handler: ErrorHandler::default(),
middleware: PhantomData,
clock: DefaultClock::default(),
}
.finish()
.unwrap()
Expand All @@ -322,16 +370,25 @@ impl<M: RateLimitingMiddleware<QuantaInstant>> GovernorConfig<PeerIpKeyExtractor
/// contains everything needed to implement a middleware
/// https://stegosaurusdormant.com/understanding-derive-clone/
#[derive(Debug)]
pub struct Governor<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S> {
pub struct Governor<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
S,
C: Clock + Clone + std::fmt::Debug = DefaultClock,
> {
pub key_extractor: K,
pub limiter: SharedRateLimiter<K::Key, M>,
pub limiter: SharedRateLimiter<K::Key, M, C>,
pub methods: Option<Vec<Method>>,
pub inner: S,
error_handler: ErrorHandler,
}

impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S: Clone> Clone
for Governor<K, M, S>
impl<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
S: Clone,
C: Clock + Clone + std::fmt::Debug,
> Clone for Governor<K, M, S, C>
{
fn clone(&self) -> Self {
Self {
Expand All @@ -344,9 +401,15 @@ impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S: Clone> Clone
}
}

impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S> Governor<K, M, S> {
impl<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
S,
C: Clock + Clone + std::fmt::Debug,
> Governor<K, M, S, C>
{
/// Create new governor middleware factory from configuration.
pub fn new(inner: S, config: &GovernorConfig<K, M>) -> Self {
pub fn new(inner: S, config: &GovernorConfig<K, M, C>) -> Self {
Governor {
key_extractor: config.key_extractor.clone(),
limiter: config.limiter.clone(),
Expand Down
34 changes: 22 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub mod errors;
pub mod governor;
pub mod key_extractor;
use crate::governor::{Governor, GovernorConfig};
use ::governor::clock::{Clock, DefaultClock, QuantaInstant};
use ::governor::clock::Clock;
use ::governor::middleware::{NoOpMiddleware, RateLimitingMiddleware, StateInformationMiddleware};
use axum::body::Body;
pub use errors::GovernorError;
Expand All @@ -24,39 +24,48 @@ use std::{future::Future, pin::Pin, task::ready};
use tower::{Layer, Service};

/// The Layer type that implements tower::Layer and is passed into `.layer()`
pub struct GovernorLayer<K, M>
pub struct GovernorLayer<K, M, C>
where
K: KeyExtractor,
M: RateLimitingMiddleware<QuantaInstant>,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug,
{
pub config: Arc<GovernorConfig<K, M>>,
pub config: Arc<GovernorConfig<K, M, C>>,
}

impl<K, M, S> Layer<S> for GovernorLayer<K, M>
impl<K, M, S, C> Layer<S> for GovernorLayer<K, M, C>
where
K: KeyExtractor,
M: RateLimitingMiddleware<QuantaInstant>,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug,
{
type Service = Governor<K, M, S>;
type Service = Governor<K, M, S, C>;

fn layer(&self, inner: S) -> Self::Service {
Governor::new(inner, &self.config)
}
}

/// https://stegosaurusdormant.com/understanding-derive-clone/
impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> Clone for GovernorLayer<K, M> {
impl<
K: KeyExtractor,
M: RateLimitingMiddleware<C::Instant>,
C: Clock + Clone + std::fmt::Debug,
> Clone for GovernorLayer<K, M, C>
{
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
}
}
}

// Implement tower::Service for Governor
impl<K, S, ReqBody> Service<Request<ReqBody>> for Governor<K, NoOpMiddleware, S>
impl<K, S, ReqBody, C> Service<Request<ReqBody>> for Governor<K, NoOpMiddleware<C::Instant>, S, C>
where
K: KeyExtractor,
S: Service<Request<ReqBody>, Response = Response<Body>>,
C: Clock + Clone + std::fmt::Debug,
{
type Response = S::Response;
type Error = S::Error;
Expand Down Expand Up @@ -89,7 +98,7 @@ where

Err(negative) => {
let wait_time = negative
.wait_time_from(DefaultClock::default().now())
.wait_time_from(self.limiter.clock().now())
.as_secs();

#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -214,10 +223,11 @@ where
}

// Implementation of Service for Governor using the StateInformationMiddleware.
impl<K, S, ReqBody> Service<Request<ReqBody>> for Governor<K, StateInformationMiddleware, S>
impl<K, S, ReqBody, C> Service<Request<ReqBody>> for Governor<K, StateInformationMiddleware, S, C>
where
K: KeyExtractor,
S: Service<Request<ReqBody>, Response = Response<Body>>,
C: Clock + Clone + std::fmt::Debug,
// Body type of response must impl From<String> trait to convert potential error
// produced by governor to re
{
Expand Down Expand Up @@ -258,7 +268,7 @@ where

Err(negative) => {
let wait_time = negative
.wait_time_from(DefaultClock::default().now())
.wait_time_from(self.limiter.clock().now())
.as_secs();

#[cfg(feature = "tracing")]
Expand Down
Loading