Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ bnum = "0.13.0"
clickhouse-macros = { version = "0.3.0", path = "macros" }
criterion = "0.6"
serde = { version = "1.0.106", features = ["derive"] }
tokio = { version = "1.0.1", features = ["full", "test-util"] }
tokio = { version = "1.0.1", features = ["full", "test-util", "io-util"] }
hyper = { version = "1.1", features = ["server"] }
indexmap = { version = "2.10.0", features = ["serde"] }
linked-hash-map = { version = "0.5.6", features = ["serde_impl"] }
Expand Down
285 changes: 34 additions & 251 deletions src/insert.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
use crate::headers::{with_authentication, with_request_headers};
use crate::insert_formatted::BufInsertFormatted;
use crate::row_metadata::RowMetadata;
use crate::rowbinary::{serialize_row_binary, serialize_with_validation};
use crate::{
Client, Compression, RowWrite,
error::{Error, Result},
Client, RowWrite,
error::Result,
formats,
request_body::{ChunkSender, RequestBody},
response::Response,
row::{self, Row},
settings,
};
use bytes::{Bytes, BytesMut};
use clickhouse_types::put_rbwnat_columns_header;
use hyper::{self, Request};
use std::{future::Future, marker::PhantomData, mem, panic, pin::Pin, time::Duration};
use tokio::{
task::JoinHandle,
time::{Instant, Sleep},
};
use url::Url;
use std::{future::Future, marker::PhantomData, time::Duration};

