diff --git a/crates/iceberg/src/arrow/delete_filter.rs b/crates/iceberg/src/arrow/delete_filter.rs index b853baa99..b29f80886 100644 --- a/crates/iceberg/src/arrow/delete_filter.rs +++ b/crates/iceberg/src/arrow/delete_filter.rs @@ -339,6 +339,7 @@ pub(crate) mod tests { project_field_ids: vec![], predicate: None, deletes: vec![pos_del_1, pos_del_2.clone()], + limit: None, }, FileScanTask { start: 0, @@ -350,6 +351,7 @@ pub(crate) mod tests { project_field_ids: vec![], predicate: None, deletes: vec![pos_del_3], + limit: None, }, ]; diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index 05aa6a4c9..c6f005474 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -176,8 +176,9 @@ impl ArrowReader { row_group_filtering_enabled: bool, row_selection_enabled: bool, ) -> Result { - let should_load_page_index = - (row_selection_enabled && task.predicate.is_some()) || !task.deletes.is_empty(); + let should_load_page_index = (row_selection_enabled && task.predicate.is_some()) + || !task.deletes.is_empty() + || task.limit.is_some(); let delete_filter_rx = delete_file_loader.load_deletes(&task.deletes, task.schema.clone()); @@ -310,6 +311,10 @@ impl ArrowReader { record_batch_stream_builder.with_row_groups(selected_row_group_indices); } + if let Some(limit) = task.limit { + record_batch_stream_builder = record_batch_stream_builder.with_limit(limit); + } + // Build the batch stream and send all the RecordBatches that it generates // to the requester. let record_batch_stream = @@ -341,7 +346,7 @@ impl ArrowReader { // Create the record batch stream builder, which wraps the parquet file reader let record_batch_stream_builder = ParquetRecordBatchStreamBuilder::new_with_options( parquet_file_reader, - ArrowReaderOptions::new(), + ArrowReaderOptions::new().with_page_index(should_load_page_index), ) .await?; Ok(record_batch_stream_builder) @@ -1745,6 +1750,7 @@ message schema { project_field_ids: vec![1], predicate: Some(predicate.bind(schema, true).unwrap()), deletes: vec![], + limit: None, })] .into_iter(), )) as FileScanTaskStream; diff --git a/crates/iceberg/src/scan/context.rs b/crates/iceberg/src/scan/context.rs index 3f7c29dbf..0f39e1845 100644 --- a/crates/iceberg/src/scan/context.rs +++ b/crates/iceberg/src/scan/context.rs @@ -42,6 +42,7 @@ pub(crate) struct ManifestFileContext { field_ids: Arc>, bound_predicates: Option>, + limit: Option, object_cache: Arc, snapshot_schema: SchemaRef, expression_evaluator_cache: Arc, @@ -59,6 +60,7 @@ pub(crate) struct ManifestEntryContext { pub partition_spec_id: i32, pub snapshot_schema: SchemaRef, pub delete_file_index: DeleteFileIndex, + pub limit: Option, } impl ManifestFileContext { @@ -74,6 +76,7 @@ impl ManifestFileContext { mut sender, expression_evaluator_cache, delete_file_index, + limit, .. } = self; @@ -89,6 +92,7 @@ impl ManifestFileContext { bound_predicates: bound_predicates.clone(), snapshot_schema: snapshot_schema.clone(), delete_file_index: delete_file_index.clone(), + limit, }; sender @@ -128,6 +132,8 @@ impl ManifestEntryContext { .map(|x| x.as_ref().snapshot_bound_predicate.clone()), deletes, + + limit: self.limit, }) } } @@ -142,6 +148,7 @@ pub(crate) struct PlanContext { pub snapshot_schema: SchemaRef, pub case_sensitive: bool, pub predicate: Option>, + pub limit: Option, pub snapshot_bound_predicate: Option>, pub object_cache: Arc, pub field_ids: Arc>, @@ -255,6 +262,7 @@ impl PlanContext { manifest_file: manifest_file.clone(), bound_predicates, sender, + limit: self.limit, object_cache: self.object_cache.clone(), snapshot_schema: self.snapshot_schema.clone(), field_ids: self.field_ids.clone(), diff --git a/crates/iceberg/src/scan/mod.rs b/crates/iceberg/src/scan/mod.rs index 3d14b3cce..e9055bdd6 100644 --- a/crates/iceberg/src/scan/mod.rs +++ b/crates/iceberg/src/scan/mod.rs @@ -59,6 +59,8 @@ pub struct TableScanBuilder<'a> { concurrency_limit_manifest_files: usize, row_group_filtering_enabled: bool, row_selection_enabled: bool, + + limit: Option, } impl<'a> TableScanBuilder<'a> { @@ -77,9 +79,16 @@ impl<'a> TableScanBuilder<'a> { concurrency_limit_manifest_files: num_cpus, row_group_filtering_enabled: true, row_selection_enabled: false, + limit: None, } } + /// Sets the maximum number of records to return + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + /// Sets the desired size of batches in the response /// to something other than the default pub fn with_batch_size(mut self, batch_size: Option) -> Self { @@ -281,6 +290,7 @@ impl<'a> TableScanBuilder<'a> { snapshot_schema: schema, case_sensitive: self.case_sensitive, predicate: self.filter.map(Arc::new), + limit: self.limit, snapshot_bound_predicate: snapshot_bound_predicate.map(Arc::new), object_cache: self.table.object_cache(), field_ids: Arc::new(field_ids), @@ -1406,6 +1416,130 @@ pub mod tests { assert_eq!(int64_arr.value(0), 2); } + #[tokio::test] + async fn test_limit() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + let mut builder = fixture.table.scan(); + builder = builder.with_limit(Some(1)); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 1); + assert_eq!(batches[1].num_rows(), 1); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 2); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 2); + } + + #[tokio::test] + async fn test_limit_with_predicate() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y > 3 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").greater_than(Datum::long(3)); + builder = builder.with_filter(predicate).with_limit(Some(1)); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 1); + assert_eq!(batches[1].num_rows(), 1); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 4); + } + + #[tokio::test] + async fn test_limit_with_predicate_and_row_selection() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y > 3 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").greater_than(Datum::long(3)); + builder = builder + .with_filter(predicate) + .with_limit(Some(1)) + .with_row_selection_enabled(true); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 1); + assert_eq!(batches[1].num_rows(), 1); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 4); + } + + #[tokio::test] + async fn test_limit_higher_than_total_rows() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y > 3 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").greater_than(Datum::long(3)); + builder = builder + .with_filter(predicate) + .with_limit(Some(100_000_000)) + .with_row_selection_enabled(true); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 312); + assert_eq!(batches[1].num_rows(), 312); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 4); + } + #[tokio::test] async fn test_filter_on_arrow_gt_eq() { let mut fixture = TableTestFixture::new(); @@ -1780,6 +1914,7 @@ pub mod tests { record_count: Some(100), data_file_format: DataFileFormat::Parquet, deletes: vec![], + limit: None, }; test_fn(task); @@ -1794,6 +1929,7 @@ pub mod tests { record_count: None, data_file_format: DataFileFormat::Avro, deletes: vec![], + limit: None, }; test_fn(task); } diff --git a/crates/iceberg/src/scan/task.rs b/crates/iceberg/src/scan/task.rs index 32fe3ae30..17116ef0b 100644 --- a/crates/iceberg/src/scan/task.rs +++ b/crates/iceberg/src/scan/task.rs @@ -54,6 +54,9 @@ pub struct FileScanTask { /// The list of delete files that may need to be applied to this data file pub deletes: Vec, + + /// Maximum number of records to return, None means no limit + pub limit: Option, } impl FileScanTask { diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index d4751a19c..87a1a4b1e 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -51,6 +51,8 @@ pub struct IcebergTableScan { projection: Option>, /// Filters to apply to the table scan predicates: Option, + /// Maximum number of records to return, None means no limit + limit: Option, } impl IcebergTableScan { @@ -61,6 +63,7 @@ impl IcebergTableScan { schema: ArrowSchemaRef, projection: Option<&Vec>, filters: &[Expr], + limit: Option, ) -> Self { let output_schema = match projection { None => schema.clone(), @@ -76,6 +79,7 @@ impl IcebergTableScan { plan_properties, projection, predicates, + limit, } } @@ -143,6 +147,7 @@ impl ExecutionPlan for IcebergTableScan { self.snapshot_id, self.projection.clone(), self.predicates.clone(), + self.limit, ); let stream = futures::stream::once(fut).try_flatten(); @@ -161,13 +166,14 @@ impl DisplayAs for IcebergTableScan { ) -> std::fmt::Result { write!( f, - "IcebergTableScan projection:[{}] predicate:[{}]", + "IcebergTableScan projection:[{}] predicate:[{}] limit:[{}]", self.projection .clone() .map_or(String::new(), |v| v.join(",")), self.predicates .clone() - .map_or(String::from(""), |p| format!("{}", p)) + .map_or(String::from(""), |p| format!("{}", p)), + self.limit.map_or(String::from(""), |p| format!("{}", p)), ) } } @@ -182,6 +188,7 @@ async fn get_batch_stream( snapshot_id: Option, column_names: Option>, predicates: Option, + limit: Option, ) -> DFResult> + Send>>> { let scan_builder = match snapshot_id { Some(snapshot_id) => table.scan().snapshot_id(snapshot_id), @@ -195,6 +202,9 @@ async fn get_batch_stream( if let Some(pred) = predicates { scan_builder = scan_builder.with_filter(pred); } + + scan_builder = scan_builder.with_limit(limit); + let table_scan = scan_builder.build().map_err(to_datafusion_error)?; let stream = table_scan diff --git a/crates/integrations/datafusion/src/table/mod.rs b/crates/integrations/datafusion/src/table/mod.rs index a8c49837c..80f070c01 100644 --- a/crates/integrations/datafusion/src/table/mod.rs +++ b/crates/integrations/datafusion/src/table/mod.rs @@ -149,7 +149,7 @@ impl TableProvider for IcebergTableProvider { _state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], - _limit: Option, + limit: Option, ) -> DFResult> { Ok(Arc::new(IcebergTableScan::new( self.table.clone(), @@ -157,6 +157,7 @@ impl TableProvider for IcebergTableProvider { self.schema.clone(), projection, filters, + limit, ))) }