diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 31669171b291a..f8ef68d077af3 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -20,19 +20,18 @@ use std::any::Any; use std::sync::Arc; -use crate::Session; -use crate::TableProvider; - use arrow::datatypes::SchemaRef; +use async_trait::async_trait; use datafusion_common::{DFSchema, Result, plan_err}; use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::equivalence::project_ordering; use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; - -use async_trait::async_trait; use log::debug; +use crate::{Session, TableProvider}; + /// A [`TableProvider`] that streams a set of [`PartitionStream`] #[derive(Debug)] pub struct StreamingTable { @@ -105,7 +104,18 @@ impl TableProvider for StreamingTable { let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; let eqp = state.execution_props(); - create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)? + let original_sort_exprs = + create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)?; + + if let Some(p) = projection { + let schema = Arc::new(self.schema.project(p)?); + LexOrdering::new(original_sort_exprs) + .and_then(|lex_ordering| project_ordering(&lex_ordering, &schema)) + .map(|lex_ordering| lex_ordering.to_vec()) + .unwrap_or_default() + } else { + original_sort_exprs + } } else { vec![] }; diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 4b74aebdf5deb..6349ff1cd109f 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -29,11 +29,11 @@ use crate::physical_optimizer::test_utils::{ spr_repartition_exec, stream_exec_ordered, union_exec, }; -use arrow::compute::SortOptions; +use arrow::compute::{SortOptions}; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, TableReference}; +use datafusion_common::{create_array, Result, TableReference}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_expr_common::operator::Operator; @@ -58,7 +58,7 @@ use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion::prelude::*; -use arrow::array::{Int32Array, RecordBatch}; +use arrow::array::{record_batch, ArrayRef, Int32Array, RecordBatch}; use arrow::datatypes::{Field}; use arrow_schema::Schema; use datafusion_execution::TaskContext; @@ -2805,3 +2805,47 @@ async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_sort_with_streaming_table() -> Result<()> { + let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [1, 2, 3]))?; + + let ctx = SessionContext::new(); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), + ]; + let schema = batch.schema(); + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT a FROM test_table GROUP BY a ORDER BY a"; + let results = ctx.sql(sql).await?.collect().await?; + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_columns(), 1); + let expected = create_array!(Int32, vec![1, 2, 3]) as ArrayRef; + assert_eq!(results[0].column(0), &expected); + + Ok(()) +}