diff --git a/crates/fluss/src/client/credentials.rs b/crates/fluss/src/client/credentials.rs index ffb682e..ab401bc 100644 --- a/crates/fluss/src/client/credentials.rs +++ b/crates/fluss/src/client/credentials.rs @@ -156,3 +156,55 @@ impl CredentialsCache { Ok(props) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::metadata::Metadata; + use crate::cluster::Cluster; + + #[test] + fn convert_hadoop_key_to_opendal_maps_known_keys() { + let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.endpoint").expect("key"); + assert_eq!(key, "endpoint"); + assert!(!invert); + + let (key, invert) = + convert_hadoop_key_to_opendal("fs.s3a.path.style.access").expect("key"); + assert_eq!(key, "enable_virtual_host_style"); + assert!(invert); + + assert!(convert_hadoop_key_to_opendal("fs.s3a.connection.ssl.enabled").is_none()); + assert!(convert_hadoop_key_to_opendal("unknown.key").is_none()); + } + + #[tokio::test] + async fn credentials_cache_returns_cached_props() -> Result<()> { + let cached = CachedToken { + access_key_id: "ak".to_string(), + secret_access_key: "sk".to_string(), + security_token: Some("token".to_string()), + addition_infos: HashMap::from([( + "fs.s3a.path.style.access".to_string(), + "true".to_string(), + )]), + cached_at: Instant::now(), + }; + + let cache = CredentialsCache { + inner: RwLock::new(Some(cached)), + rpc_client: Arc::new(RpcClient::new()), + metadata: Arc::new(Metadata::new_for_test(Arc::new(Cluster::default()))), + }; + + let props = cache.get_or_refresh().await?; + assert_eq!(props.get("access_key_id"), Some(&"ak".to_string())); + assert_eq!(props.get("secret_access_key"), Some(&"sk".to_string())); + assert_eq!(props.get("security_token"), Some(&"token".to_string())); + assert_eq!( + props.get("enable_virtual_host_style"), + Some(&"false".to_string()) + ); + Ok(()) + } +} diff --git a/crates/fluss/src/client/metadata.rs b/crates/fluss/src/client/metadata.rs index a514422..d7324b3 100644 --- a/crates/fluss/src/client/metadata.rs +++ b/crates/fluss/src/client/metadata.rs @@ -135,7 +135,93 @@ impl Metadata { guard.clone() } - pub fn leader_for(&self, _table_bucket: &TableBucket) -> Option<&ServerNode> { - todo!() + pub fn leader_for(&self, table_bucket: &TableBucket) -> Option { + let cluster = self.cluster.read(); + cluster.leader_for(table_bucket).cloned() + } +} + +#[cfg(test)] +impl Metadata { + pub(crate) fn new_for_test(cluster: Arc) -> Self { + Metadata { + cluster: RwLock::new(cluster), + connections: Arc::new(RpcClient::new()), + bootstrap: Arc::from(""), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; + use crate::metadata::{DataField, DataTypes, Schema, TableDescriptor, TableInfo, TablePath}; + use std::collections::HashMap; + + fn build_table_info(table_path: TablePath, table_id: i64) -> TableInfo { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor build"); + TableInfo::of(table_path, table_id, 1, table_descriptor, 0, 0) + } + + fn build_cluster(table_path: &TablePath, table_id: i64) -> Arc { + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let table_bucket = TableBucket::new(table_id, 0); + let bucket_location = + BucketLocation::new(table_bucket.clone(), Some(server.clone()), table_path.clone()); + + let mut servers = HashMap::new(); + servers.insert(server.id(), server); + + let mut locations_by_path = HashMap::new(); + locations_by_path.insert(table_path.clone(), vec![bucket_location.clone()]); + + let mut locations_by_bucket = HashMap::new(); + locations_by_bucket.insert(table_bucket, bucket_location); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert(table_path.clone(), build_table_info(table_path.clone(), table_id)); + + Arc::new(Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + )) + } + + #[test] + fn leader_for_returns_server() { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let cluster = build_cluster(&table_path, 1); + let metadata = Metadata::new_for_test(cluster); + let leader = metadata.leader_for(&TableBucket::new(1, 0)).expect("leader"); + assert_eq!(leader.id(), 1); + } + + #[test] + fn invalidate_server_removes_leader() { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let cluster = build_cluster(&table_path, 1); + let metadata = Metadata::new_for_test(cluster); + metadata.invalidate_server(&1, vec![1]); + let cluster = metadata.get_cluster(); + assert!(cluster.get_tablet_server(1).is_none()); } } diff --git a/crates/fluss/src/client/table/log_fetch_buffer.rs b/crates/fluss/src/client/table/log_fetch_buffer.rs index cee104e..001567b 100644 --- a/crates/fluss/src/client/table/log_fetch_buffer.rs +++ b/crates/fluss/src/client/table/log_fetch_buffer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::error::Result; +use crate::error::{Error, Result}; use crate::metadata::TableBucket; use crate::record::{ LogRecordBatch, LogRecordIterator, LogRecordsBatches, ReadContext, ScanRecord, @@ -32,6 +32,7 @@ pub trait CompletedFetch: Send + Sync { fn table_bucket(&self) -> &TableBucket; fn fetch_records(&mut self, max_records: usize) -> Result>; fn is_consumed(&self) -> bool; + fn records_read(&self) -> usize; fn drain(&mut self); fn size_in_bytes(&self) -> usize; fn high_watermark(&self) -> i64; @@ -54,6 +55,7 @@ pub struct LogFetchBuffer { next_in_line_fetch: Mutex>>, not_empty_notify: Notify, woken_up: Arc, + error: Mutex>, } impl LogFetchBuffer { @@ -64,6 +66,7 @@ impl LogFetchBuffer { next_in_line_fetch: Mutex::new(None), not_empty_notify: Notify::new(), woken_up: Arc::new(AtomicBool::new(false)), + error: Mutex::new(None), } } @@ -72,26 +75,32 @@ impl LogFetchBuffer { self.completed_fetches.lock().is_empty() } - /// Wait for the buffer to become non-empty, with timeout - /// Returns true if data became available, false if timeout - pub async fn await_not_empty(&self, timeout: Duration) -> bool { + /// Wait for the buffer to become non-empty, with timeout. + /// Returns true if data became available, false if timeout. + pub async fn await_not_empty(&self, timeout: Duration) -> Result { let deadline = std::time::Instant::now() + timeout; loop { + if let Some(error) = self.take_error() { + return Err(error); + } + // Check if buffer is not empty if !self.is_empty() { - return true; + return Ok(true); } // Check if woken up if self.woken_up.swap(false, Ordering::Acquire) { - return true; + return Err(Error::WakeupError { + message: "The await is wakeup.".to_string(), + }); } // Check if timeout let now = std::time::Instant::now(); if now >= deadline { - return false; + return Ok(false); } // Wait for notification with remaining time @@ -99,7 +108,7 @@ impl LogFetchBuffer { let notified = self.not_empty_notify.notified(); tokio::select! { _ = tokio::time::sleep(remaining) => { - return false; // Timeout + return Ok(false); // Timeout } _ = notified => { // Got notification, check again @@ -116,6 +125,18 @@ impl LogFetchBuffer { self.not_empty_notify.notify_waiters(); } + pub(crate) fn set_error(&self, error: Error) { + let mut guard = self.error.lock(); + if guard.is_none() { + *guard = Some(error); + } + self.not_empty_notify.notify_waiters(); + } + + pub(crate) fn take_error(&self) -> Option { + self.error.lock().take() + } + /// Add a pending fetch to the buffer pub fn pend(&self, pending_fetch: Box) { let table_bucket = pending_fetch.table_bucket().clone(); @@ -133,6 +154,7 @@ impl LogFetchBuffer { // holding both locks simultaneously. let mut completed_to_push: Vec> = Vec::new(); let mut has_completed = false; + let mut pending_error: Option = None; { let mut pending_map = self.pending_fetches.lock(); if let Some(pendings) = pending_map.get_mut(table_bucket) { @@ -145,8 +167,9 @@ impl LogFetchBuffer { has_completed = true; } Err(e) => { - // todo: handle exception? - log::error!("Error when completing: {e}"); + pending_error = Some(e); + has_completed = true; + break; } } } else { @@ -166,6 +189,11 @@ impl LogFetchBuffer { } } + if let Some(error) = pending_error { + self.set_error(error); + return; + } + if has_completed { // Signal that buffer is not empty self.not_empty_notify.notify_waiters(); @@ -269,6 +297,9 @@ pub struct DefaultCompletedFetch { records_read: usize, current_record_iterator: Option, current_record_batch: Option, + last_record: Option, + cached_record_error: Option, + corrupt_last_record: bool, } impl DefaultCompletedFetch { @@ -292,6 +323,9 @@ impl DefaultCompletedFetch { records_read: 0, current_record_iterator: None, current_record_batch: None, + last_record: None, + cached_record_error: None, + corrupt_last_record: false, }) } @@ -318,6 +352,20 @@ impl DefaultCompletedFetch { } } } + + fn fetch_error(&self) -> Error { + let mut message = format!( + "Received exception when fetching the next record from {table_bucket}. If needed, please back to past the record to continue scanning.", + table_bucket = self.table_bucket + ); + if let Some(cause) = self.cached_record_error.as_deref() { + message.push_str(&format!(" Cause: {cause}")); + } + Error::UnexpectedError { + message, + source: None, + } + } } impl CompletedFetch for DefaultCompletedFetch { @@ -326,7 +374,10 @@ impl CompletedFetch for DefaultCompletedFetch { } fn fetch_records(&mut self, max_records: usize) -> Result> { - // todo: handle corrupt_last_record + if self.corrupt_last_record { + return Err(self.fetch_error()); + } + if self.consumed { return Ok(Vec::new()); } @@ -334,13 +385,34 @@ impl CompletedFetch for DefaultCompletedFetch { let mut scan_records = Vec::new(); for _ in 0..max_records { - if let Some(record) = self.next_fetched_record()? { - self.next_fetch_offset = record.offset() + 1; - self.records_read += 1; - scan_records.push(record); - } else { - break; + if self.cached_record_error.is_none() { + self.corrupt_last_record = true; + match self.next_fetched_record() { + Ok(Some(record)) => { + self.corrupt_last_record = false; + self.last_record = Some(record); + } + Ok(None) => { + self.corrupt_last_record = false; + self.last_record = None; + } + Err(e) => { + self.cached_record_error = Some(e.to_string()); + } + } } + + let Some(record) = self.last_record.take() else { + break; + }; + + self.next_fetch_offset = record.offset() + 1; + self.records_read += 1; + scan_records.push(record); + } + + if self.cached_record_error.is_some() && scan_records.is_empty() { + return Err(self.fetch_error()); } Ok(scan_records) @@ -350,8 +422,15 @@ impl CompletedFetch for DefaultCompletedFetch { self.consumed } + fn records_read(&self) -> usize { + self.records_read + } + fn drain(&mut self) { self.consumed = true; + self.cached_record_error = None; + self.corrupt_last_record = false; + self.last_record = None; } fn size_in_bytes(&self) -> usize { @@ -374,3 +453,106 @@ impl CompletedFetch for DefaultCompletedFetch { self.next_fetch_offset } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::WriteRecord; + use crate::compression::{ + ArrowCompressionInfo, ArrowCompressionType, DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }; + use crate::metadata::{DataField, DataTypes, TablePath}; + use crate::record::{MemoryLogRecordsArrowBuilder, to_arrow_schema}; + use crate::row::GenericRow; + use std::sync::Arc; + use std::time::Duration; + + struct ErrorPendingFetch { + table_bucket: TableBucket, + } + + impl PendingFetch for ErrorPendingFetch { + fn table_bucket(&self) -> &TableBucket { + &self.table_bucket + } + + fn is_completed(&self) -> bool { + true + } + + fn to_completed_fetch(self: Box) -> Result> { + Err(Error::UnexpectedError { + message: "pending fetch failure".to_string(), + source: None, + }) + } + } + + #[tokio::test] + async fn await_not_empty_returns_wakeup_error() { + let buffer = LogFetchBuffer::new(); + buffer.wakeup(); + + let result = buffer.await_not_empty(Duration::from_millis(10)).await; + assert!(matches!(result, Err(Error::WakeupError { .. }))); + } + + #[tokio::test] + async fn await_not_empty_returns_pending_error() { + let buffer = LogFetchBuffer::new(); + let table_bucket = TableBucket::new(1, 0); + buffer.pend(Box::new(ErrorPendingFetch { + table_bucket: table_bucket.clone(), + })); + buffer.try_complete(&table_bucket); + + let result = buffer.await_not_empty(Duration::from_millis(10)).await; + assert!(matches!(result, Err(Error::UnexpectedError { .. }))); + } + + #[test] + fn default_completed_fetch_reads_records() -> Result<()> { + let row_type = DataTypes::row(vec![ + DataField::new("id".to_string(), DataTypes::int(), None), + DataField::new("name".to_string(), DataTypes::string(), None), + ]); + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + + let mut builder = MemoryLogRecordsArrowBuilder::new( + 1, + &row_type, + false, + ArrowCompressionInfo { + compression_type: ArrowCompressionType::None, + compression_level: DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }, + ); + + let mut row = GenericRow::new(); + row.set_field(0, 1_i32); + row.set_field(1, "alice"); + let record = WriteRecord::new(table_path, row); + builder.append(&record)?; + + let data = builder.build()?; + let log_records = LogRecordsBatches::new(data.clone()); + let read_context = ReadContext::new(to_arrow_schema(&row_type), false); + let mut fetch = DefaultCompletedFetch::new( + TableBucket::new(1, 0), + log_records, + data.len(), + read_context, + 0, + 0, + )?; + + let records = fetch.fetch_records(10)?; + assert_eq!(records.len(), 1); + assert_eq!(records[0].offset(), 0); + + let empty = fetch.fetch_records(10)?; + assert!(empty.is_empty()); + + Ok(()) + } +} diff --git a/crates/fluss/src/client/table/scanner.rs b/crates/fluss/src/client/table/scanner.rs index 0acaac8..5afdabd 100644 --- a/crates/fluss/src/client/table/scanner.rs +++ b/crates/fluss/src/client/table/scanner.rs @@ -24,7 +24,7 @@ use crate::client::table::log_fetch_buffer::{ use crate::client::table::remote_log::{ RemoteLogDownloader, RemoteLogFetchInfo, RemotePendingFetch, }; -use crate::error::{Error, Result, RpcError}; +use crate::error::{ApiError, Error, FlussError, Result, RpcError}; use crate::metadata::{TableBucket, TableInfo, TablePath}; use crate::proto::{FetchLogRequest, PbFetchLogReqForBucket, PbFetchLogReqForTable}; use crate::record::{LogRecordsBatches, ReadContext, ScanRecord, ScanRecords, to_arrow_schema}; @@ -73,7 +73,12 @@ impl<'a> TableScan<'a> { /// /// # Example /// ``` - /// let scanner = table.new_scan().project(&[0, 2, 3])?.create_log_scanner(); + /// # use fluss::client::FlussTable; + /// # use fluss::error::Result; + /// # fn example(table: &FlussTable<'_>) -> Result<()> { + /// let _scanner = table.new_scan().project(&[0, 2, 3])?.create_log_scanner()?; + /// # Ok(()) + /// # } /// ``` pub fn project(mut self, column_indices: &[usize]) -> Result { if column_indices.is_empty() { @@ -107,7 +112,15 @@ impl<'a> TableScan<'a> { /// /// # Example /// ``` - /// let scanner = table.new_scan().project_by_name(&["col1", "col3"])?.create_log_scanner(); + /// # use fluss::client::FlussTable; + /// # use fluss::error::Result; + /// # fn example(table: &FlussTable<'_>) -> Result<()> { + /// let _scanner = table + /// .new_scan() + /// .project_by_name(&["col1", "col3"])? + /// .create_log_scanner()?; + /// # Ok(()) + /// # } /// ``` pub fn project_by_name(mut self, column_names: &[&str]) -> Result { if column_names.is_empty() { @@ -202,7 +215,7 @@ impl LogScanner { .log_fetcher .log_fetch_buffer .await_not_empty(remaining) - .await; + .await?; if !has_data { // Timeout while waiting @@ -284,10 +297,13 @@ impl LogFetcher { projected_fields: Option>, ) -> Result { let full_arrow_schema = to_arrow_schema(table_info.get_row_type()); - let read_context = - Self::create_read_context(full_arrow_schema.clone(), projected_fields.clone(), false); + let read_context = Self::create_read_context( + full_arrow_schema.clone(), + projected_fields.clone(), + false, + )?; let remote_read_context = - Self::create_read_context(full_arrow_schema, projected_fields.clone(), true); + Self::create_read_context(full_arrow_schema, projected_fields.clone(), true)?; let tmp_dir = TempDir::with_prefix("fluss-remote-logs")?; @@ -310,23 +326,18 @@ impl LogFetcher { full_arrow_schema: SchemaRef, projected_fields: Option>, is_from_remote: bool, - ) -> ReadContext { + ) -> Result { match projected_fields { - None => ReadContext::new(full_arrow_schema, is_from_remote), - Some(fields) => { - ReadContext::with_projection_pushdown(full_arrow_schema, fields, is_from_remote) - } + None => Ok(ReadContext::new(full_arrow_schema, is_from_remote)), + Some(fields) => ReadContext::with_projection_pushdown( + full_arrow_schema, + fields, + is_from_remote, + ), } } async fn check_and_update_metadata(&self) -> Result<()> { - if self.is_partitioned { - // TODO: Implement partition-aware metadata refresh for buckets whose leaders are unknown. - // The implementation will likely need to collect partition IDs for such buckets and - // perform targeted metadata updates. Until then, we avoid computing unused partition_ids. - return Ok(()); - } - let need_update = self .fetchable_buckets() .iter() @@ -336,6 +347,24 @@ impl LogFetcher { return Ok(()); } + if self.is_partitioned { + // Fallback to full table metadata refresh until partition-aware updates are available. + self.metadata + .update_tables_metadata(&HashSet::from([&self.table_path])) + .await + .or_else(|e| { + if let Error::RpcError { source, .. } = &e + && matches!(source, RpcError::ConnectionError(_) | RpcError::Poisoned(_)) + { + warn!("Retrying after encountering error while updating table metadata: {e}"); + Ok(()) + } else { + Err(e) + } + })?; + return Ok(()); + } + // TODO: Handle PartitionNotExist error self.metadata .update_tables_metadata(&HashSet::from([&self.table_path])) @@ -375,6 +404,7 @@ impl LogFetcher { let creds_cache = self.credentials_cache.clone(); let nodes_with_pending = self.nodes_with_pending_fetch_requests.clone(); let metadata = self.metadata.clone(); + let table_path = self.table_path.clone(); // Spawn async task to handle the fetch request // Note: These tasks are not explicitly tracked or cancelled when LogFetcher is dropped. @@ -425,6 +455,8 @@ impl LogFetcher { fetch_response, &log_fetch_buffer, &log_scanner_status, + &metadata, + &table_path, &read_context, &remote_read_context, &remote_log_downloader, @@ -433,6 +465,7 @@ impl LogFetcher { .await { error!("Fail to handle fetch response: {e:?}"); + log_fetch_buffer.set_error(e); } }); } @@ -454,6 +487,8 @@ impl LogFetcher { fetch_response: crate::proto::FetchLogResponse, log_fetch_buffer: &Arc, log_scanner_status: &Arc, + metadata: &Arc, + table_path: &TablePath, read_context: &ReadContext, remote_read_context: &ReadContext, remote_log_downloader: &Arc, @@ -475,6 +510,90 @@ impl LogFetcher { continue; }; + let error_code = fetch_log_for_bucket + .error_code + .unwrap_or(FlussError::None.code()); + if error_code != FlussError::None.code() { + let error = FlussError::for_code(error_code); + let error_message = fetch_log_for_bucket + .error_message + .clone() + .unwrap_or_else(|| error.message().to_string()); + + log_scanner_status.move_bucket_to_end(table_bucket.clone()); + + match error { + FlussError::NotLeaderOrFollower + | FlussError::LogStorageException + | FlussError::KvStorageException + | FlussError::StorageException + | FlussError::FencedLeaderEpochException => { + debug!( + "Error in fetch for bucket {table_bucket}: {error:?}: {error_message}" + ); + if let Err(e) = metadata + .update_tables_metadata(&HashSet::from([table_path])) + .await + { + warn!( + "Failed to update metadata for {table_path} after fetch error {error:?}: {e:?}" + ); + } + } + FlussError::UnknownTableOrBucketException => { + warn!( + "Received unknown table or bucket error in fetch for bucket {table_bucket}" + ); + if let Err(e) = metadata + .update_tables_metadata(&HashSet::from([table_path])) + .await + { + warn!( + "Failed to update metadata for {table_path} after unknown table or bucket error: {e:?}" + ); + } + } + FlussError::LogOffsetOutOfRangeException => { + log_fetch_buffer.set_error(Error::UnexpectedError { + message: format!( + "The fetching offset {fetch_offset} is out of range: {error_message}" + ), + source: None, + }); + } + FlussError::AuthorizationException => { + log_fetch_buffer.set_error(Error::FlussAPIError { + api_error: ApiError { + code: error_code, + message: error_message.clone(), + }, + }); + } + FlussError::UnknownServerError => { + warn!( + "Unknown server error while fetching offset {fetch_offset} for bucket {table_bucket}: {error_message}" + ); + } + FlussError::CorruptMessage => { + log_fetch_buffer.set_error(Error::UnexpectedError { + message: format!( + "Encountered corrupt message when fetching offset {fetch_offset} for bucket {table_bucket}: {error_message}" + ), + source: None, + }); + } + _ => { + log_fetch_buffer.set_error(Error::UnexpectedError { + message: format!( + "Unexpected error code {error:?} while fetching at offset {fetch_offset} from bucket {table_bucket}: {error_message}" + ), + source: None, + }); + } + } + continue; + } + // Check if this is a remote log fetch if let Some(ref remote_log_fetch_info) = fetch_log_for_bucket.remote_log_fetch_info { @@ -514,8 +633,8 @@ impl LogFetcher { log_fetch_buffer.add(Box::new(completed_fetch)); } Err(e) => { - // todo: handle error - log::warn!("Failed to create completed fetch: {e:?}"); + warn!("Failed to create completed fetch: {e:?}"); + log_fetch_buffer.set_error(e); } } } @@ -576,6 +695,7 @@ impl LogFetcher { const MAX_POLL_RECORDS: usize = 500; // Default max poll records let mut result: HashMap> = HashMap::new(); let mut records_remaining = MAX_POLL_RECORDS; + let mut pending_error = self.log_fetch_buffer.take_error(); while records_remaining > 0 { // Get the next in line fetch, or get a new one from buffer @@ -601,7 +721,11 @@ impl LogFetcher { // todo: do we need to consider it like java ? // self.log_fetch_buffer.poll(); } - return Err(e); + if result.is_empty() { + return Err(e); + } + self.log_fetch_buffer.set_error(e); + break; } } } else { @@ -617,7 +741,16 @@ impl LogFetcher { // Fetch records from next_in_line if let Some(mut next_fetch) = next_in_line { let records = - self.fetch_records_from_fetch(&mut next_fetch, records_remaining)?; + match self.fetch_records_from_fetch(&mut next_fetch, records_remaining) { + Ok(records) => records, + Err(e) => { + if result.is_empty() { + return Err(e); + } + self.log_fetch_buffer.set_error(e); + break; + } + }; if !records.is_empty() { let table_bucket = next_fetch.table_bucket().clone(); @@ -639,6 +772,17 @@ impl LogFetcher { } } + if pending_error.is_none() { + pending_error = self.log_fetch_buffer.take_error(); + } + + if let Some(error) = pending_error { + if result.is_empty() { + return Err(error); + } + self.log_fetch_buffer.set_error(error); + } + Ok(result) } @@ -708,6 +852,11 @@ impl LogFetcher { .update_offset(&table_bucket, next_fetch_offset); } + if next_in_line_fetch.is_consumed() && next_in_line_fetch.records_read() > 0 { + self.log_scanner_status + .move_bucket_to_end(table_bucket.clone()); + } + Ok(records) } else { // These records aren't next in line, ignore them @@ -943,3 +1092,232 @@ impl BucketScanStatus { *self.high_watermark.write() = high_watermark } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::metadata::Metadata; + use crate::client::WriteRecord; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; + use crate::compression::{ + ArrowCompressionInfo, ArrowCompressionType, DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }; + use crate::metadata::{DataField, DataTypes, Schema, TableDescriptor, TableInfo, TablePath}; + use crate::record::MemoryLogRecordsArrowBuilder; + use crate::row::{Datum, GenericRow}; + use crate::rpc::FlussError; + + fn build_table_info(table_path: TablePath, table_id: i64) -> TableInfo { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor build"); + TableInfo::of(table_path, table_id, 1, table_descriptor, 0, 0) + } + + fn build_cluster(table_path: &TablePath, table_id: i64) -> Arc { + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let table_bucket = TableBucket::new(table_id, 0); + let bucket_location = + BucketLocation::new(table_bucket.clone(), Some(server.clone()), table_path.clone()); + + let mut servers = HashMap::new(); + servers.insert(server.id(), server); + + let mut locations_by_path = HashMap::new(); + locations_by_path.insert(table_path.clone(), vec![bucket_location.clone()]); + + let mut locations_by_bucket = HashMap::new(); + locations_by_bucket.insert(table_bucket, bucket_location); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert( + table_path.clone(), + build_table_info(table_path.clone(), table_id), + ); + + Arc::new(Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + )) + } + + fn build_records(table_info: &TableInfo, table_path: Arc) -> Result> { + let mut builder = MemoryLogRecordsArrowBuilder::new( + 1, + table_info.get_row_type(), + false, + ArrowCompressionInfo { + compression_type: ArrowCompressionType::None, + compression_level: DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }, + ); + let record = WriteRecord::new( + table_path, + GenericRow { + values: vec![Datum::Int32(1)], + }, + ); + builder.append(&record)?; + builder.build() + } + + #[tokio::test] + async fn collect_fetches_updates_offset() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let table_info = build_table_info(table_path.clone(), 1); + let cluster = build_cluster(&table_path, 1); + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let status = Arc::new(LogScannerStatus::new()); + let fetcher = LogFetcher::new( + table_info.clone(), + Arc::new(RpcClient::new()), + metadata, + status.clone(), + None, + )?; + + let bucket = TableBucket::new(1, 0); + status.assign_scan_bucket(bucket.clone(), 0); + + let data = build_records(&table_info, Arc::new(table_path))?; + let log_records = LogRecordsBatches::new(data.clone()); + let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type()), false); + let completed = DefaultCompletedFetch::new( + bucket.clone(), + log_records, + data.len(), + read_context, + 0, + 0, + )?; + fetcher.log_fetch_buffer.add(Box::new(completed)); + + let fetched = fetcher.collect_fetches()?; + assert_eq!(fetched.get(&bucket).unwrap().len(), 1); + assert_eq!(status.get_bucket_offset(&bucket), Some(1)); + Ok(()) + } + + #[test] + fn fetch_records_from_fetch_drains_unassigned_bucket() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let table_info = build_table_info(table_path.clone(), 1); + let cluster = build_cluster(&table_path, 1); + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let status = Arc::new(LogScannerStatus::new()); + let fetcher = LogFetcher::new( + table_info.clone(), + Arc::new(RpcClient::new()), + metadata, + status, + None, + )?; + + let bucket = TableBucket::new(1, 0); + let data = build_records(&table_info, Arc::new(table_path))?; + let log_records = LogRecordsBatches::new(data.clone()); + let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type()), false); + let mut completed: Box = Box::new(DefaultCompletedFetch::new( + bucket, + log_records, + data.len(), + read_context, + 0, + 0, + )?); + + let records = fetcher.fetch_records_from_fetch(&mut completed, 10)?; + assert!(records.is_empty()); + assert!(completed.is_consumed()); + Ok(()) + } + + #[tokio::test] + async fn prepare_fetch_log_requests_skips_pending() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let table_info = build_table_info(table_path.clone(), 1); + let cluster = build_cluster(&table_path, 1); + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let status = Arc::new(LogScannerStatus::new()); + status.assign_scan_bucket(TableBucket::new(1, 0), 0); + let fetcher = LogFetcher::new( + table_info, + Arc::new(RpcClient::new()), + metadata, + status, + None, + )?; + + fetcher + .nodes_with_pending_fetch_requests + .lock() + .insert(1); + + let requests = fetcher.prepare_fetch_log_requests().await; + assert!(requests.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn handle_fetch_response_sets_error() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let table_info = build_table_info(table_path.clone(), 1); + let cluster = build_cluster(&table_path, 1); + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let status = Arc::new(LogScannerStatus::new()); + status.assign_scan_bucket(TableBucket::new(1, 0), 5); + let fetcher = LogFetcher::new( + table_info.clone(), + Arc::new(RpcClient::new()), + metadata.clone(), + status.clone(), + None, + )?; + + let response = crate::proto::FetchLogResponse { + tables_resp: vec![crate::proto::PbFetchLogRespForTable { + table_id: 1, + buckets_resp: vec![crate::proto::PbFetchLogRespForBucket { + bucket_id: 0, + error_code: Some(FlussError::AuthorizationException.code()), + error_message: Some("denied".to_string()), + ..Default::default() + }], + ..Default::default() + }], + }; + + LogFetcher::handle_fetch_response( + response, + &fetcher.log_fetch_buffer, + &fetcher.log_scanner_status, + &metadata, + &TablePath::new("db".to_string(), "tbl".to_string()), + &fetcher.read_context, + &fetcher.remote_read_context, + &fetcher.remote_log_downloader, + &fetcher.credentials_cache, + ) + .await?; + + let error = fetcher.log_fetch_buffer.take_error().expect("error"); + assert!(matches!(error, Error::FlussAPIError { .. })); + Ok(()) + } +} diff --git a/crates/fluss/src/client/write/accumulator.rs b/crates/fluss/src/client/write/accumulator.rs index 215adbe..d3ef382 100644 --- a/crates/fluss/src/client/write/accumulator.rs +++ b/crates/fluss/src/client/write/accumulator.rs @@ -250,7 +250,7 @@ impl RecordAccumulator { cluster: Arc, nodes: &HashSet, max_size: i32, - ) -> Result>>> { + ) -> Result>> { if nodes.is_empty() { return Ok(HashMap::new()); } @@ -272,7 +272,7 @@ impl RecordAccumulator { cluster: &Cluster, node: &ServerNode, max_size: i32, - ) -> Result>> { + ) -> Result> { let mut size = 0; let buckets = self.get_all_buckets_in_current_node(node, cluster); let mut ready = Vec::new(); @@ -323,10 +323,10 @@ impl RecordAccumulator { // mark the batch as drained. batch.drained(current_time_ms()); - ready.push(Arc::new(ReadyWriteBatch { + ready.push(ReadyWriteBatch { table_bucket, write_batch: batch, - })); + }); } } if current_index == start { @@ -340,6 +340,29 @@ impl RecordAccumulator { self.incomplete_batches.write().remove(&batch_id); } + pub async fn re_enqueue(&self, ready_write_batch: ReadyWriteBatch) { + ready_write_batch.write_batch.re_enqueued(); + let table_path = ready_write_batch.write_batch.table_path().clone(); + let bucket_id = ready_write_batch.table_bucket.bucket_id(); + let table_id = u64::try_from(ready_write_batch.table_bucket.table_id()).unwrap_or(0); + let mut binding = self + .write_batches + .entry(table_path) + .or_insert_with(|| BucketAndWriteBatches { + table_id, + is_partitioned_table: false, + partition_id: None, + batches: Default::default(), + }); + let bucket_and_batches = binding.value_mut(); + let dq = bucket_and_batches + .batches + .entry(bucket_id) + .or_insert_with(|| Mutex::new(VecDeque::new())); + let mut dq_guard = dq.lock().await; + dq_guard.push_front(ready_write_batch.write_batch); + } + fn get_all_buckets_in_current_node( &self, current: &ServerNode, @@ -381,6 +404,93 @@ pub struct ReadyWriteBatch { pub write_batch: WriteBatch, } +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; + use crate::metadata::{DataField, DataTypes, Schema, TableDescriptor, TableInfo, TablePath}; + use crate::row::{Datum, GenericRow}; + use std::sync::Arc; + + fn build_table_info(table_path: TablePath, table_id: i64) -> TableInfo { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor build"); + TableInfo::of(table_path, table_id, 1, table_descriptor, 0, 0) + } + + fn build_cluster(table_path: &TablePath, table_id: i64) -> Cluster { + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let table_bucket = TableBucket::new(table_id, 0); + let bucket_location = + BucketLocation::new(table_bucket.clone(), Some(server.clone()), table_path.clone()); + + let mut servers = HashMap::new(); + servers.insert(server.id(), server); + + let mut locations_by_path = HashMap::new(); + locations_by_path.insert(table_path.clone(), vec![bucket_location.clone()]); + + let mut locations_by_bucket = HashMap::new(); + locations_by_bucket.insert(table_bucket, bucket_location); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert(table_path.clone(), build_table_info(table_path.clone(), table_id)); + + Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + ) + } + + #[tokio::test] + async fn re_enqueue_increments_attempts() -> Result<()> { + let config = Config::default(); + let accumulator = RecordAccumulator::new(config); + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let cluster = Arc::new(build_cluster(table_path.as_ref(), 1)); + let record = WriteRecord::new( + table_path.clone(), + GenericRow { + values: vec![Datum::Int32(1)], + }, + ); + + accumulator.append(&record, 0, &cluster, false).await?; + + let server = cluster.get_tablet_server(1).expect("server"); + let nodes = HashSet::from([server.clone()]); + let mut batches = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024).await?; + let mut drained = batches.remove(&1).expect("drained batches"); + let batch = drained.pop().expect("batch"); + assert_eq!(batch.write_batch.attempts(), 0); + + accumulator.re_enqueue(batch).await; + + let mut batches = accumulator.drain(cluster, &nodes, 1024 * 1024).await?; + let mut drained = batches.remove(&1).expect("drained batches"); + let batch = drained.pop().expect("batch"); + assert_eq!(batch.write_batch.attempts(), 1); + Ok(()) + } +} + #[allow(dead_code)] struct BucketAndWriteBatches { table_id: TableId, diff --git a/crates/fluss/src/client/write/batch.rs b/crates/fluss/src/client/write/batch.rs index ba04db4..5039aa7 100644 --- a/crates/fluss/src/client/write/batch.rs +++ b/crates/fluss/src/client/write/batch.rs @@ -22,7 +22,9 @@ use crate::compression::ArrowCompressionInfo; use crate::error::Result; use crate::metadata::{DataType, TablePath}; use crate::record::MemoryLogRecordsArrowBuilder; +use parking_lot::Mutex; use std::cmp::max; +use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; #[allow(dead_code)] pub struct InnerWriteBatch { @@ -31,7 +33,8 @@ pub struct InnerWriteBatch { create_ms: i64, bucket_id: BucketId, results: BroadcastOnce, - completed: bool, + completed: AtomicBool, + attempts: AtomicI32, drained_ms: i64, } @@ -43,7 +46,8 @@ impl InnerWriteBatch { create_ms, bucket_id, results: Default::default(), - completed: Default::default(), + completed: AtomicBool::new(false), + attempts: AtomicI32::new(0), drained_ms: -1, } } @@ -53,15 +57,36 @@ impl InnerWriteBatch { } fn complete(&self, write_result: BatchWriteResult) -> bool { - if !self.completed { - self.results.broadcast(write_result); + if self + .completed + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + return false; } + self.results.broadcast(write_result); true } fn drained(&mut self, now_ms: i64) { self.drained_ms = max(self.drained_ms, now_ms); } + + fn table_path(&self) -> &TablePath { + &self.table_path + } + + fn attempts(&self) -> i32 { + self.attempts.load(Ordering::Acquire) + } + + fn re_enqueued(&self) { + self.attempts.fetch_add(1, Ordering::AcqRel); + } + + fn is_done(&self) -> bool { + self.completed.load(Ordering::Acquire) + } } pub enum WriteBatch { @@ -125,11 +150,28 @@ impl WriteBatch { pub fn batch_id(&self) -> i64 { self.inner_batch().batch_id } + + pub fn table_path(&self) -> &TablePath { + self.inner_batch().table_path() + } + + pub fn attempts(&self) -> i32 { + self.inner_batch().attempts() + } + + pub fn re_enqueued(&self) { + self.inner_batch().re_enqueued(); + } + + pub fn is_done(&self) -> bool { + self.inner_batch().is_done() + } } pub struct ArrowLogWriteBatch { pub write_batch: InnerWriteBatch, pub arrow_builder: MemoryLogRecordsArrowBuilder, + built_records: Mutex>>, } impl ArrowLogWriteBatch { @@ -153,6 +195,7 @@ impl ArrowLogWriteBatch { to_append_record_batch, arrow_compression_info, ), + built_records: Mutex::new(None), } } @@ -175,7 +218,13 @@ impl ArrowLogWriteBatch { } pub fn build(&self) -> Result> { - self.arrow_builder.build() + let mut cached = self.built_records.lock(); + if let Some(bytes) = cached.as_ref() { + return Ok(bytes.clone()); + } + let bytes = self.arrow_builder.build()?; + *cached = Some(bytes.clone()); + Ok(bytes) } pub fn is_closed(&self) -> bool { @@ -186,3 +235,34 @@ impl ArrowLogWriteBatch { self.arrow_builder.close() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::TablePath; + + #[test] + fn complete_only_once() { + let batch = InnerWriteBatch::new( + 1, + TablePath::new("db".to_string(), "tbl".to_string()), + 0, + 0, + ); + assert!(batch.complete(Ok(()))); + assert!(!batch.complete(Err(crate::client::broadcast::Error::Dropped))); + } + + #[test] + fn attempts_increment_on_reenqueue() { + let batch = InnerWriteBatch::new( + 1, + TablePath::new("db".to_string(), "tbl".to_string()), + 0, + 0, + ); + assert_eq!(batch.attempts(), 0); + batch.re_enqueued(); + assert_eq!(batch.attempts(), 1); + } +} diff --git a/crates/fluss/src/client/write/broadcast.rs b/crates/fluss/src/client/write/broadcast.rs index d2e7f0c..4b55385 100644 --- a/crates/fluss/src/client/write/broadcast.rs +++ b/crates/fluss/src/client/write/broadcast.rs @@ -28,6 +28,8 @@ pub type BatchWriteResult = Result<(), Error>; pub enum Error { #[error("BroadcastOnce dropped")] Dropped, + #[error("Write failed: {message} (code {code})")] + WriteFailed { code: i32, message: String }, } #[derive(Debug, Clone)] diff --git a/crates/fluss/src/client/write/bucket_assigner.rs b/crates/fluss/src/client/write/bucket_assigner.rs index 44b2673..256095a 100644 --- a/crates/fluss/src/client/write/bucket_assigner.rs +++ b/crates/fluss/src/client/write/bucket_assigner.rs @@ -146,3 +146,98 @@ impl BucketAssigner for HashBucketAssigner { self.bucketing_function.bucketing(key, self.num_buckets) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::bucketing::BucketingFunction; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; + use crate::metadata::TableBucket; + use crate::metadata::{DataField, DataTypes, Schema, TableDescriptor, TableInfo, TablePath}; + use std::collections::HashMap; + + fn build_table_info(table_path: TablePath, table_id: i64, buckets: i32) -> TableInfo { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(buckets), vec![]) + .build() + .expect("descriptor build"); + TableInfo::of(table_path, table_id, 1, table_descriptor, 0, 0) + } + + fn build_cluster(table_path: &TablePath, table_id: i64, buckets: i32) -> Cluster { + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + + let mut servers = HashMap::new(); + servers.insert(server.id(), server.clone()); + + let mut locations_by_path = HashMap::new(); + let mut locations_by_bucket = HashMap::new(); + let mut bucket_locations = Vec::new(); + + for bucket_id in 0..buckets { + let table_bucket = TableBucket::new(table_id, bucket_id); + let bucket_location = + BucketLocation::new(table_bucket.clone(), Some(server.clone()), table_path.clone()); + bucket_locations.push(bucket_location.clone()); + locations_by_bucket.insert(table_bucket, bucket_location); + } + locations_by_path.insert(table_path.clone(), bucket_locations); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert( + table_path.clone(), + build_table_info(table_path.clone(), table_id, buckets), + ); + + Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + ) + } + + #[test] + fn sticky_bucket_assigner_picks_available_bucket() { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let cluster = build_cluster(&table_path, 1, 2); + let assigner = StickyBucketAssigner::new(table_path); + let bucket = assigner.assign_bucket(None, &cluster).expect("bucket"); + assert!(bucket >= 0 && bucket < 2); + + assigner.on_new_batch(&cluster, bucket); + let next_bucket = assigner.assign_bucket(None, &cluster).expect("bucket"); + assert!(next_bucket >= 0 && next_bucket < 2); + } + + #[test] + fn hash_bucket_assigner_requires_key() { + let assigner = HashBucketAssigner::new(3, ::of(None)); + let cluster = Cluster::default(); + let err = assigner.assign_bucket(None, &cluster).unwrap_err(); + assert!(matches!(err, crate::error::Error::IllegalArgument { .. })); + } + + #[test] + fn hash_bucket_assigner_hashes_key() { + let assigner = HashBucketAssigner::new(4, ::of(None)); + let cluster = Cluster::default(); + let bucket = assigner + .assign_bucket(Some(b"key"), &cluster) + .expect("bucket"); + assert!(bucket >= 0 && bucket < 4); + } +} diff --git a/crates/fluss/src/client/write/mod.rs b/crates/fluss/src/client/write/mod.rs index cd33586..19bf9c4 100644 --- a/crates/fluss/src/client/write/mod.rs +++ b/crates/fluss/src/client/write/mod.rs @@ -81,10 +81,16 @@ impl ResultHandle { } pub fn result(&self, batch_result: BatchWriteResult) -> Result<(), Error> { - // do nothing, just return empty result - batch_result.map_err(|e| Error::UnexpectedError { - message: format!("Fail to get write result {e:?}"), - source: None, + batch_result.map_err(|e| match e { + crate::client::broadcast::Error::WriteFailed { code, message } => { + Error::FlussAPIError { + api_error: crate::rpc::ApiError { code, message }, + } + } + crate::client::broadcast::Error::Dropped => Error::UnexpectedError { + message: "Fail to get write result because broadcast was dropped.".to_string(), + source: None, + }, }) } } diff --git a/crates/fluss/src/client/write/sender.rs b/crates/fluss/src/client/write/sender.rs index 462a846..1a193ae 100644 --- a/crates/fluss/src/client/write/sender.rs +++ b/crates/fluss/src/client/write/sender.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::client::broadcast; use crate::client::metadata::Metadata; use crate::client::{ReadyWriteBatch, RecordAccumulator}; -use crate::error::Error; -use crate::error::Result; -use crate::metadata::TableBucket; +use crate::error::{FlussError, Result}; +use crate::metadata::{TableBucket, TablePath}; use crate::proto::ProduceLogResponse; use crate::rpc::message::ProduceLogRequest; +use log::warn; use parking_lot::Mutex; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; @@ -32,7 +33,7 @@ pub struct Sender { running: bool, metadata: Arc, accumulator: Arc, - in_flight_batches: Mutex>>>, + in_flight_batches: Mutex>>, max_request_size: i32, ack: i16, max_request_timeout_ms: i32, @@ -99,31 +100,30 @@ impl Sender { if !batches.is_empty() { self.add_to_inflight_batches(&batches); - self.send_write_requests(&batches).await?; + self.send_write_requests(batches).await?; } Ok(()) } - fn add_to_inflight_batches(&self, batches: &HashMap>>) { + fn add_to_inflight_batches(&self, batches: &HashMap>) { let mut in_flight = self.in_flight_batches.lock(); for batch_list in batches.values() { for batch in batch_list { in_flight .entry(batch.table_bucket.clone()) .or_default() - .push(batch.clone()); + .push(batch.write_batch.batch_id()); } } } async fn send_write_requests( &self, - collated: &HashMap>>, + collated: HashMap>, ) -> Result<()> { for (leader_id, batches) in collated { - self.send_write_request(*leader_id, self.ack, batches) - .await?; + self.send_write_request(leader_id, self.ack, batches).await?; } Ok(()) } @@ -132,78 +132,530 @@ impl Sender { &self, destination: i32, acks: i16, - batches: &Vec>, + batches: Vec, ) -> Result<()> { if batches.is_empty() { return Ok(()); } - let mut records_by_bucket = HashMap::new(); - let mut write_batch_by_table = HashMap::new(); + let mut records_by_bucket: HashMap = HashMap::new(); + let mut write_batch_by_table: HashMap> = HashMap::new(); for batch in batches { - records_by_bucket.insert(batch.table_bucket.clone(), batch.clone()); + let table_bucket = batch.table_bucket.clone(); write_batch_by_table - .entry(batch.table_bucket.table_id()) - .or_insert_with(Vec::new) - .push(batch); + .entry(table_bucket.table_id()) + .or_default() + .push(table_bucket.clone()); + records_by_bucket.insert(table_bucket, batch); } let cluster = self.metadata.get_cluster(); let destination_node = - cluster - .get_tablet_server(destination) - .ok_or(Error::LeaderNotAvailable { - message: format!("destination node not found in metadata cache {destination}."), - })?; - let connection = self.metadata.get_connection(destination_node).await?; + match cluster.get_tablet_server(destination) { + Some(node) => node, + None => { + self.handle_batches_with_error( + records_by_bucket.into_values().collect(), + FlussError::LeaderNotAvailableException, + format!("Destination node not found in metadata cache {destination}."), + ) + .await?; + return Ok(()); + } + }; + let connection = match self.metadata.get_connection(destination_node).await { + Ok(connection) => connection, + Err(e) => { + self.handle_batches_with_error( + records_by_bucket.into_values().collect(), + FlussError::NetworkException, + format!("Failed to connect destination node {destination}: {e}"), + ) + .await?; + return Ok(()); + } + }; + + for (table_id, table_buckets) in write_batch_by_table { + let request_batches: Vec<&ReadyWriteBatch> = table_buckets + .iter() + .filter_map(|bucket| records_by_bucket.get(bucket)) + .collect(); + if request_batches.is_empty() { + continue; + } + let request = match ProduceLogRequest::new( + table_id, + acks, + self.max_request_timeout_ms, + request_batches.as_slice(), + ) { + Ok(request) => request, + Err(e) => { + self.handle_batches_with_error( + table_buckets + .iter() + .filter_map(|bucket| records_by_bucket.remove(bucket)) + .collect(), + FlussError::UnknownServerError, + format!("Failed to build produce request: {e}"), + ) + .await?; + continue; + } + }; + + let response = match connection.request(request).await { + Ok(response) => response, + Err(e) => { + self.handle_batches_with_error( + table_buckets + .iter() + .filter_map(|bucket| records_by_bucket.remove(bucket)) + .collect(), + FlussError::NetworkException, + format!("Failed to send produce request: {e}"), + ) + .await?; + continue; + } + }; - for (table_id, write_batches) in write_batch_by_table { - let request = - ProduceLogRequest::new(table_id, acks, self.max_request_timeout_ms, write_batches)?; - let response = connection.request(request).await?; - self.handle_produce_response(table_id, &records_by_bucket, response)? + self.handle_produce_response( + table_id, + &table_buckets, + &mut records_by_bucket, + response, + ) + .await?; } Ok(()) } - fn handle_produce_response( + async fn handle_produce_response( &self, table_id: i64, - records_by_bucket: &HashMap>, + request_buckets: &[TableBucket], + records_by_bucket: &mut HashMap, response: ProduceLogResponse, ) -> Result<()> { + let mut invalid_metadata_tables: HashSet = HashSet::new(); + let mut pending_buckets: HashSet = request_buckets.iter().cloned().collect(); for produce_log_response_for_bucket in response.buckets_resp.iter() { let tb = TableBucket::new(table_id, produce_log_response_for_bucket.bucket_id); - let ready_batch = records_by_bucket.get(&tb).unwrap(); + let Some(ready_batch) = records_by_bucket.remove(&tb) else { + warn!("Missing ready batch for table bucket {tb}"); + continue; + }; + pending_buckets.remove(&tb); + if let Some(error_code) = produce_log_response_for_bucket.error_code { - todo!("handle_produce_response error: {}", error_code) + if error_code == FlussError::None.code() { + self.complete_batch(ready_batch); + continue; + } + + let error = FlussError::for_code(error_code); + let message = produce_log_response_for_bucket + .error_message + .clone() + .unwrap_or_else(|| error.message().to_string()); + if let Some(table_path) = + self.handle_write_batch_error(ready_batch, error, message).await? + { + invalid_metadata_tables.insert(table_path); + } } else { self.complete_batch(ready_batch) } } + if !pending_buckets.is_empty() { + for bucket in pending_buckets { + if let Some(ready_batch) = records_by_bucket.remove(&bucket) { + let message = format!( + "Missing response for table bucket {bucket} in produce response." + ); + let error = FlussError::UnknownServerError; + if let Some(table_path) = + self.handle_write_batch_error(ready_batch, error, message).await? + { + invalid_metadata_tables.insert(table_path); + } + } + } + } + self.update_metadata_if_needed(invalid_metadata_tables).await; Ok(()) } - fn complete_batch(&self, ready_write_batch: &Arc) { - if ready_write_batch.write_batch.complete(Ok(())) { - // remove from in flight batches - let mut in_flight_guard = self.in_flight_batches.lock(); - if let Some(in_flight) = in_flight_guard.get_mut(&ready_write_batch.table_bucket) { - in_flight.retain(|b| !Arc::ptr_eq(b, ready_write_batch)); - if in_flight.is_empty() { - in_flight_guard.remove(&ready_write_batch.table_bucket); - } - } + fn complete_batch(&self, ready_write_batch: ReadyWriteBatch) { + self.finish_batch(ready_write_batch, Ok(())); + } + + fn fail_batch(&self, ready_write_batch: ReadyWriteBatch, error: broadcast::Error) { + self.finish_batch(ready_write_batch, Err(error)); + } + + fn finish_batch(&self, ready_write_batch: ReadyWriteBatch, result: broadcast::Result<()>) { + if ready_write_batch.write_batch.complete(result) { + self.remove_from_inflight_batches(&ready_write_batch); // remove from incomplete batches self.accumulator .remove_incomplete_batches(ready_write_batch.write_batch.batch_id()) } } - pub async fn close(&mut self) { - self.running = false; + async fn handle_batches_with_error( + &self, + batches: Vec, + error: FlussError, + message: String, + ) -> Result<()> { + let mut invalid_metadata_tables: HashSet = HashSet::new(); + for batch in batches { + if let Some(table_path) = + self.handle_write_batch_error(batch, error, message.clone()).await? + { + invalid_metadata_tables.insert(table_path); + } + } + self.update_metadata_if_needed(invalid_metadata_tables).await; + Ok(()) + } + + async fn handle_write_batch_error( + &self, + ready_write_batch: ReadyWriteBatch, + error: FlussError, + message: String, + ) -> Result> { + let table_path = ready_write_batch.write_batch.table_path().clone(); + if self.can_retry(&ready_write_batch, error) { + warn!( + "Retrying write batch for {table_path} on bucket {} after error {error:?}: {message}", + ready_write_batch.table_bucket.bucket_id() + ); + self.re_enqueue_batch(ready_write_batch).await; + return Ok(Self::is_invalid_metadata_error(error).then_some(table_path)); + } + + if error == FlussError::DuplicateSequenceException { + self.complete_batch(ready_write_batch); + return Ok(None); + } + + self.fail_batch( + ready_write_batch, + broadcast::Error::WriteFailed { + code: error.code(), + message, + }, + ); + Ok(Self::is_invalid_metadata_error(error).then_some(table_path)) + } + + async fn re_enqueue_batch(&self, ready_write_batch: ReadyWriteBatch) { + self.remove_from_inflight_batches(&ready_write_batch); + self.accumulator.re_enqueue(ready_write_batch).await; + } + + fn remove_from_inflight_batches(&self, ready_write_batch: &ReadyWriteBatch) { + let batch_id = ready_write_batch.write_batch.batch_id(); + let mut in_flight_guard = self.in_flight_batches.lock(); + if let Some(in_flight) = in_flight_guard.get_mut(&ready_write_batch.table_bucket) { + in_flight.retain(|id| *id != batch_id); + if in_flight.is_empty() { + in_flight_guard.remove(&ready_write_batch.table_bucket); + } + } + } + + fn can_retry(&self, ready_write_batch: &ReadyWriteBatch, error: FlussError) -> bool { + ready_write_batch.write_batch.attempts() < self.retries + && !ready_write_batch.write_batch.is_done() + && Self::is_retriable_error(error) + } + + async fn update_metadata_if_needed(&self, table_paths: HashSet) { + if table_paths.is_empty() { + return; + } + let table_path_refs: HashSet<&TablePath> = table_paths.iter().collect(); + if let Err(e) = self.metadata.update_tables_metadata(&table_path_refs).await { + warn!("Failed to update metadata after write error: {e:?}"); + } + } + + fn is_invalid_metadata_error(error: FlussError) -> bool { + matches!( + error, + FlussError::NotLeaderOrFollower + | FlussError::UnknownTableOrBucketException + | FlussError::LeaderNotAvailableException + | FlussError::NetworkException + ) + } + + fn is_retriable_error(error: FlussError) -> bool { + matches!( + error, + FlussError::NetworkException + | FlussError::NotLeaderOrFollower + | FlussError::UnknownTableOrBucketException + | FlussError::LeaderNotAvailableException + | FlussError::LogStorageException + | FlussError::KvStorageException + | FlussError::StorageException + | FlussError::RequestTimeOut + | FlussError::NotEnoughReplicasAfterAppendException + | FlussError::NotEnoughReplicasException + | FlussError::CorruptMessage + | FlussError::CorruptRecordException + ) + } + +pub async fn close(&mut self) { + self.running = false; +} +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::WriteRecord; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; + use crate::config::Config; + use crate::metadata::{DataField, DataTypes, Schema, TableDescriptor, TableInfo, TablePath}; + use crate::row::{Datum, GenericRow}; + use crate::rpc::FlussError; + use crate::proto::{PbProduceLogRespForBucket, ProduceLogResponse}; + use std::collections::HashSet; + + fn build_table_info(table_path: TablePath, table_id: i64) -> TableInfo { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor build"); + TableInfo::of(table_path, table_id, 1, table_descriptor, 0, 0) + } + + fn build_cluster(table_path: &TablePath, table_id: i64) -> Arc { + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let table_bucket = TableBucket::new(table_id, 0); + let bucket_location = + BucketLocation::new(table_bucket.clone(), Some(server.clone()), table_path.clone()); + + let mut servers = HashMap::new(); + servers.insert(server.id(), server); + + let mut locations_by_path = HashMap::new(); + locations_by_path.insert(table_path.clone(), vec![bucket_location.clone()]); + + let mut locations_by_bucket = HashMap::new(); + locations_by_bucket.insert(table_bucket, bucket_location); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert( + table_path.clone(), + build_table_info(table_path.clone(), table_id), + ); + + Arc::new(Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + )) + } + + async fn build_ready_batch( + accumulator: &RecordAccumulator, + cluster: Arc, + table_path: Arc, + ) -> Result<(ReadyWriteBatch, crate::client::ResultHandle)> { + let record = WriteRecord::new( + table_path, + GenericRow { + values: vec![Datum::Int32(1)], + }, + ); + let result = accumulator.append(&record, 0, &cluster, false).await?; + let result_handle = result.result_handle.expect("result handle"); + let server = cluster.get_tablet_server(1).expect("server"); + let nodes = HashSet::from([server.clone()]); + let mut batches = accumulator.drain(cluster, &nodes, 1024 * 1024).await?; + let mut drained = batches.remove(&1).expect("drained batches"); + let batch = drained.pop().expect("batch"); + Ok((batch, result_handle)) + } + + #[tokio::test] + async fn handle_write_batch_error_retries() -> Result<()> { + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let cluster = build_cluster(table_path.as_ref(), 1); + let metadata = Arc::new(Metadata::new_for_test(cluster.clone())); + let accumulator = Arc::new(RecordAccumulator::new(Config::default())); + let sender = Sender::new( + metadata, + accumulator.clone(), + 1024 * 1024, + 1000, + 1, + 1, + ); + + let (batch, _handle) = + build_ready_batch(accumulator.as_ref(), cluster.clone(), table_path.clone()).await?; + let mut inflight = HashMap::new(); + inflight.insert(1, vec![batch]); + sender.add_to_inflight_batches(&inflight); + let batch = inflight.remove(&1).unwrap().pop().unwrap(); + + sender + .handle_write_batch_error( + batch, + FlussError::RequestTimeOut, + "timeout".to_string(), + ) + .await?; + + let server = cluster.get_tablet_server(1).expect("server"); + let nodes = HashSet::from([server.clone()]); + let mut batches = accumulator.drain(cluster, &nodes, 1024 * 1024).await?; + let mut drained = batches.remove(&1).expect("drained batches"); + let batch = drained.pop().expect("batch"); + assert_eq!(batch.write_batch.attempts(), 1); + Ok(()) + } + + #[tokio::test] + async fn handle_write_batch_error_fails() -> Result<()> { + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let cluster = build_cluster(table_path.as_ref(), 1); + let metadata = Arc::new(Metadata::new_for_test(cluster.clone())); + let accumulator = Arc::new(RecordAccumulator::new(Config::default())); + let sender = Sender::new( + metadata, + accumulator.clone(), + 1024 * 1024, + 1000, + 1, + 0, + ); + + let (batch, handle) = + build_ready_batch(accumulator.as_ref(), cluster.clone(), table_path).await?; + sender + .handle_write_batch_error( + batch, + FlussError::InvalidTableException, + "invalid".to_string(), + ) + .await?; + + let batch_result = handle.wait().await?; + assert!(matches!( + batch_result, + Err(broadcast::Error::WriteFailed { code, .. }) + if code == FlussError::InvalidTableException.code() + )); + Ok(()) + } + + #[tokio::test] + async fn handle_produce_response_missing_bucket_fails() -> Result<()> { + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let cluster = build_cluster(table_path.as_ref(), 1); + let metadata = Arc::new(Metadata::new_for_test(cluster.clone())); + let accumulator = Arc::new(RecordAccumulator::new(Config::default())); + let sender = Sender::new( + metadata, + accumulator.clone(), + 1024 * 1024, + 1000, + 1, + 0, + ); + + let (batch, handle) = + build_ready_batch(accumulator.as_ref(), cluster, table_path).await?; + let request_buckets = vec![batch.table_bucket.clone()]; + let mut records_by_bucket = HashMap::new(); + records_by_bucket.insert(batch.table_bucket.clone(), batch); + + let response = ProduceLogResponse { + buckets_resp: vec![PbProduceLogRespForBucket { + bucket_id: 1, + error_code: None, + ..Default::default() + }], + }; + + sender + .handle_produce_response(1, &request_buckets, &mut records_by_bucket, response) + .await?; + + let batch_result = handle.wait().await?; + assert!(matches!( + batch_result, + Err(broadcast::Error::WriteFailed { code, .. }) + if code == FlussError::UnknownServerError.code() + )); + Ok(()) + } + + #[tokio::test] + async fn handle_produce_response_duplicate_sequence_completes() -> Result<()> { + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let cluster = build_cluster(table_path.as_ref(), 1); + let metadata = Arc::new(Metadata::new_for_test(cluster.clone())); + let accumulator = Arc::new(RecordAccumulator::new(Config::default())); + let sender = Sender::new( + metadata, + accumulator.clone(), + 1024 * 1024, + 1000, + 1, + 0, + ); + + let (batch, handle) = + build_ready_batch(accumulator.as_ref(), cluster, table_path).await?; + let request_buckets = vec![batch.table_bucket.clone()]; + let mut records_by_bucket = HashMap::new(); + records_by_bucket.insert(batch.table_bucket.clone(), batch); + + let response = ProduceLogResponse { + buckets_resp: vec![PbProduceLogRespForBucket { + bucket_id: 0, + error_code: Some(FlussError::DuplicateSequenceException.code()), + error_message: Some("dup".to_string()), + ..Default::default() + }], + }; + + sender + .handle_produce_response(1, &request_buckets, &mut records_by_bucket, response) + .await?; + + let batch_result = handle.wait().await?; + assert!(matches!(batch_result, Ok(()))); + Ok(()) } } diff --git a/crates/fluss/src/error.rs b/crates/fluss/src/error.rs index 0f4b1b6..ac6d145 100644 --- a/crates/fluss/src/error.rs +++ b/crates/fluss/src/error.rs @@ -98,6 +98,12 @@ pub enum Error { )] IoUnsupported { message: String }, + #[snafu( + visibility(pub(crate)), + display("Fluss hitting wakeup error {}.", message) + )] + WakeupError { message: String }, + #[snafu( visibility(pub(crate)), display("Fluss hitting leader not available error {}.", message) diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index 5a5115e..f0d4b28 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -17,7 +17,7 @@ use crate::client::{Record, WriteRecord}; use crate::compression::ArrowCompressionInfo; -use crate::error::Result; +use crate::error::{Error, Result}; use crate::metadata::DataType; use crate::record::{ChangeType, ScanRecord}; use crate::row::{ColumnarRow, GenericRow}; @@ -446,8 +446,18 @@ impl LogRecordBatch { } pub fn ensure_valid(&self) -> Result<()> { - // todo - Ok(()) + if self.is_valid() { + return Ok(()); + } + Err(Error::UnexpectedError { + message: format!( + "Record batch at offset {} is invalid (checksum={}, computed={}).", + self.base_log_offset(), + self.checksum(), + self.compute_checksum() + ), + source: None, + }) } pub fn is_valid(&self) -> bool { @@ -457,8 +467,7 @@ impl LogRecordBatch { fn compute_checksum(&self) -> u32 { let start = SCHEMA_ID_OFFSET; - let end = start + self.data.len(); - crc32c(&self.data[start..end]) + crc32c(&self.data[start..]) } fn attributes(&self) -> u8 { @@ -471,12 +480,12 @@ impl LogRecordBatch { pub fn checksum(&self) -> u32 { let offset = CRC_OFFSET; - LittleEndian::read_u32(&self.data[offset..offset + CRC_OFFSET]) + LittleEndian::read_u32(&self.data[offset..offset + CRC_LENGTH]) } pub fn schema_id(&self) -> i16 { let offset = SCHEMA_ID_OFFSET; - LittleEndian::read_i16(&self.data[offset..offset + SCHEMA_ID_OFFSET]) + LittleEndian::read_i16(&self.data[offset..offset + SCHEMA_ID_LENGTH]) } pub fn base_log_offset(&self) -> i64 { @@ -508,6 +517,10 @@ impl LogRecordBatch { return Ok(LogRecordIterator::empty()); } + if !read_context.is_projection_pushdowned() { + self.ensure_valid()?; + } + let data = &self.data[RECORDS_OFFSET..]; let record_batch = read_context.record_batch(data)?; @@ -528,6 +541,10 @@ impl LogRecordBatch { return Ok(LogRecordIterator::empty()); } + if !read_context.is_projection_pushdowned() { + self.ensure_valid()?; + } + let data = &self.data[RECORDS_OFFSET..]; let record_batch = read_context.record_batch_for_remote_log(data)?; @@ -758,8 +775,10 @@ impl ReadContext { arrow_schema: SchemaRef, projected_fields: Vec, is_from_remote: bool, - ) -> ReadContext { - let target_schema = Self::project_schema(arrow_schema.clone(), projected_fields.as_slice()); + ) -> Result { + Self::validate_projection(&arrow_schema, projected_fields.as_slice())?; + let target_schema = + Self::project_schema(arrow_schema.clone(), projected_fields.as_slice())?; // the logic is little bit of hard to understand, to refactor it to follow // java side let (need_do_reorder, sorted_fields) = { @@ -782,16 +801,20 @@ impl ReadContext { // Calculate reordering indexes to transform from sorted order to user-requested order let mut reordering_indexes = Vec::with_capacity(projected_fields.len()); for &original_idx in &projected_fields { - let pos = sorted_fields - .binary_search(&original_idx) - .expect("projection index should exist in sorted list"); + let pos = sorted_fields.binary_search(&original_idx).map_err(|_| { + Error::IllegalArgument { + message: format!( + "Projection index {original_idx} is invalid for the current schema." + ), + } + })?; reordering_indexes.push(pos); } Projection { ordered_schema: Self::project_schema( arrow_schema.clone(), sorted_fields.as_slice(), - ), + )?, projected_fields, ordered_fields: sorted_fields, reordering_indexes, @@ -802,7 +825,7 @@ impl ReadContext { ordered_schema: Self::project_schema( arrow_schema.clone(), projected_fields.as_slice(), - ), + )?, ordered_fields: projected_fields.clone(), projected_fields, reordering_indexes: vec![], @@ -811,21 +834,34 @@ impl ReadContext { } }; - ReadContext { + Ok(ReadContext { target_schema, full_schema: arrow_schema, projection: Some(project), is_from_remote, + }) + } + + fn validate_projection(schema: &SchemaRef, projected_fields: &[usize]) -> Result<()> { + let field_count = schema.fields().len(); + for &index in projected_fields { + if index >= field_count { + return Err(Error::IllegalArgument { + message: format!( + "Projection index {index} is out of bounds for schema with {field_count} fields." + ), + }); + } } + Ok(()) } - pub fn project_schema(schema: SchemaRef, projected_fields: &[usize]) -> SchemaRef { - // todo: handle the exception - SchemaRef::new( - schema - .project(projected_fields) - .expect("can't project schema"), - ) + pub fn project_schema(schema: SchemaRef, projected_fields: &[usize]) -> Result { + Ok(SchemaRef::new( + schema.project(projected_fields).map_err(|e| Error::IllegalArgument { + message: format!("Invalid projection: {e}"), + })?, + )) } pub fn project_fields(&self) -> Option<&[usize]> { @@ -840,6 +876,10 @@ impl ReadContext { .map(|p| p.ordered_fields.as_slice()) } + pub fn is_projection_pushdowned(&self) -> bool { + !self.is_from_remote && self.projection.is_some() + } + pub fn record_batch(&self, data: &[u8]) -> Result { let (batch_metadata, body_buffer, version) = parse_ipc_message(data)?; @@ -1014,6 +1054,8 @@ pub struct MyVec(pub StreamReader); mod tests { use super::*; use crate::metadata::DataTypes; + use crate::metadata::DataField; + use crate::error::Error; #[test] fn test_to_array_type() { @@ -1185,6 +1227,18 @@ mod tests { ); } + #[test] + fn projection_rejects_out_of_bounds_index() { + let row_type = DataTypes::row(vec![ + DataField::new("id".to_string(), DataTypes::int(), None), + DataField::new("name".to_string(), DataTypes::string(), None), + ]); + let schema = to_arrow_schema(&row_type); + let result = ReadContext::with_projection_pushdown(schema, vec![0, 2], false); + + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + } + fn le_bytes(vals: &[u32]) -> Vec { let mut out = Vec::with_capacity(vals.len() * 4); for &v in vals { diff --git a/crates/fluss/src/record/mod.rs b/crates/fluss/src/record/mod.rs index 35928ea..7931080 100644 --- a/crates/fluss/src/record/mod.rs +++ b/crates/fluss/src/record/mod.rs @@ -181,3 +181,65 @@ impl IntoIterator for ScanRecords { .into_iter() } } + +#[cfg(test)] +mod tests { + use super::*; + use ::arrow::array::{Int32Array, RecordBatch}; + use ::arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + fn make_row(values: Vec, row_id: usize) -> ColumnarRow { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]) + .expect("record batch"); + ColumnarRow::new_with_row_id(Arc::new(batch), row_id) + } + + #[test] + fn change_type_round_trip() { + let cases = [ + (ChangeType::AppendOnly, "+A", 0), + (ChangeType::Insert, "+I", 1), + (ChangeType::UpdateBefore, "-U", 2), + (ChangeType::UpdateAfter, "+U", 3), + (ChangeType::Delete, "-D", 4), + ]; + + for (change_type, short, byte) in cases { + assert_eq!(change_type.short_string(), short); + assert_eq!(change_type.to_byte_value(), byte); + assert_eq!(ChangeType::from_byte_value(byte).unwrap(), change_type); + } + + let err = ChangeType::from_byte_value(9).unwrap_err(); + assert!(err.contains("Unsupported byte value")); + } + + #[test] + fn scan_records_counts_and_iterates() { + let bucket0 = TableBucket::new(1, 0); + let bucket1 = TableBucket::new(1, 1); + let record0 = ScanRecord::new(make_row(vec![10, 11], 0), 5, 7, ChangeType::Insert); + let record1 = ScanRecord::new(make_row(vec![10, 11], 1), 6, 8, ChangeType::Delete); + + let mut records = HashMap::new(); + records.insert(bucket0.clone(), vec![record0.clone(), record1.clone()]); + + let scan_records = ScanRecords::new(records); + assert_eq!(scan_records.records(&bucket0).len(), 2); + assert!(scan_records.records(&bucket1).is_empty()); + assert_eq!(scan_records.count(), 2); + + let collected: Vec<_> = scan_records.into_iter().collect(); + assert_eq!(collected.len(), 2); + } + + #[test] + fn scan_record_default_values() { + let record = ScanRecord::new_default(make_row(vec![1], 0)); + assert_eq!(record.offset(), -1); + assert_eq!(record.timestamp(), -1); + assert_eq!(record.change_type(), &ChangeType::Insert); + } +} diff --git a/crates/fluss/src/row/column.rs b/crates/fluss/src/row/column.rs index 31f0fdf..ddc2a3b 100644 --- a/crates/fluss/src/row/column.rs +++ b/crates/fluss/src/row/column.rs @@ -166,3 +166,67 @@ impl InternalRow for ColumnarRow { .value(self.row_id) } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + BinaryArray, BooleanArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int8Array, + Int16Array, Int32Array, Int64Array, StringArray, + }; + use arrow::datatypes::{DataType, Field, Schema}; + + #[test] + fn columnar_row_reads_values() { + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Boolean, false), + Field::new("i8", DataType::Int8, false), + Field::new("i16", DataType::Int16, false), + Field::new("i32", DataType::Int32, false), + Field::new("i64", DataType::Int64, false), + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + Field::new("s", DataType::Utf8, false), + Field::new("bin", DataType::Binary, false), + Field::new("char", DataType::FixedSizeBinary(2), false), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(BooleanArray::from(vec![true])), + Arc::new(Int8Array::from(vec![1])), + Arc::new(Int16Array::from(vec![2])), + Arc::new(Int32Array::from(vec![3])), + Arc::new(Int64Array::from(vec![4])), + Arc::new(Float32Array::from(vec![1.25])), + Arc::new(Float64Array::from(vec![2.5])), + Arc::new(StringArray::from(vec!["hello"])), + Arc::new(BinaryArray::from(vec![b"data".as_slice()])), + Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![Some(b"ab".as_slice())].into_iter(), + 2, + ) + .expect("fixed array"), + ), + ], + ) + .expect("record batch"); + + let mut row = ColumnarRow::new(Arc::new(batch)); + assert_eq!(row.get_field_count(), 10); + assert_eq!(row.get_boolean(0), true); + assert_eq!(row.get_byte(1), 1); + assert_eq!(row.get_short(2), 2); + assert_eq!(row.get_int(3), 3); + assert_eq!(row.get_long(4), 4); + assert_eq!(row.get_float(5), 1.25); + assert_eq!(row.get_double(6), 2.5); + assert_eq!(row.get_string(7), "hello"); + assert_eq!(row.get_bytes(8), b"data"); + assert_eq!(row.get_char(9, 2), "ab"); + row.set_row_id(0); + assert_eq!(row.get_row_id(), 0); + } +} diff --git a/crates/fluss/src/row/datum.rs b/crates/fluss/src/row/datum.rs index 1ea3933..16d8935 100644 --- a/crates/fluss/src/row/datum.rs +++ b/crates/fluss/src/row/datum.rs @@ -432,3 +432,69 @@ impl Date { date.day() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Int32Builder, StringBuilder}; + + #[test] + fn datum_accessors_and_conversions() { + let datum = Datum::String("value"); + assert_eq!(datum.as_str(), "value"); + assert!(!datum.is_null()); + + let blob = Blob::from(vec![1, 2, 3]); + let datum = Datum::Blob(blob); + assert_eq!(datum.as_blob(), &[1, 2, 3]); + + assert!(Datum::Null.is_null()); + + let datum = Datum::Int32(42); + let value: i32 = (&datum).try_into().unwrap(); + assert_eq!(value, 42); + let value: std::result::Result = (&datum).try_into(); + assert!(value.is_err()); + } + + #[test] + fn datum_append_to_builder() { + let mut builder = Int32Builder::new(); + Datum::Null.append_to(&mut builder).unwrap(); + Datum::Int32(5).append_to(&mut builder).unwrap(); + let array = builder.finish(); + assert!(array.is_null(0)); + assert_eq!(array.value(1), 5); + + let mut builder = StringBuilder::new(); + let err = Datum::Int32(1).append_to(&mut builder).unwrap_err(); + assert!(matches!(err, crate::error::Error::RowConvertError { .. })); + + let mut builder = Int32Builder::new(); + let err = Datum::Date(Date::new(0)) + .append_to(&mut builder) + .unwrap_err(); + assert!(matches!(err, crate::error::Error::RowConvertError { .. })); + } + + #[test] + #[should_panic] + fn datum_as_str_panics_on_non_string() { + let _ = Datum::Int32(1).as_str(); + } + + #[test] + #[should_panic] + fn datum_as_blob_panics_on_non_blob() { + let _ = Datum::Int16(1).as_blob(); + } + + #[test] + fn date_components() { + let date = Date::new(0); + assert_eq!(date.get_inner(), 0); + assert_eq!(date.year(), 1970); + assert_eq!(date.month(), 1); + assert_eq!(date.day(), 1); + } +} diff --git a/crates/fluss/src/rpc/api_key.rs b/crates/fluss/src/rpc/api_key.rs index b11647f..c515396 100644 --- a/crates/fluss/src/rpc/api_key.rs +++ b/crates/fluss/src/rpc/api_key.rs @@ -85,3 +85,41 @@ impl From for i16 { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn api_key_round_trip() { + let cases = [ + (1001, ApiKey::CreateDatabase), + (1002, ApiKey::DropDatabase), + (1003, ApiKey::ListDatabases), + (1004, ApiKey::DatabaseExists), + (1005, ApiKey::CreateTable), + (1006, ApiKey::DropTable), + (1007, ApiKey::GetTable), + (1008, ApiKey::ListTables), + (1010, ApiKey::TableExists), + (1012, ApiKey::MetaData), + (1014, ApiKey::ProduceLog), + (1015, ApiKey::FetchLog), + (1021, ApiKey::ListOffsets), + (1025, ApiKey::GetFileSystemSecurityToken), + (1032, ApiKey::GetLatestLakeSnapshot), + (1035, ApiKey::GetDatabaseInfo), + ]; + + for (raw, key) in cases { + assert_eq!(ApiKey::from(raw), key); + let mapped: i16 = key.into(); + assert_eq!(mapped, raw); + } + + let unknown = ApiKey::from(9999); + assert_eq!(unknown, ApiKey::Unknown(9999)); + let mapped: i16 = unknown.into(); + assert_eq!(mapped, 9999); + } +} diff --git a/crates/fluss/src/rpc/api_version.rs b/crates/fluss/src/rpc/api_version.rs index 395c45c..f009d69 100644 --- a/crates/fluss/src/rpc/api_version.rs +++ b/crates/fluss/src/rpc/api_version.rs @@ -52,3 +52,28 @@ impl std::fmt::Display for ApiVersionRange { write!(f, "{}:{}", self.min, self.max) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn api_version_display() { + let version = ApiVersion(3); + assert_eq!(version.to_string(), "3"); + } + + #[test] + fn api_version_range_accessors() { + let range = ApiVersionRange::new(ApiVersion(1), ApiVersion(4)); + assert_eq!(range.min(), ApiVersion(1)); + assert_eq!(range.max(), ApiVersion(4)); + assert_eq!(range.to_string(), "1:4"); + } + + #[test] + #[should_panic] + fn api_version_range_panics_on_invalid_bounds() { + let _ = ApiVersionRange::new(ApiVersion(4), ApiVersion(1)); + } +} diff --git a/crates/fluss/src/rpc/convert.rs b/crates/fluss/src/rpc/convert.rs index 6feb7eb..1862589 100644 --- a/crates/fluss/src/rpc/convert.rs +++ b/crates/fluss/src/rpc/convert.rs @@ -41,3 +41,51 @@ pub fn from_pb_table_path(pb_table_path: &PbTablePath) -> TablePath { pb_table_path.table_name.to_string(), ) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::{PbServerNode, PbTablePath}; + + #[test] + fn table_path_round_trip() { + let table_path = TablePath::new("db".to_string(), "table".to_string()); + let pb = to_table_path(&table_path); + assert_eq!(pb.database_name, "db"); + assert_eq!(pb.table_name, "table"); + + let restored = from_pb_table_path(&pb); + assert_eq!(restored, table_path); + + let manual = PbTablePath { + database_name: "db2".to_string(), + table_name: "table2".to_string(), + }; + let restored = from_pb_table_path(&manual); + assert_eq!(restored.database(), "db2"); + assert_eq!(restored.table(), "table2"); + } + + #[test] + fn server_node_from_pb() { + let pb = PbServerNode { + node_id: 7, + host: "127.0.0.1".to_string(), + port: 9092, + listeners: None, + }; + let node = from_pb_server_node(pb, ServerType::TabletServer); + assert_eq!(node.id(), 7); + assert_eq!(node.url(), "127.0.0.1:9092"); + assert_eq!(node.uid(), "ts-7"); + + let pb = PbServerNode { + node_id: 3, + host: "localhost".to_string(), + port: 8123, + listeners: None, + }; + let node = from_pb_server_node(pb, ServerType::CoordinatorServer); + assert_eq!(node.uid(), "cs-3"); + } +} diff --git a/crates/fluss/src/rpc/fluss_api_error.rs b/crates/fluss/src/rpc/fluss_api_error.rs index b26eb72..a501b99 100644 --- a/crates/fluss/src/rpc/fluss_api_error.rs +++ b/crates/fluss/src/rpc/fluss_api_error.rs @@ -369,3 +369,38 @@ impl From for FlussError { FlussError::for_code(api_error.code) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn for_code_maps_known_and_unknown() { + assert_eq!(FlussError::for_code(0), FlussError::None); + assert_eq!( + FlussError::for_code(FlussError::AuthorizationException.code()), + FlussError::AuthorizationException + ); + assert_eq!(FlussError::for_code(9999), FlussError::UnknownServerError); + } + + #[test] + fn to_api_error_uses_message() { + let err = FlussError::InvalidTableException.to_api_error(None); + assert_eq!(err.code, FlussError::InvalidTableException.code()); + assert!(err.message.contains("invalid table")); + } + + #[test] + fn error_response_conversion_round_trip() { + let response = ErrorResponse { + error_code: FlussError::TableNotExist.code(), + error_message: Some("missing".to_string()), + }; + let api_error = ApiError::from(response); + assert_eq!(api_error.code, FlussError::TableNotExist.code()); + assert_eq!(api_error.message, "missing"); + let fluss_error = FlussError::from(api_error); + assert_eq!(fluss_error, FlussError::TableNotExist); + } +} diff --git a/crates/fluss/src/rpc/message/list_offsets.rs b/crates/fluss/src/rpc/message/list_offsets.rs index 9ab1f14..1ee25d4 100644 --- a/crates/fluss/src/rpc/message/list_offsets.rs +++ b/crates/fluss/src/rpc/message/list_offsets.rs @@ -17,7 +17,7 @@ use crate::{impl_read_version_type, impl_write_version_type, proto}; -use crate::error::Error; +use crate::error::{ApiError, Error, FlussError}; use crate::error::Result as FlussResult; use crate::proto::ListOffsetsResponse; use crate::rpc::frame::ReadError; @@ -108,22 +108,55 @@ impl ListOffsetsResponse { self.buckets_resp .iter() .map(|resp| { - if resp.error_code.is_some() { - // todo: consider use another suitable error - Err(Error::UnexpectedError { + if let Some(error_code) = resp.error_code { + if error_code != FlussError::None.code() { + let error = FlussError::for_code(error_code); + let message = resp + .error_message + .clone() + .unwrap_or_else(|| error.message().to_string()); + return Err(Error::FlussAPIError { + api_error: ApiError { + code: error_code, + message, + }, + }); + } + } + + // if no error msg, offset must exists + resp.offset + .map(|offset| (resp.bucket_id, offset)) + .ok_or_else(|| Error::UnexpectedError { message: format!( - "Missing offset, error message: {}", - resp.error_message - .as_deref() - .unwrap_or("unknown server exception") + "Missing offset for bucket {} without error code.", + resp.bucket_id ), source: None, }) - } else { - // if no error msg, offset must exists - Ok((resp.bucket_id, resp.offset.unwrap())) - } }) .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::{ListOffsetsResponse, PbListOffsetsRespForBucket}; + + #[test] + fn offsets_returns_api_error_on_error_code() { + let response = ListOffsetsResponse { + buckets_resp: vec![PbListOffsetsRespForBucket { + bucket_id: 1, + error_code: Some(FlussError::TableNotExist.code()), + error_message: Some("missing".to_string()), + offset: None, + }], + ..Default::default() + }; + + let result = response.offsets(); + assert!(matches!(result, Err(Error::FlussAPIError { .. }))); + } +} diff --git a/crates/fluss/src/rpc/message/produce_log.rs b/crates/fluss/src/rpc/message/produce_log.rs index 39bfb3f..b377819 100644 --- a/crates/fluss/src/rpc/message/produce_log.rs +++ b/crates/fluss/src/rpc/message/produce_log.rs @@ -24,8 +24,6 @@ use crate::rpc::api_version::ApiVersion; use crate::rpc::frame::WriteError; use crate::rpc::message::{ReadVersionedType, RequestBody, WriteVersionedType}; use crate::{impl_read_version_type, impl_write_version_type, proto}; -use std::sync::Arc; - use crate::client::ReadyWriteBatch; use bytes::{Buf, BufMut}; use prost::Message; @@ -39,7 +37,7 @@ impl ProduceLogRequest { table_id: i64, ack: i16, max_request_timeout_ms: i32, - ready_batches: Vec<&Arc>, + ready_batches: &[&ReadyWriteBatch], ) -> FlussResult { let mut request = proto::ProduceLogRequest { table_id, diff --git a/crates/fluss/src/util/mod.rs b/crates/fluss/src/util/mod.rs index 5f67290..43d92a8 100644 --- a/crates/fluss/src/util/mod.rs +++ b/crates/fluss/src/util/mod.rs @@ -183,3 +183,57 @@ impl Default for FairBucketStatusMap { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn fair_bucket_status_map_tracks_order_and_size() { + let bucket0 = TableBucket::new(1, 0); + let bucket1 = TableBucket::new(1, 1); + + let mut map = FairBucketStatusMap::new(); + map.update_and_move_to_end(bucket0.clone(), 10); + map.update_and_move_to_end(bucket1.clone(), 20); + assert_eq!(map.size(), 2); + + let values: Vec = map + .bucket_status_values() + .into_iter() + .map(|value| **value) + .collect(); + assert_eq!(values, vec![10, 20]); + + map.move_to_end(bucket0.clone()); + let values: Vec = map + .bucket_status_values() + .into_iter() + .map(|value| **value) + .collect(); + assert_eq!(values, vec![20, 10]); + } + + #[test] + fn fair_bucket_status_map_mutations() { + let bucket0 = TableBucket::new(1, 0); + let bucket1 = TableBucket::new(2, 1); + + let mut map = FairBucketStatusMap::new(); + let mut input = HashMap::new(); + input.insert(bucket0.clone(), Arc::new(1)); + input.insert(bucket1.clone(), Arc::new(2)); + map.set(input); + + assert!(map.contains(&bucket0)); + assert!(map.contains(&bucket1)); + assert_eq!(map.bucket_set().len(), 2); + + map.remove(&bucket1); + assert_eq!(map.size(), 1); + + map.clear(); + assert_eq!(map.size(), 0); + } +}