diff --git a/src/authorize.rs b/src/authorize.rs new file mode 100644 index 0000000..47a8ae6 --- /dev/null +++ b/src/authorize.rs @@ -0,0 +1,53 @@ +use core::fmt; +use std::sync::Arc; + +use reqwest::Request; + +use crate::credentials::{AuthorizationError, Credentials}; + +// AUTHORIZE /////////////////////////////////////////////////////////////////// + +pub trait Authorize: fmt::Debug + Send + Sync + 'static { + type Error: Send + 'static; + + /// Appends an `Authorization` header to the `request`, if this authorizer + /// has credentials and unless it is already set. + /// + /// Returns `true` if the `Authorization` header was inserted. + /// + /// # Errors + /// + /// Upon failure to produce the `Authorization` header value or if the request + /// has too many headers. + fn authorize(&self, request: &mut Request) -> Result; +} + +impl Authorize for Credentials { + type Error = AuthorizationError; + + #[inline] + fn authorize(&self, request: &mut Request) -> Result { + self.as_ref().authorize(request) + } +} + +impl Authorize for Option { + type Error = T::Error; + + #[inline] + fn authorize(&self, request: &mut Request) -> Result { + match self.as_ref() { + None => Ok(false), + Some(authorizer) => authorizer.authorize(request), + } + } +} + +impl Authorize for Arc { + type Error = T::Error; + + #[inline] + fn authorize(&self, request: &mut Request) -> Result { + (&**self).authorize(request) + } +} diff --git a/src/client.rs b/src/client.rs index b1a3165..f55b675 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,146 +1,83 @@ -//! HTTP client - use core::fmt; use reqwest::{Request, Response}; -use crate::{ - credentials::{AuthorizationError, Credentials}, - execute::ExecuteRequest, -}; +use crate::{authorize::Authorize, credentials::Credentials, execute::ExecuteRequest}; // CLIENT ERROR //////////////////////////////////////////////////////////////// #[derive(Debug, thiserror::Error)] -pub enum ClientError { +pub enum ClientError { #[error("failed to authorize request, {0}")] - Authorize(#[from] AuthorizationError), + Authorize(A), #[error("failed to execute request, {0}")] Execute(E), } // CLIENT ////////////////////////////////////////////////////////////////////// -/// HTTP client with optional [`Credentials`]. -/// -/// When credentials are provided, the client, will ensure requests are authorized -/// before they are executed. -#[derive(Debug, Default, Clone)] +/// HTTP client that authorize requests before execution. +#[derive(Debug, Clone)] #[must_use] -pub struct Client { - inner: T, - credentials: Option, +pub struct Client> { + executer: T, + authorizer: A, } -impl From for Client { +impl From for Client { fn from(value: T) -> Self { Self { - inner: value, - credentials: None, + executer: value, + authorizer: A::default(), } } } -impl From<&Credentials> for Client { - fn from(value: &Credentials) -> Self { - Self::from(value.clone()) +impl Default for Client { + fn default() -> Self { + Self::from(reqwest::Client::default()) } } -impl>> From> for Client { - fn from(value: Credentials) -> Self { +impl Client { + #[cfg_attr(feature = "tracing", tracing::instrument)] + pub fn new(executer: T, authorizer: A) -> Self { Self { - inner: reqwest::Client::new(), - credentials: Some(value.into()), + executer, + authorizer, } } -} -impl Client { - pub fn new() -> Self { - Self::default() + pub fn executer(&self) -> &T { + &self.executer } -} -impl Client { - /// Sets the `credentials` to be used by the client to authorize HTTP request, - /// discarding the current value, if any. - #[cfg_attr(feature = "tracing", tracing::instrument)] - pub fn set_credentials(&mut self, credentials: Option) { - self.credentials = credentials; + pub fn executer_mut(&mut self) -> &mut T { + &mut self.executer } - #[cfg_attr(feature = "tracing", tracing::instrument)] - fn set_credentials_from> + fmt::Debug>( - &mut self, - credentials: Option>, - ) { - self.credentials = credentials.map(Credentials::into); + pub fn authorizer(&self) -> &A { + &self.authorizer } - /// Fills the `credentials` to be used by the client to authorize HTTP request, - /// discarding the current value, if any. - pub fn with_credentials> + fmt::Debug>( - mut self, - credentials: impl Into>>, - ) -> Self { - self.set_credentials_from(credentials.into()); - self - } - - /// Returns the credentials that will be used by this client to authorized - /// subsequent HTTP requests. - #[cfg_attr(feature = "tracing", tracing::instrument)] - pub fn credentials(&self) -> Option> { - self.credentials.as_ref().map(Credentials::as_ref) - } - - /// Returns a shared reference to the inner HTTP client. - #[cfg_attr(feature = "tracing", tracing::instrument)] - pub fn inner(&self) -> &T { - &self.inner - } - - /// Appends an `Authorization` header to the `request`, if this client has credentials and unless it is already set. - /// - /// Returns `true` if the `Authorization` header was inserted. - /// - /// # Errors - /// - /// Upon failure to produce the header value. - /// - /// If the client doesn't have credentials, this method is infallible. - #[cfg_attr(feature = "tracing", tracing::instrument)] - pub fn authorize(&self, request: &mut Request) -> Result { - match self.credentials() { - None => Ok(false), - Some(credentials) => credentials.authorize(request), - } + pub fn authorizer_mut(&mut self) -> &mut A { + &mut self.authorizer } } -impl ExecuteRequest for Client { - type Error = ClientError; +impl ExecuteRequest for Client { + type Error = ClientError; fn execute_request( &self, mut request: Request, ) -> impl Future> + Send + 'static { let result = self + .authorizer .authorize(&mut request) - .map(|_| self.inner.execute_request(request)); + .map_err(ClientError::Authorize) + .map(|_| self.executer.execute_request(request)); async move { result?.await.map_err(ClientError::Execute) } } } - -#[cfg(feature = "zeroize")] -impl Drop for Client { - fn drop(&mut self) { - use zeroize::Zeroize; - - if let Some(mut credentials) = self.credentials.take() { - credentials.zeroize(); - } - } -} diff --git a/src/credentials.rs b/src/credentials.rs index f6910fe..39f962e 100644 --- a/src/credentials.rs +++ b/src/credentials.rs @@ -3,20 +3,10 @@ use core::fmt; use base64::{Engine, engine::general_purpose::STANDARD as BASE64_ENGINE}; -use reqwest::{Method, Url, header}; +use reqwest::{Request, header}; use crate::signer::{HmacSha512, SignError, Signer}; -// CREDENTIALS ERROR /////////////////////////////////////////////////////////// - -#[derive(Debug, thiserror::Error)] -pub enum CredentialsError { - #[error("missing consumer key")] - MissingConsumerKey, - #[error("missing consumer secret")] - MissingConsumerSecret, -} - // AUTHORIZATION ERROR ///////////////////////////////////////////////////////// #[derive(Debug, thiserror::Error)] @@ -39,100 +29,10 @@ pub enum CredentialsKind { OAuth1, } -// CREDENTIALS BUILDER ///////////////////////////////////////////////////////// - -/// Utility type for handling credentials with optional consumer key and secret. -pub type CredentialsBuilder> = Credentials>; - -impl CredentialsBuilder { - /// Builds the credentials, filling missing consumer data with the provided default values. - pub fn with_consumer( - self, - default_consumer_key: impl Into, - default_consumer_secret: impl Into, - ) -> Credentials { - match self { - Self::Bearer { token } => Credentials::Bearer { token }, - Self::Basic { username, password } => Credentials::Basic { username, password }, - Self::OAuth1 { - token, - secret, - consumer_key, - consumer_secret, - } => Credentials::OAuth1 { - token, - secret, - consumer_key: consumer_key.unwrap_or_else(|| default_consumer_key.into()), - consumer_secret: consumer_secret.unwrap_or_else(|| default_consumer_secret.into()), - }, - } - } - - /// Builds the credentials. - /// - /// # Errors - /// - /// If one of `consumer_key` and `consumer_secret` is missing. - pub fn build(self) -> Result, CredentialsError> { - self.try_into() - } -} - -impl> From> for CredentialsBuilder { - fn from(value: Credentials) -> Self { - match value.into() { - Credentials::Bearer { token } => Self::Bearer { token }, - Credentials::Basic { username, password } => Self::Basic { username, password }, - Credentials::OAuth1 { - token, - secret, - consumer_key, - consumer_secret, - } => Self::OAuth1 { - token, - secret, - consumer_key: Some(consumer_key), - consumer_secret: Some(consumer_secret), - }, - } - } -} - -impl> TryFrom> for Credentials { - type Error = CredentialsError; - - fn try_from(value: CredentialsBuilder) -> Result { - Ok(match value { - Credentials::Bearer { token } => Credentials::Bearer { - token: token.into(), - }, - Credentials::Basic { username, password } => Credentials::Basic { - username: username.into(), - password: password.map(Into::into), - }, - Credentials::OAuth1 { - token, - secret, - consumer_key, - consumer_secret, - } => Credentials::OAuth1 { - token: token.into(), - secret: secret.into(), - consumer_key: consumer_key - .ok_or(CredentialsError::MissingConsumerKey)? - .into(), - consumer_secret: consumer_secret - .ok_or(CredentialsError::MissingConsumerSecret)? - .into(), - }, - }) - } -} - // CREDENTIALS ///////////////////////////////////////////////////////////////// /// Credentials used to authorize an HTTP request. -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(untagged))] #[cfg_attr(feature = "zeroize", derive(zeroize::Zeroize))] @@ -141,25 +41,49 @@ impl> TryFrom> for Credentials { zeroize(bound = "T: zeroize::Zeroize, U: zeroize::Zeroize") )] pub enum Credentials, U = T> { - Bearer { - #[cfg_attr(feature = "serde", serde(rename = "token"))] + OAuth1 { + #[cfg_attr( + feature = "serde", + serde(rename = "token", alias = "oauth-token", alias = "oauth_token") + )] token: T, + #[cfg_attr( + feature = "serde", + serde(rename = "secret", alias = "oauth-secret", alias = "oauth_secret") + )] + secret: T, + #[cfg_attr( + feature = "serde", + serde( + rename = "consumer-key", + alias = "consumer_key", + alias = "oauth-consumer-key", + alias = "oauth_consumer_key", + default + ) + )] + consumer_key: U, + #[cfg_attr( + feature = "serde", + serde( + rename = "consumer-secret", + alias = "consumer_secret", + alias = "oauth-consumer-secret", + alias = "oauth_consumer_secret", + default + ) + )] + consumer_secret: U, }, Basic { #[cfg_attr(feature = "serde", serde(rename = "username"))] username: T, - #[cfg_attr(feature = "serde", serde(rename = "password"))] + #[cfg_attr(feature = "serde", serde(rename = "password", default))] password: Option, }, - OAuth1 { + Bearer { #[cfg_attr(feature = "serde", serde(rename = "token"))] token: T, - #[cfg_attr(feature = "serde", serde(rename = "secret"))] - secret: T, - #[cfg_attr(feature = "serde", serde(rename = "consumer-key"))] - consumer_key: U, - #[cfg_attr(feature = "serde", serde(rename = "consumer-secret"))] - consumer_secret: U, }, } @@ -186,7 +110,7 @@ impl Credentials { Self::bearer(token.into()) } - pub fn basic(username: T, password: Option) -> Self { + pub const fn basic(username: T, password: Option) -> Self { Self::Basic { username, password } } @@ -194,7 +118,7 @@ impl Credentials { Self::basic(username.into(), password.map(Into::into)) } - pub fn oauth1(token: T, secret: T, consumer_key: U, consumer_secret: U) -> Self { + pub const fn oauth1(token: T, secret: T, consumer_key: U, consumer_secret: U) -> Self { Self::OAuth1 { token, secret, @@ -280,21 +204,20 @@ impl Credentials { } impl Credentials<&str> { - /// Returns the value for the `Authorization` header. - /// - /// Note: currently, only `HMAC-SHA512` signature method is supported. - /// - /// # Errors - /// - /// Upon failure to produce the header value. #[cfg_attr(feature = "tracing", tracing::instrument)] - pub fn authorization( - self, - method: &Method, - endpoint: &Url, - ) -> Result { - Ok(match self { - Self::Bearer { token } => format!("Bearer {token}"), + pub fn authorize(&self, request: &mut Request) -> Result { + if request.headers().contains_key(header::AUTHORIZATION) { + return Ok(false); + } + + let authorization = match *self { + Self::OAuth1 { + token, + secret, + consumer_key, + consumer_secret, + } => Signer::::new(token, secret, consumer_key, consumer_secret)? + .sign(request.method(), request.url())?, Self::Basic { username, password } => { let input = if let Some(password) = password { format!("{username}:{password}") @@ -303,40 +226,18 @@ impl Credentials<&str> { }; format!("Basic {}", BASE64_ENGINE.encode(input.as_bytes())) } - Self::OAuth1 { - token, - secret, - consumer_key, - consumer_secret, - } => Signer::::new(token, secret, consumer_key, consumer_secret)? - .sign(method, endpoint)?, - }) - } - - /// Appends an `Authorization` header to the `request`, unless it is already set. - /// - /// Returns `true` if the `Authorization` header was inserted. - /// - /// # Errors - /// - /// Upon failure to produce the header value. - #[cfg_attr(feature = "tracing", tracing::instrument)] - pub fn authorize(&self, request: &mut reqwest::Request) -> Result { - if !request.headers().contains_key(header::AUTHORIZATION) { - let authorization = self.authorization(request.method(), request.url())?; + Self::Bearer { token } => format!("Bearer {token}"), + }; - let mut value = authorization.parse::()?; - value.set_sensitive(true); + let mut value = authorization.parse::()?; + value.set_sensitive(true); - request - .headers_mut() - .try_append(header::AUTHORIZATION, value)?; + request + .headers_mut() + .try_append(header::AUTHORIZATION, value)?; - #[cfg(feature = "logging")] - trace!(request = ?request, "authorized request"); + trace!(request = ?request, "authorized request"); - return Ok(true); - } - Ok(false) + Ok(true) } } diff --git a/src/execute.rs b/src/execute.rs index f9116ab..e07b885 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -43,7 +43,6 @@ async fn execute_request( #[cfg(any(feature = "logging", feature = "metrics"))] let (endpoint, method) = (request.url().to_string(), request.method().to_string()); - #[cfg(feature = "logging")] trace!(%endpoint, %method, "execute request"); #[cfg(feature = "metrics")] @@ -54,7 +53,6 @@ async fn execute_request( #[cfg(any(feature = "logging", feature = "metrics"))] let status_code = response.status(); - #[cfg(feature = "logging")] trace!( endpoint, method = %method, diff --git a/src/lib.rs b/src/lib.rs index f742b9a..d672e50 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,6 @@ pub use reqwest; pub use url; -#[cfg(feature = "logging")] #[macro_use] mod logging; @@ -18,6 +17,8 @@ pub mod signer; pub mod credentials; +pub mod authorize; + #[cfg(feature = "execute")] pub mod execute; diff --git a/src/signer.rs b/src/signer.rs index e0fa65d..cf9d535 100644 --- a/src/signer.rs +++ b/src/signer.rs @@ -28,7 +28,7 @@ pub trait SignatureMethod { fn digest(key: &str, signature: &str) -> Result; } -// HMAX-SHA512 ///////////////////////////////////////////////////////////////// +// HMAC-SHA512 ///////////////////////////////////////////////////////////////// pub type HmacSha512 = Hmac; @@ -39,9 +39,13 @@ impl SignatureMethod for HmacSha512 { #[inline] fn digest(key: &str, signature: &str) -> Result { - let mut hasher = Hmac::::new_from_slice(key.as_bytes())?; - hasher.update(signature.as_bytes()); - Ok(BASE64_ENGINE.encode(hasher.finalize().into_bytes())) + let key = key.as_bytes(); + let hash_value = { + let mut hasher = HmacSha512::new_from_slice(key)?; + hasher.update(signature.as_bytes()); + hasher.finalize().into_bytes() + }; + Ok(BASE64_ENGINE.encode(hash_value)) } } diff --git a/tests/sse.rs b/tests/sse.rs index 9c37fd1..2c43706 100644 --- a/tests/sse.rs +++ b/tests/sse.rs @@ -157,7 +157,7 @@ mod sse_integration_tests { .data("Message 3"), ]; - let client = Client::new(); + let client = Client::default(); let mut event_stream = client.untyped_sse(&endpoint).max_loop(0).stream()?; @@ -290,7 +290,7 @@ mod sse_integration_tests { ), ]; - let client = Client::new(); + let client = Client::default(); let mut event_stream = client .sse::>(&endpoint) @@ -356,18 +356,18 @@ mod sse_e2e_tests { use tracing::warn; fn oauth_env() -> Option { - if let Ok(token) = std::env::var("CC_TOKEN") { - if let Ok(secret) = std::env::var("CC_SECRET") { - if let Ok(consumer_key) = std::env::var("CC_CONSUMER_KEY") { - if let Ok(consumer_secret) = std::env::var("CC_CONSUMER_SECRET") { - return Some(Credentials::oauth1_from( - token, - secret, - consumer_key, - consumer_secret, - )); - } - } + if let Ok(token) = std::env::var("CLEVER_TOKEN") { + if let Ok(secret) = std::env::var("CLEVER_SECRET") { + let consumer_key = std::env::var("CLEVER_CONSUMER_KEY") + .unwrap_or("T5nFjKeHH4AIlEveuGhB5S3xg8T19e".into()); + let consumer_secret = std::env::var("CLEVER_CONSUMER_SECRET") + .unwrap_or("MgVMqTr6fWlf2M0tkC2MXOnhfqBWDT".into()); + return Some(Credentials::oauth1_from( + token, + secret, + consumer_key, + consumer_secret, + )); } } None @@ -382,7 +382,7 @@ mod sse_e2e_tests { return Ok(()); }; - let client = Client::new().with_credentials(credentials); + let client = Client::new(reqwest::Client::new(), credentials); let mut event_stream = client .untyped_sse(ENDPOINT)