// The desired max frame size.
const BUFFER_SIZE: usize = 256 * 1024;
Expand All @@ -35,7 +25,7 @@ const MIN_CHUNK_SIZE: usize = const {
/// The [`Insert::end`] must be called to finalize the `INSERT`.
/// Otherwise, the whole `INSERT` will be aborted.
///
/// Rows are being sent progressively to spread network load.
/// Rows are sent progressively to spread network load.
///
/// # Note: Metadata is Cached
/// If [validation is enabled][Client::with_validation],
Expand All @@ -48,92 +38,11 @@ const MIN_CHUNK_SIZE: usize = const {
/// after any changes to the current database schema.
#[must_use]
pub struct Insert<T> {
state: InsertState,
buffer: BytesMut,
insert: BufInsertFormatted,
row_metadata: Option<RowMetadata>,
#[cfg(feature = "lz4")]
compression: Compression,
send_timeout: Option<Duration>,
end_timeout: Option<Duration>,
// Use boxed `Sleep` to reuse a timer entry, it improves performance.
// Also, `tokio::time::timeout()` significantly increases a future's size.
sleep: Pin<Box<Sleep>>,
_marker: PhantomData<fn() -> T>, // TODO: test contravariance.
}

enum InsertState {
NotStarted {
client: Box<Client>,
sql: String,
},
Active {
sender: ChunkSender,
handle: JoinHandle<Result<()>>,
},
Terminated {
handle: JoinHandle<Result<()>>,
},
Completed,
}

impl InsertState {
fn sender(&mut self) -> Option<&mut ChunkSender> {
match self {
InsertState::Active { sender, .. } => Some(sender),
_ => None,
}
}

fn handle(&mut self) -> Option<&mut JoinHandle<Result<()>>> {
match self {
InsertState::Active { handle, .. } | InsertState::Terminated { handle } => Some(handle),
_ => None,
}
}

fn client_with_sql(&self) -> Option<(&Client, &str)> {
match self {
InsertState::NotStarted { client, sql } => Some((client, sql)),
_ => None,
}
}

#[inline]
fn expect_client_mut(&mut self) -> &mut Client {
let Self::NotStarted { client, .. } = self else {
panic!("cannot modify client options while an insert is in-progress")
};

client
}

fn terminated(&mut self) {
match mem::replace(self, InsertState::Completed) {
InsertState::NotStarted { .. } | InsertState::Completed => (),
InsertState::Active { handle, .. } => {
*self = InsertState::Terminated { handle };
}
InsertState::Terminated { handle } => {
*self = InsertState::Terminated { handle };
}
}
}
}

// It should be a regular function, but it decreases performance.
macro_rules! timeout {
($self:expr, $timeout:ident, $fut:expr) => {{
if let Some(timeout) = $self.$timeout {
$self.sleep.as_mut().reset(Instant::now() + timeout);
}

tokio::select! {
res = $fut => Some(res),
_ = &mut $self.sleep, if $self.$timeout.is_some() => None,
}
}};
}

impl<T> Insert<T> {
pub(crate) fn new(client: &Client, table: &str, row_metadata: Option<RowMetadata>) -> Self
where
Expand All @@ -152,18 +61,11 @@ impl<T> Insert<T> {
let sql = format!("INSERT INTO {table}({fields}) FORMAT {format}");

Self {
state: InsertState::NotStarted {
client: Box::new(client.clone()),
sql,
},
buffer: BytesMut::with_capacity(BUFFER_SIZE),
#[cfg(feature = "lz4")]
compression: client.compression,
send_timeout: None,
end_timeout: None,
sleep: Box::pin(tokio::time::sleep(Duration::new(0, 0))),
_marker: PhantomData,
insert: client
.insert_formatted_with(sql)
.buffered_with_capacity(BUFFER_SIZE),
row_metadata,
_marker: PhantomData,
}
}

Expand Down Expand Up @@ -202,7 +104,7 @@ impl<T> Insert<T> {
/// # Panics
/// If called after the request is started, e.g., after [`Insert::write`].
pub fn with_roles(mut self, roles: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.state.expect_client_mut().set_roles(roles);
self.insert.expect_client_mut().set_roles(roles);
self
}

Expand All @@ -216,7 +118,7 @@ impl<T> Insert<T> {
/// # Panics
/// If called after the request is started, e.g., after [`Insert::write`].
pub fn with_default_roles(mut self) -> Self {
self.state.expect_client_mut().clear_roles();
self.insert.expect_client_mut().clear_roles();
self
}

Expand All @@ -227,7 +129,7 @@ impl<T> Insert<T> {
/// If called after the request is started, e.g., after [`Insert::write`].
#[track_caller]
pub fn with_option(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.state.expect_client_mut().add_option(name, value);
self.insert.expect_client_mut().add_option(name, value);
self
}

Expand All @@ -236,8 +138,7 @@ impl<T> Insert<T> {
send_timeout: Option<Duration>,
end_timeout: Option<Duration>,
) {
self.send_timeout = send_timeout;
self.end_timeout = end_timeout;
self.insert.set_timeouts(send_timeout, end_timeout);
}

/// Serializes the provided row into an internal buffer.
Expand Down Expand Up @@ -270,30 +171,32 @@ impl<T> Insert<T> {

async move {
result?;
if self.buffer.len() >= MIN_CHUNK_SIZE {
self.send_chunk().await?;
if self.insert.buf_len() >= MIN_CHUNK_SIZE {
self.insert.flush().await?;
}
Ok(())
}
}

/// Returns the number of bytes written, not including the RBWNAT header.
#[inline(always)]
pub(crate) fn do_write(&mut self, row: &T::Value<'_>) -> Result<usize>
where
T: RowWrite,
{
match self.state {
InsertState::NotStarted { .. } => self.init_request(),
InsertState::Active { .. } => Ok(()),
_ => panic!("write() after error"),
}?;
// We don't want to wait for the buffer to be full before we start the request,
// in the event of an error.
self.init_request_if_required()?;

let old_buf_size = self.buffer.len();
// The following calls need an `impl BufMut`
let buffer = self.insert.buffer_mut();

let old_buf_size = buffer.len();
let result = match &self.row_metadata {
Some(metadata) => serialize_with_validation(&mut self.buffer, row, metadata),
None => serialize_row_binary(&mut self.buffer, row),
Some(metadata) => serialize_with_validation(&mut *buffer, row, metadata),
None => serialize_row_binary(&mut *buffer, row),
};
let written = self.buffer.len() - old_buf_size;
let written = buffer.len() - old_buf_size;

if result.is_err() {
self.abort();
Expand All @@ -309,141 +212,21 @@ impl<T> Insert<T> {
///
/// NOTE: If it isn't called, the whole `INSERT` is aborted.
pub async fn end(mut self) -> Result<()> {
if !self.buffer.is_empty() {
self.send_chunk().await?;
}
self.state.terminated();
self.wait_handle().await
}

async fn send_chunk(&mut self) -> Result<()> {
debug_assert!(matches!(self.state, InsertState::Active { .. }));

// Hyper uses non-trivial and inefficient schema of buffering chunks.
// It's difficult to determine when allocations occur.
// So, instead we control it manually here and rely on the system allocator.
let chunk = self.take_and_prepare_chunk()?;
let sender = self.state.sender().unwrap(); // checked above

let is_timed_out = match timeout!(self, send_timeout, sender.send(chunk)) {
Some(true) => return Ok(()),
Some(false) => false, // an actual error will be returned from `wait_handle`
None => true,
};

// Error handling.

self.abort();

// TODO: is it required to wait the handle in the case of timeout?
let res = self.wait_handle().await;

if is_timed_out {
Err(Error::TimedOut)
} else {
res?; // a real error should be here.
Err(Error::Network("channel closed".into()))
}
}

async fn wait_handle(&mut self) -> Result<()> {
match self.state.handle() {
Some(handle) => {
let result = match timeout!(self, end_timeout, &mut *handle) {
Some(Ok(res)) => res,
Some(Err(err)) if err.is_panic() => panic::resume_unwind(err.into_panic()),
Some(Err(err)) => Err(Error::Custom(format!("unexpected error: {err}"))),
None => {
// We can do nothing useful here, so just shut down the background task.
handle.abort();
Err(Error::TimedOut)
}
};
self.state = InsertState::Completed;
result
}
_ => Ok(()),
}
}

#[cfg(feature = "lz4")]
fn take_and_prepare_chunk(&mut self) -> Result<Bytes> {
Ok(if self.compression.is_lz4() {
let compressed = crate::compression::lz4::compress(&self.buffer)?;
self.buffer.clear();
compressed
} else {
mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE)).freeze()
})
self.insert.end().await
}

#[cfg(not(feature = "lz4"))]
fn take_and_prepare_chunk(&mut self) -> Result<Bytes> {
Ok(mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE)).freeze())
}

#[cold]
#[track_caller]
#[inline(never)]
fn init_request(&mut self) -> Result<()> {
debug_assert!(matches!(self.state, InsertState::NotStarted { .. }));
let (client, sql) = self.state.client_with_sql().unwrap(); // checked above

let mut url = Url::parse(&client.url).map_err(|err| Error::InvalidParams(err.into()))?;
let mut pairs = url.query_pairs_mut();
pairs.clear();
fn init_request_if_required(&mut self) -> Result<()> {
let fresh_request = self.insert.init_request_if_required()?;

if let Some(database) = &client.database {
pairs.append_pair(settings::DATABASE, database);
if fresh_request && let Some(metadata) = &self.row_metadata {
put_rbwnat_columns_header(&metadata.columns, self.insert.buffer_mut())
.inspect_err(|_| self.abort())?;
}

pairs.append_pair(settings::QUERY, sql);

if client.compression.is_lz4() {
pairs.append_pair(settings::DECOMPRESS, "1");
}

for (name, value) in &client.options {
pairs.append_pair(name, value);
}

drop(pairs);

let mut builder = Request::post(url.as_str());
builder = with_request_headers(builder, &client.headers, &client.products_info);
builder = with_authentication(builder, &client.authentication);

let (sender, body) = RequestBody::chunked();

let request = builder
.body(body)
.map_err(|err| Error::InvalidParams(Box::new(err)))?;

let future = client.http.request(request);
// TODO: introduce `Executor` to allow bookkeeping of spawned tasks.
let handle =
tokio::spawn(async move { Response::new(future, Compression::None).finish().await });

match self.row_metadata {
None => (), // RowBinary is used, no header is required.
Some(ref metadata) => {
put_rbwnat_columns_header(&metadata.columns, &mut self.buffer)?;
}
}

self.state = InsertState::Active { handle, sender };
Ok(())
}

fn abort(&mut self) {
if let Some(sender) = self.state.sender() {
sender.abort();
}
}
}

impl<T> Drop for Insert<T> {
fn drop(&mut self) {
self.abort();
self.insert.abort();
}
}
Loading