diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ba909d..019c2ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ + +### v0.8.0 (2018-10-15) + +#### Features +* Update rand to 0.6 +* Upgrade native-tls to 0.2 +* Add a maximal size for fragments exposed via the `max_fragment_size` setting + +#### Bug fixes +* Don't try to parse response when the socket not ready + ### v0.7.9 (2018-10-15) diff --git a/Cargo.toml b/Cargo.toml index 0942b3c..f2d440c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ license = "MIT" name = "parity-ws" readme = "README.md" repository = "https://github.com/paritytech/ws-rs" -version = "0.8.0" +version = "0.10.0" [dependencies] byteorder = "1.2.1" @@ -22,10 +22,10 @@ httparse = "1.2.4" log = "0.4.1" mio = "0.6.14" mio-extras = "2.0" -rand = "0.4.2" -sha1 = "0.6.0" +rand = "0.7" +sha-1 = "0.8.0" slab = "0.4" -url = "1.7.0" +url = "2.0.0" [dependencies.libc] optional = true @@ -41,11 +41,11 @@ version = "0.10" [dependencies.native-tls] optional = true -version = "0.1.5" +version = "0.2" [dev-dependencies] clap = "2.31.2" -env_logger = "0.5.6" +env_logger = "0.6" term = "0.5.1" time = "0.1.39" diff --git a/README.md b/README.md index e6413ee..3892839 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,10 @@ listen("127.0.0.1:3012", |out| { } }) ``` + # This fork -Note this is (hopefuly) temporary fork of the original crate until https://github.com/housleyjk/ws-rs/pull/252 gets merged. +Note this is (hopefuly) a temporary fork of the original crate until https://github.com/housleyjk/ws-rs/pull/328 gets merged. Introduction ------------ diff --git a/src/capped_buffer.rs b/src/capped_buffer.rs new file mode 100644 index 0000000..cb85b85 --- /dev/null +++ b/src/capped_buffer.rs @@ -0,0 +1,187 @@ +use bytes::BufMut; +use std::ops::Deref; +use std::io; + +/// Safe wrapper around Vec with custom `bytes::BufMut` and `std::io::Write` +/// implementations that ensure the buffer never exceeds maximum capacity. +pub struct CappedBuffer { + buf: Vec, + max: usize, +} + +impl CappedBuffer { + /// Create a new `CappedBuffer` with initial `capacity`, and a limit + /// capacity set to `max`. + pub fn new(mut capacity: usize, max: usize) -> Self { + if capacity > max { + capacity = max; + } + + Self { + buf: Vec::with_capacity(capacity), + max, + } + } + + /// Remaining amount of bytes that can be written to the buffer + /// before reaching max capacity + #[inline] + pub fn remaining(&self) -> usize { + self.max - self.buf.len() + } + + /// Shift the content of the buffer to the left by `shift`, + /// effectively forgetting the shifted out bytes. + /// New length of the buffer will be adjusted accordingly. + pub fn shift(&mut self, shift: usize) { + if shift >= self.buf.len() { + self.buf.clear(); + return; + } + + let src = self.buf[shift..].as_ptr(); + let dst = self.buf.as_mut_ptr(); + let new_len = self.buf.len() - shift; + + // This is a simple, potentially overlapping memcpy within + // the buffer, shifting `new_len` bytes at offset `shift` (`src`) + // to the beginning of the buffer (`dst`) + unsafe { + std::ptr::copy(src, dst, new_len); + self.buf.set_len(new_len); + } + } +} + +impl AsRef<[u8]> for CappedBuffer { + fn as_ref(&self) -> &[u8] { + &self.buf + } +} + +impl AsMut<[u8]> for CappedBuffer { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.buf + } +} + +impl Deref for CappedBuffer { + type Target = Vec; + + fn deref(&self) -> &Vec { + &self.buf + } +} + +impl io::Write for CappedBuffer { + fn write(&mut self, mut buf: &[u8]) -> io::Result { + if buf.len() > self.remaining() { + buf = &buf[..self.remaining()]; + } + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } + + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + if buf.len() <= self.remaining() { + self.buf.extend_from_slice(buf); + Ok(()) + } else { + Err(io::Error::new(io::ErrorKind::InvalidInput, "Exceeded maximum buffer capacity")) + } + } + + fn flush(&mut self) -> io::Result<()> { + self.buf.flush() + } +} + +impl BufMut for CappedBuffer { + fn remaining_mut(&self) -> usize { + self.remaining() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + assert!(cnt <= self.remaining(), "Exceeded buffer capacity"); + + self.buf.advance_mut(cnt); + } + + unsafe fn bytes_mut(&mut self) -> &mut [u8] { + let remaining = self.remaining(); + + // `self.buf.bytes_mut` does an implicit allocation + if remaining == 0 { + return &mut []; + } + + let mut bytes = self.buf.bytes_mut(); + + if bytes.len() > remaining { + bytes = &mut bytes[..remaining]; + } + + bytes + } +} + +#[cfg(test)] +mod test { + use std::io::Write; + use super::*; + + #[test] + fn shift() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(6); + + assert_eq!(&*buffer, b"World"); + assert_eq!(buffer.remaining(), 15); + } + + #[test] + fn shift_zero() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(0); + + assert_eq!(&*buffer, b"Hello World"); + assert_eq!(buffer.remaining(), 9); + } + + #[test] + fn shift_all() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(11); + + assert_eq!(&*buffer, b""); + assert_eq!(buffer.remaining(), 20); + } + + #[test] + fn shift_capacity() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(20); + + assert_eq!(&*buffer, b""); + assert_eq!(buffer.remaining(), 20); + } + + #[test] + fn shift_over_capacity() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(50); + + assert_eq!(&*buffer, b""); + assert_eq!(buffer.remaining(), 20); + } +} diff --git a/src/communication.rs b/src/communication.rs index a6020c6..2b2822f 100644 --- a/src/communication.rs +++ b/src/communication.rs @@ -11,6 +11,7 @@ use message; use protocol::CloseCode; use result::{Error, Result}; use std::cmp::PartialEq; +use std::hash::{Hash, Hasher}; use std::fmt; #[derive(Debug, Clone)] @@ -69,6 +70,16 @@ impl PartialEq for Sender { } } +impl Eq for Sender { } + +impl Hash for Sender { + fn hash(&self, state: &mut H) { + self.connection_id.hash(state); + self.token.hash(state); + } +} + + impl Sender { #[doc(hidden)] #[inline] diff --git a/src/connection.rs b/src/connection.rs index 9466b03..0178e0d 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; use std::collections::VecDeque; -use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::io::{Cursor, Read, Write}; use std::mem::replace; use std::net::SocketAddr; use std::str::from_utf8; @@ -15,6 +15,7 @@ use native_tls::HandshakeError; #[cfg(feature = "ssl")] use openssl::ssl::HandshakeError; +use capped_buffer::CappedBuffer; use frame::Frame; use handler::Handler; use handshake::{Handshake, Request, Response}; @@ -87,8 +88,8 @@ where fragments: VecDeque, - in_buffer: Cursor>, - out_buffer: Cursor>, + in_buffer: Cursor, + out_buffer: Cursor, handler: H, @@ -119,8 +120,8 @@ where endpoint: Endpoint::Server, events: Ready::empty(), fragments: VecDeque::with_capacity(settings.fragments_capacity), - in_buffer: Cursor::new(Vec::with_capacity(settings.in_buffer_capacity)), - out_buffer: Cursor::new(Vec::with_capacity(settings.out_buffer_capacity)), + in_buffer: Cursor::new(CappedBuffer::new(settings.in_buffer_capacity, settings.max_in_buffer_capacity)), + out_buffer: Cursor::new(CappedBuffer::new(settings.out_buffer_capacity, settings.max_out_buffer_capacity)), handler, addresses: Vec::new(), settings, @@ -182,7 +183,7 @@ where HandshakeError::Failure(_) => { Err(Error::new(Kind::SslHandshake(handshake_err), details)) } - HandshakeError::Interrupted(mid) => { + HandshakeError::WouldBlock(mid) => { self.socket = Stream::tls(mid); Ok(()) } @@ -252,7 +253,7 @@ where HandshakeError::Failure(_) => { Err(Error::new(Kind::SslHandshake(handshake_err), details)) } - HandshakeError::Interrupted(mid) => { + HandshakeError::WouldBlock(mid) => { self.socket = Stream::tls(mid); Ok(()) } @@ -426,8 +427,8 @@ where self.handler.on_error(err); if let Err(err) = self.send_close(CloseCode::Size, reason) { self.handler.on_error(err); - self.disconnect() } + self.disconnect() } Kind::Protocol => { if self.settings.panic_on_protocol { @@ -605,10 +606,13 @@ where if !data[..end].ends_with(b"\r\n\r\n") { return Ok(()); } - self.in_buffer.get_mut().extend(&data[end..]); + self.in_buffer.get_mut().write_all(&data[end..])?; end }; res.get_mut().truncate(end); + } else { + // NOTE: wait to be polled again; response not ready. + return Ok(()); } } } @@ -1172,29 +1176,24 @@ where trace!("Buffering frame to {}:\n{}", self.peer_addr(), frame); - let pos = self.out_buffer.position(); - self.out_buffer.seek(SeekFrom::End(0))?; - frame.format(&mut self.out_buffer)?; - self.out_buffer.seek(SeekFrom::Start(pos))?; + frame.format(self.out_buffer.get_mut())?; Ok(()) } fn check_buffer_out(&mut self, frame: &Frame) -> Result<()> { - if self.out_buffer.get_ref().capacity() <= self.out_buffer.get_ref().len() + frame.len() { - // extend - let mut new = Vec::with_capacity(self.out_buffer.get_ref().capacity()); - new.extend(&self.out_buffer.get_ref()[self.out_buffer.position() as usize..]); - if new.len() == new.capacity() { - if self.settings.out_buffer_grow { - new.reserve(self.settings.out_buffer_capacity) - } else { - return Err(Error::new( - Kind::Capacity, - "Maxed out output buffer for connection.", - )); - } + if self.out_buffer.get_ref().remaining() < frame.len() { + // There is no more room to grow, and we can't shift the buffer + if self.out_buffer.position() == 0 { + return Err(Error::new( + Kind::Capacity, + "Reached the limit of the output buffer for the connection.", + )); } - self.out_buffer = Cursor::new(new); + + // Shift the buffer + let prev_pos = self.out_buffer.position() as usize; + self.out_buffer.set_position(0); + self.out_buffer.get_mut().shift(prev_pos); } Ok(()) } @@ -1203,21 +1202,19 @@ where trace!("Reading buffer for connection to {}.", self.peer_addr()); if let Some(len) = self.socket.try_read_buf(self.in_buffer.get_mut())? { trace!("Buffered {}.", len); - if self.in_buffer.get_ref().len() == self.in_buffer.get_ref().capacity() { - // extend - let mut new = Vec::with_capacity(self.in_buffer.get_ref().capacity()); - new.extend(&self.in_buffer.get_ref()[self.in_buffer.position() as usize..]); - if new.len() == new.capacity() { - if self.settings.in_buffer_grow { - new.reserve(self.settings.in_buffer_capacity); - } else { - return Err(Error::new( - Kind::Capacity, - "Maxed out input buffer for connection.", - )); - } + if self.in_buffer.get_ref().remaining() == 0 { + // There is no more room to grow, and we can't shift the buffer + if self.in_buffer.position() == 0 { + return Err(Error::new( + Kind::Capacity, + "Reached the limit of the input buffer for the connection.", + )); } - self.in_buffer = Cursor::new(new); + + // Shift the buffer + let prev_pos = self.in_buffer.position() as usize; + self.in_buffer.set_position(0); + self.in_buffer.get_mut().shift(prev_pos); } Ok(Some(len)) } else { diff --git a/src/frame.rs b/src/frame.rs index 154816c..4c43a04 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -5,6 +5,7 @@ use std::io::{Cursor, ErrorKind, Read, Write}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use rand; +use capped_buffer::CappedBuffer; use protocol::{CloseCode, OpCode}; use result::{Error, Kind, Result}; use stream::TryReadBuf; @@ -244,7 +245,7 @@ impl Frame { } /// Parse the input stream into a frame. - pub fn parse(cursor: &mut Cursor>, max_payload_length: u64) -> Result> { + pub fn parse(cursor: &mut Cursor, max_payload_length: u64) -> Result> { let size = cursor.get_ref().len() as u64 - cursor.position(); let initial = cursor.position(); trace!("Position in buffer {}", initial); diff --git a/src/handler.rs b/src/handler.rs index 68a7527..7c1ef21 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -317,13 +317,14 @@ pub trait Handler { Kind::Protocol, format!("Unable to parse domain from {}. Needed for SSL.", url), ))?; - let connector = TlsConnector::builder().and_then(|builder| builder.build()) - .map_err(|e| { - Error::new( - Kind::Internal, - format!("Failed to upgrade client to SSL: {}", e), - ) - })?; + + let connector = TlsConnector::new().map_err(|e| { + Error::new( + Kind::Internal, + format!("Failed to upgrade client to SSL: {}", e), + ) + })?; + connector.connect(domain, stream).map_err(Error::from) } /// A method for wrapping a server TcpStream with Ssl Authentication machinery diff --git a/src/handshake.rs b/src/handshake.rs index 59b91b1..b7520bd 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -5,7 +5,7 @@ use std::str::from_utf8; use httparse; use rand; -use sha1; +use sha1::{self, Digest}; use url; use result::{Error, Kind, Result}; @@ -22,10 +22,10 @@ fn generate_key() -> String { pub fn hash_key(key: &[u8]) -> String { let mut hasher = sha1::Sha1::new(); - hasher.update(key); - hasher.update(WS_GUID.as_bytes()); + hasher.input(key); + hasher.input(WS_GUID.as_bytes()); - encode_base64(&hasher.digest().bytes()) + encode_base64(&hasher.result()) } // This code is based on rustc_serialize base64 STANDARD diff --git a/src/lib.rs b/src/lib.rs index ac8d014..fe08689 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ extern crate url; #[macro_use] extern crate log; +mod capped_buffer; mod communication; mod connection; mod factory; @@ -161,22 +162,20 @@ pub struct Settings { /// The maximum length of acceptable incoming frames. Messages longer than this will be rejected. /// Default: unlimited pub max_fragment_size: usize, - /// The size of the incoming buffer. A larger buffer uses more memory but will allow for fewer - /// reallocations. + /// The initial size of the incoming buffer. A larger buffer uses more memory but will allow for + /// fewer reallocations. /// Default: 2048 pub in_buffer_capacity: usize, - /// Whether to reallocate the incoming buffer when `in_buffer_capacity` is reached. If this is - /// false, a Capacity error will be triggered instead. - /// Default: true - pub in_buffer_grow: bool, - /// The size of the outgoing buffer. A larger buffer uses more memory but will allow for fewer - /// reallocations. + /// The maximum size to which the incoming buffer can grow. + /// Default: 10,485,760 + pub max_in_buffer_capacity: usize, + /// The initial size of the outgoing buffer. A larger buffer uses more memory but will allow for + /// fewer reallocations. /// Default: 2048 pub out_buffer_capacity: usize, - /// Whether to reallocate the incoming buffer when `out_buffer_capacity` is reached. If this is - /// false, a Capacity error will be triggered instead. - /// Default: true - pub out_buffer_grow: bool, + /// The maximum size to which the outgoing buffer can grow. + /// Default: 10,485,760 + pub max_out_buffer_capacity: usize, /// Whether to panic when an Internal error is encountered. Internal errors should generally /// not occur, so this setting defaults to true as a debug measure, whereas production /// applications should consider setting it to false. @@ -250,9 +249,9 @@ impl Default for Settings { fragment_size: u16::max_value() as usize, max_fragment_size: usize::max_value(), in_buffer_capacity: 2048, - in_buffer_grow: true, + max_in_buffer_capacity: 10 * 1024 * 1024, out_buffer_capacity: 2048, - out_buffer_grow: true, + max_out_buffer_capacity: 10 * 1024 * 1024, panic_on_internal: true, panic_on_capacity: false, panic_on_protocol: false, diff --git a/src/result.rs b/src/result.rs index eaec515..eb3c151 100644 --- a/src/result.rs +++ b/src/result.rs @@ -61,7 +61,7 @@ pub enum Kind { /// A custom error kind for use by applications. This error kind involves extra overhead /// because it will allocate the memory on the heap. The WebSocket ignores such errors by /// default, simply passing them to the Connection Handler. - Custom(Box), + Custom(Box), } /// A struct indicating the kind of error that has occurred and any precise details of that error. @@ -81,7 +81,7 @@ impl Error { } } - pub fn into_box(self) -> Box { + pub fn into_box(self) -> Box { match self.kind { Kind::Custom(err) => err, _ => Box::new(self), @@ -127,7 +127,7 @@ impl StdError for Error { } } - fn cause(&self) -> Option<&StdError> { + fn cause(&self) -> Option<&dyn StdError> { match self.kind { Kind::Encoding(ref err) => Some(err), Kind::Io(ref err) => Some(err), diff --git a/src/stream.rs b/src/stream.rs index a2583a8..3b8d9c4 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -194,7 +194,7 @@ impl io::Read for Stream { err } #[cfg(feature = "nativetls")] - Err(HandshakeError::Interrupted(mid)) => { + Err(HandshakeError::WouldBlock(mid)) => { negotiating = true; *tls_stream = TlsStream::Handshake { sock: mid, @@ -264,7 +264,7 @@ impl io::Write for Stream { err } #[cfg(feature = "nativetls")] - Err(HandshakeError::Interrupted(mid)) => { + Err(HandshakeError::WouldBlock(mid)) => { negotiating = true; *tls_stream = TlsStream::Handshake { sock: mid,