|
| 1 | +use datafusion::arrow::array::{Array, RecordBatch, UInt32Array}; |
| 2 | +use datafusion::arrow::compute::{concat_batches, take_record_batch}; |
| 3 | +use datafusion::arrow::error::ArrowError; |
| 4 | +use std::sync::Arc; |
| 5 | + |
| 6 | +pub const PAGE_SIZE: usize = 100; |
| 7 | + |
| 8 | +/// Calculate the row range needed for a given page |
| 9 | +pub fn page_row_range(page: usize, page_size: usize) -> (usize, usize) { |
| 10 | + let start = page * page_size; |
| 11 | + let end = start + page_size; |
| 12 | + (start, end) |
| 13 | +} |
| 14 | + |
| 15 | +/// Check if we have enough rows loaded to display the requested page |
| 16 | +pub fn has_sufficient_rows(loaded_rows: usize, page: usize, page_size: usize) -> bool { |
| 17 | + let (_start, end) = page_row_range(page, page_size); |
| 18 | + loaded_rows >= end |
| 19 | +} |
| 20 | + |
| 21 | +/// Extract a page of rows from loaded batches |
| 22 | +/// This handles pagination across batch boundaries by concatenating only what's needed |
| 23 | +pub fn extract_page( |
| 24 | + batches: &[RecordBatch], |
| 25 | + page: usize, |
| 26 | + page_size: usize, |
| 27 | +) -> Result<RecordBatch, ArrowError> { |
| 28 | + if batches.is_empty() { |
| 29 | + return Ok(RecordBatch::new_empty(Arc::new( |
| 30 | + datafusion::arrow::datatypes::Schema::empty(), |
| 31 | + ))); |
| 32 | + } |
| 33 | + |
| 34 | + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); |
| 35 | + let (start, end) = page_row_range(page, page_size); |
| 36 | + |
| 37 | + // Clamp end to available rows |
| 38 | + let end = end.min(total_rows); |
| 39 | + |
| 40 | + if start >= total_rows { |
| 41 | + // Page is beyond available data |
| 42 | + return Ok(RecordBatch::new_empty(batches[0].schema())); |
| 43 | + } |
| 44 | + |
| 45 | + // Create indices for the rows we want |
| 46 | + let indices = UInt32Array::from_iter_values((start as u32)..(end as u32)); |
| 47 | + |
| 48 | + // Extract rows from batches |
| 49 | + extract_rows_from_batches(batches, &indices) |
| 50 | +} |
| 51 | + |
| 52 | +/// Extract specific rows (by global indices) from batches |
| 53 | +/// Handles batch boundaries by concatenating only necessary batches |
| 54 | +fn extract_rows_from_batches( |
| 55 | + batches: &[RecordBatch], |
| 56 | + indices: &dyn Array, |
| 57 | +) -> Result<RecordBatch, ArrowError> { |
| 58 | + match batches.len() { |
| 59 | + 0 => Ok(RecordBatch::new_empty(Arc::new( |
| 60 | + datafusion::arrow::datatypes::Schema::empty(), |
| 61 | + ))), |
| 62 | + 1 => take_record_batch(&batches[0], indices), |
| 63 | + _ => { |
| 64 | + // Multiple batches: concat then extract rows |
| 65 | + // Only concat the batches we've loaded (lazy loading ensures minimal concat) |
| 66 | + let schema = batches[0].schema(); |
| 67 | + let concatenated = concat_batches(&schema, batches)?; |
| 68 | + take_record_batch(&concatenated, indices) |
| 69 | + } |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +#[cfg(test)] |
| 74 | +mod tests { |
| 75 | + use super::*; |
| 76 | + |
| 77 | + #[test] |
| 78 | + fn test_page_row_range() { |
| 79 | + assert_eq!(page_row_range(0, 100), (0, 100)); |
| 80 | + assert_eq!(page_row_range(1, 100), (100, 200)); |
| 81 | + assert_eq!(page_row_range(2, 50), (100, 150)); |
| 82 | + } |
| 83 | + |
| 84 | + #[test] |
| 85 | + fn test_has_sufficient_rows() { |
| 86 | + assert!(has_sufficient_rows(100, 0, 100)); // Exactly enough |
| 87 | + assert!(has_sufficient_rows(150, 0, 100)); // More than enough |
| 88 | + assert!(!has_sufficient_rows(50, 0, 100)); // Not enough |
| 89 | + assert!(!has_sufficient_rows(150, 1, 100)); // Need 200, only have 150 |
| 90 | + } |
| 91 | +} |
0 commit comments