Skip to content

Commit 3852dcc

Browse files
committed
chore: catch panics on too many headers and prevent unnecessary allocations
1 parent 6492e86 commit 3852dcc

File tree

2 files changed

+126
-95
lines changed

2 files changed

+126
-95
lines changed

src/client/mod.rs

Lines changed: 114 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
//! to interact with the Clever-Cloud's api, but has been extended to be more
55
//! generic.
66
7-
use core::{error::Error, fmt, future::Future};
7+
use core::{error::Error, fmt, future::Future, time::Duration};
88

99
use std::{
10+
borrow::Cow,
1011
collections::BTreeMap,
1112
time::{SystemTime, SystemTimeError},
1213
};
@@ -224,7 +225,7 @@ pub trait OAuth1: fmt::Debug {
224225
type Error;
225226

226227
// `params` returns OAuth1 parameters without the signature one
227-
fn params(&self) -> BTreeMap<String, String>;
228+
fn params(&self) -> BTreeMap<Cow<'_, str>, Cow<'_, str>>;
228229

229230
// `signature` returns the computed signature from given parameters
230231
fn signature(&self, method: &str, endpoint: &str) -> Result<String, Self::Error>;
@@ -238,9 +239,9 @@ pub trait OAuth1: fmt::Debug {
238239
let signature = self.signature(method, endpoint)?;
239240
let mut params = self.params();
240241

241-
params.insert(
242-
OAUTH1_SIGNATURE.to_string(),
243-
urlencoding::encode(&signature).into_owned(),
242+
let _ = params.insert(
243+
Cow::Borrowed(OAUTH1_SIGNATURE),
244+
urlencoding::encode(&signature),
244245
);
245246

246247
let mut base = params
@@ -299,50 +300,90 @@ pub enum SignerError {
299300
// -----------------------------------------------------------------------------
300301
// Signer structure
301302

302-
#[derive(Debug, Clone, PartialEq, Eq)]
303-
pub struct Signer {
304-
pub nonce: String,
305-
pub timestamp: u64,
306-
pub token: String,
307-
pub secret: String,
308-
pub consumer_key: String,
309-
pub consumer_secret: String,
303+
#[derive(Clone, PartialEq, Eq)]
304+
pub struct Signer<T = String> {
305+
pub nonce: Uuid,
306+
pub timestamp: Duration,
307+
pub token: T,
308+
pub secret: T,
309+
pub consumer_key: T,
310+
pub consumer_secret: T,
310311
}
311312

312-
impl OAuth1 for Signer {
313+
impl<T> fmt::Debug for Signer<T> {
314+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315+
f.debug_struct("Signer")
316+
.field("nonce", &self.nonce)
317+
.field("timestamp", &self.timestamp)
318+
.finish_non_exhaustive()
319+
}
320+
}
321+
322+
impl<T: fmt::Debug> Signer<T> {
323+
#[cfg_attr(feature = "tracing", tracing::instrument)]
324+
fn new(token: T, secret: T, consumer_key: T, consumer_secret: T) -> Result<Self, SignerError> {
325+
let nonce = Uuid::new_v4();
326+
let timestamp = SystemTime::now()
327+
.duration_since(SystemTime::UNIX_EPOCH)
328+
.map_err(SignerError::UnixEpochTime)?;
329+
Ok(Self {
330+
nonce,
331+
timestamp,
332+
token,
333+
secret,
334+
consumer_key,
335+
consumer_secret,
336+
})
337+
}
338+
}
339+
340+
impl<T: AsRef<str> + fmt::Debug> OAuth1 for Signer<T> {
313341
type Error = SignerError;
314342

315343
#[cfg_attr(feature = "tracing", tracing::instrument)]
316-
fn params(&self) -> BTreeMap<String, String> {
344+
fn params(&self) -> BTreeMap<Cow<'_, str>, Cow<'_, str>> {
317345
let mut params = BTreeMap::new();
318-
319-
params.insert(
320-
OAUTH1_CONSUMER_KEY.to_string(),
321-
self.consumer_key.to_string(),
346+
let _ = params.insert(
347+
Cow::Borrowed(OAUTH1_CONSUMER_KEY),
348+
self.consumer_key.as_ref().into(),
349+
);
350+
let _ = params.insert(
351+
Cow::Borrowed(OAUTH1_NONCE),
352+
Cow::Owned(self.nonce.to_string()),
353+
);
354+
let _ = params.insert(
355+
Cow::Borrowed(OAUTH1_SIGNATURE_METHOD),
356+
Cow::Borrowed(OAUTH1_SIGNATURE_HMAC_SHA512),
322357
);
323-
params.insert(OAUTH1_NONCE.to_string(), self.nonce.to_string());
324-
params.insert(
325-
OAUTH1_SIGNATURE_METHOD.to_string(),
326-
OAUTH1_SIGNATURE_HMAC_SHA512.to_string(),
358+
let _ = params.insert(
359+
Cow::Borrowed(OAUTH1_TIMESTAMP),
360+
Cow::Owned(self.timestamp.as_secs().to_string()),
361+
);
362+
let _ = params.insert(
363+
Cow::Borrowed(OAUTH1_VERSION),
364+
Cow::Borrowed(OAUTH1_VERSION_1),
365+
);
366+
let _ = params.insert(
367+
Cow::Borrowed(OAUTH1_TOKEN),
368+
Cow::Borrowed(self.token.as_ref()),
327369
);
328-
params.insert(OAUTH1_TIMESTAMP.to_string(), self.timestamp.to_string());
329-
params.insert(OAUTH1_VERSION.to_string(), OAUTH1_VERSION_1.to_string());
330-
params.insert(OAUTH1_TOKEN.to_string(), self.token.to_string());
331370
params
332371
}
333372

334373
#[cfg_attr(feature = "tracing", tracing::instrument)]
335374
fn signature(&self, method: &str, endpoint: &str) -> Result<String, Self::Error> {
336375
let mut params = self.params();
337376

377+
// TODO: we could use query_pairs on Url
378+
338379
let host = match endpoint.split_once('?') {
339380
None => endpoint,
340381
Some((host, query)) => {
341382
for qparam in query.split('&') {
342383
let (k, v) = qparam.split_once('=').ok_or_else(|| {
343384
SignerError::Parse(format!("failed to parse query parameter, {qparam}"))
344385
})?;
345-
params.entry(k.to_owned()).or_insert(v.to_owned());
386+
let _ = params.entry(Cow::Borrowed(k)).or_insert(Cow::Borrowed(v));
346387
}
347388
host
348389
}
@@ -375,8 +416,8 @@ impl OAuth1 for Signer {
375416
fn signing_key(&self) -> String {
376417
format!(
377418
"{}&{}",
378-
urlencoding::encode(&self.consumer_secret),
379-
urlencoding::encode(&self.secret)
419+
urlencoding::encode(self.consumer_secret.as_ref()),
420+
urlencoding::encode(self.secret.as_ref())
380421
)
381422
}
382423
}
@@ -386,26 +427,13 @@ impl TryFrom<Credentials> for Signer {
386427

387428
#[cfg_attr(feature = "tracing", tracing::instrument)]
388429
fn try_from(credentials: Credentials) -> Result<Self, Self::Error> {
389-
let nonce = Uuid::new_v4().to_string();
390-
let timestamp = SystemTime::now()
391-
.duration_since(SystemTime::UNIX_EPOCH)
392-
.map_err(SignerError::UnixEpochTime)?
393-
.as_secs();
394-
395430
match credentials {
396431
Credentials::OAuth1 {
397432
token,
398433
secret,
399434
consumer_key,
400435
consumer_secret,
401-
} => Ok(Self {
402-
nonce,
403-
timestamp,
404-
token,
405-
secret,
406-
consumer_key,
407-
consumer_secret,
408-
}),
436+
} => Self::new(token, secret, consumer_key, consumer_secret),
409437
_ => Err(SignerError::InvalidCredentials),
410438
}
411439
}
@@ -432,6 +460,8 @@ pub enum ClientError {
432460
Digest(SignerError),
433461
#[error("failed to serialize signature as header value, {0}")]
434462
SerializeHeaderValue(header::InvalidHeaderValue),
463+
#[error("too many headers")]
464+
TooManyHeaders(#[from] reqwest::header::MaxSizeReached),
435465
}
436466

437467
// -----------------------------------------------------------------------------
@@ -462,36 +492,33 @@ impl Execute for Client {
462492
let method = request.method().to_string();
463493
let endpoint = request.url().to_string();
464494

465-
if !request.headers().contains_key(&header::AUTHORIZATION) {
466-
match &client.credentials {
467-
Some(Credentials::Bearer { token }) => {
468-
request.headers_mut().insert(
469-
header::AUTHORIZATION,
495+
if let Some(credentials) = &client.credentials {
496+
if let header::Entry::Vacant(vacant_entry) =
497+
request.headers_mut().entry(header::AUTHORIZATION)
498+
{
499+
let header_value = match credentials {
500+
Credentials::OAuth1 {
501+
token,
502+
secret,
503+
consumer_key,
504+
consumer_secret,
505+
} => Signer::new(token, secret, consumer_key, consumer_secret)
506+
.map_err(ClientError::Signer)?
507+
.sign(&method, &endpoint)
508+
.map_err(ClientError::Digest)?
509+
.parse()
510+
.map_err(ClientError::SerializeHeaderValue)?,
511+
Credentials::Basic { username, password } => {
512+
let token = BASE64_ENGINE.encode(format!("{username}:{password}"));
513+
HeaderValue::from_str(&format!("Basic {token}"))
514+
.map_err(ClientError::SerializeHeaderValue)?
515+
}
516+
Credentials::Bearer { token } => {
470517
HeaderValue::from_str(&format!("Bearer {token}"))
471-
.map_err(ClientError::SerializeHeaderValue)?,
472-
);
473-
}
474-
Some(Credentials::Basic { username, password }) => {
475-
let token = BASE64_ENGINE.encode(format!("{username}:{password}"));
476-
477-
request.headers_mut().insert(
478-
header::AUTHORIZATION,
479-
HeaderValue::from_str(&format!("Basic {token}",))
480-
.map_err(ClientError::SerializeHeaderValue)?,
481-
);
482-
}
483-
Some(credentials) => {
484-
request.headers_mut().insert(
485-
header::AUTHORIZATION,
486-
Signer::try_from(credentials.to_owned())
487-
.map_err(ClientError::Signer)?
488-
.sign(&method, &endpoint)
489-
.map_err(ClientError::Digest)?
490-
.parse()
491-
.map_err(ClientError::SerializeHeaderValue)?,
492-
);
493-
}
494-
_ => {}
518+
.map_err(ClientError::SerializeHeaderValue)?
519+
}
520+
};
521+
let _ = vacant_entry.try_insert(header_value)?;
495522
}
496523
}
497524

@@ -551,10 +578,10 @@ impl<X: IntoUrl + fmt::Debug + Send> RestClient<X> for Client {
551578
let mut request = reqwest::Request::new(method.to_owned(), url);
552579

553580
let headers = request.headers_mut();
554-
headers.insert(header::CONTENT_TYPE, APPLICATION_JSON);
555-
headers.insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len()));
556-
headers.insert(header::ACCEPT_CHARSET, UTF8);
557-
headers.insert(header::ACCEPT, APPLICATION_JSON);
581+
let _ = headers.try_insert(header::CONTENT_TYPE, APPLICATION_JSON)?;
582+
let _ = headers.try_insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len()))?;
583+
let _ = headers.try_insert(header::ACCEPT_CHARSET, UTF8)?;
584+
let _ = headers.try_insert(header::ACCEPT, APPLICATION_JSON)?;
558585

559586
*request.body_mut() = Some(buf.into());
560587

@@ -585,15 +612,18 @@ impl<X: IntoUrl + fmt::Debug + Send> RestClient<X> for Client {
585612
{
586613
let url = endpoint.into_url().map_err(ClientError::Request)?;
587614

588-
let mut req = reqwest::Request::new(Method::GET, url);
589-
590-
req.headers_mut().insert(header::ACCEPT_CHARSET, UTF8);
591-
592-
req.headers_mut().insert(header::ACCEPT, APPLICATION_JSON);
615+
let mut request = reqwest::Request::new(Method::GET, url);
593616

594-
let res = self.execute(req).await?;
595-
let status = res.status();
596-
let buf = res.bytes().await.map_err(ClientError::BodyAggregation)?;
617+
let headers = request.headers_mut();
618+
let _ = headers.try_insert(header::ACCEPT_CHARSET, UTF8)?;
619+
let _ = headers.try_insert(header::ACCEPT, APPLICATION_JSON)?;
620+
621+
let response = self.execute(request).await?;
622+
let status = response.status();
623+
let buf = response
624+
.bytes()
625+
.await
626+
.map_err(ClientError::BodyAggregation)?;
597627

598628
if !status.is_success() {
599629
return Err(ClientError::StatusCode(

src/client/sse.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ pub type SseResult<C, K = String, V = String> = Result<Event<K, V>, SseErrorOf<C
3838
pub type SseBuildResult<C, K = String, V = String> =
3939
Result<SseStream<C, K, V>, SseErrorOf<C, K, V>>;
4040

41+
pub const MAX_CAPACITY: usize = isize::MAX as usize;
42+
4143
/// Default initial capacity of the buffer of the [`SseStream`].
4244
pub const DEFAULT_INITIAL_CAPACITY: usize = 512;
4345

@@ -154,7 +156,7 @@ impl<K, V> EventParser<K, V> {
154156
initial_capacity: usize,
155157
max_capacity: usize,
156158
) -> Self {
157-
let max_capacity = max_capacity.min(isize::MAX as usize);
159+
let max_capacity = max_capacity.min(MAX_CAPACITY);
158160
let initial_capacity = initial_capacity.min(max_capacity);
159161
Self {
160162
buf: BytesMut::with_capacity(initial_capacity),
@@ -475,6 +477,8 @@ pub enum SseError<E, K = Infallible, V = Infallible> {
475477
/// Failed to parse [`Event`].
476478
#[error(transparent)]
477479
Parser(EventParseError<K, V>),
480+
#[error("too many header")]
481+
TooManyHeaders(#[from] reqwest::header::MaxSizeReached),
478482
}
479483

480484
// SSE STATE ///////////////////////////////////////////////////////////////////
@@ -785,29 +789,26 @@ impl<C, K, V> SseStreamBuilder<C, K, V> {
785789
let mut request = reqwest::Request::new(Method::GET, url);
786790

787791
let headers = request.headers_mut();
788-
789-
let _ = headers.insert(header::ACCEPT, TEXT_EVENT_STREAM);
790-
791-
let _ = headers.insert(header::CACHE_CONTROL, NO_STORE);
792+
let _ = headers.try_insert(header::ACCEPT, TEXT_EVENT_STREAM)?;
793+
let _ = headers.try_insert(header::CACHE_CONTROL, NO_STORE)?;
792794

793795
if let Some(last_event_id) = last_event_id.header_value() {
794796
let _ = headers.insert(LAST_EVENT_ID, last_event_id);
795797
}
796798

797799
// TODO: request's "initiator" type should be set to "other"
798800

799-
let first_request = match request.try_clone() {
800-
None => return Err(SseError::RequestBodyNotCloneable),
801-
Some(v) => v,
802-
};
801+
let first_request = request
802+
.try_clone()
803+
.ok_or(SseError::RequestBodyNotCloneable)?;
803804

804805
Ok(SseStream {
805806
state: SseState::Connecting(client.execute(first_request).boxed()),
806807
parser: EventParser::new(
807808
request.url().clone(),
808809
last_event_id,
809810
initial_capacity,
810-
max_capacity.unwrap_or(isize::MAX as usize),
811+
max_capacity.unwrap_or(MAX_CAPACITY),
811812
),
812813
max_retry: max_retry.map(|n| (n, n)),
813814
max_loop: max_loop.map(|n| (n, n)),
@@ -819,7 +820,7 @@ impl<C, K, V> SseStreamBuilder<C, K, V> {
819820

820821
// SSE CLIENT //////////////////////////////////////////////////////////////////
821822

822-
/// Extension trait for [`Request`]s clients that support subscribing to Server-Sent Events (SSE).
823+
/// Extension trait for HTTP clients that support subscribing to Server-Sent Events (SSE).
823824
pub trait SseClient<U> {
824825
/// Sends a GET HTTP request to the provided `endpoint`,
825826
/// which is expected to serve a stream of Server-Sent Events (SSE).

0 commit comments

Comments
 (0)