From 535943917dbaac3325275a8b8e88f22ad392ac84 Mon Sep 17 00:00:00 2001 From: Aniket Modak Date: Wed, 12 Nov 2025 20:15:40 +0530 Subject: [PATCH] Fixed select * and added analyzer rules for projecting row_id in query phase --- ...imizer.rs => absolute_row_id_optimizer.rs} | 96 ++++------ plugins/engine-datafusion/jni/src/lib.rs | 37 ++-- .../jni/src/listing_table.rs | 5 +- .../jni/src/project_row_id_analyzer.rs | 119 ++++++++++++ .../jni/src/query_executor.rs | 178 +++++++++++++++--- .../datafusion/DatafusionEngine.java | 57 ++++-- .../datafusion/jni/NativeBridge.java | 4 +- .../datafusion/search/DatafusionQuery.java | 29 ++- .../datafusion/search/DatafusionReader.java | 2 + .../datafusion/search/DatafusionSearcher.java | 8 +- .../DataFusionReaderManagerTests.java | 2 +- .../datafusion/DataFusionServiceTests.java | 12 +- .../index/engine/SearchExecEngine.java | 2 +- .../index/mapper/NumberFieldMapper.java | 11 ++ .../org/opensearch/search/SearchService.java | 3 +- 15 files changed, 427 insertions(+), 138 deletions(-) rename plugins/engine-datafusion/jni/src/{row_id_optimizer.rs => absolute_row_id_optimizer.rs} (66%) create mode 100644 plugins/engine-datafusion/jni/src/project_row_id_analyzer.rs diff --git a/plugins/engine-datafusion/jni/src/row_id_optimizer.rs b/plugins/engine-datafusion/jni/src/absolute_row_id_optimizer.rs similarity index 66% rename from plugins/engine-datafusion/jni/src/row_id_optimizer.rs rename to plugins/engine-datafusion/jni/src/absolute_row_id_optimizer.rs index b2bdd0216868e..4810c91d9d0d2 100644 --- a/plugins/engine-datafusion/jni/src/row_id_optimizer.rs +++ b/plugins/engine-datafusion/jni/src/absolute_row_id_optimizer.rs @@ -11,7 +11,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields, Schema}; use arrow_schema::SchemaRef; -use datafusion::physical_plan::projection::new_projections_for_columns; +use datafusion::physical_plan::projection::{new_projections_for_columns, ProjectionExpr}; use datafusion::{ common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}, config::ConfigOptions, @@ -21,15 +21,21 @@ use datafusion::{ }, error::DataFusionError, logical_expr::Operator, - physical_expr::{PhysicalExpr, expressions::{BinaryExpr, Column}}, + parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder, + physical_expr::{ + expressions::{BinaryExpr, Column}, + PhysicalExpr, + }, physical_optimizer::PhysicalOptimizerRule, - physical_plan::{ExecutionPlan, filter::FilterExec, projection::{ProjectionExec, ProjectionExpr}}, + physical_plan::{filter::FilterExec, projection::ProjectionExec, ExecutionPlan}, }; #[derive(Debug)] -pub struct ProjectRowIdOptimizer; +pub struct AbsoluteRowIdOptimizer; +pub const ROW_ID_FIELD_NAME: &'static str = "___row_id"; +pub const ROW_BASE_FIELD_NAME: &'static str = "row_base"; -impl ProjectRowIdOptimizer { +impl AbsoluteRowIdOptimizer { /// Helper to build new schema and projection info with added `row_base` column. fn build_updated_file_source_schema( &self, @@ -53,27 +59,19 @@ impl ProjectRowIdOptimizer { // } // } - if !projections.contains(&file_source_schema.index_of("___row_id").unwrap()) { - new_projections.push(file_source_schema.index_of("___row_id").unwrap()); + if !projections.contains(&file_source_schema.index_of(ROW_ID_FIELD_NAME).unwrap()) { + new_projections.push(file_source_schema.index_of(ROW_ID_FIELD_NAME).unwrap()); - // let field = file_source_schema.field_with_name(&*"___row_id").expect("Field ___row_id not found in file_source_schema"); - // fields.push(Arc::new(Field::new("___row_id", field.data_type().clone(), field.is_nullable()))); + // let field = file_source_schema.field_with_name(&*ROW_ID_FIELD_NAME).expect("Field ___row_id not found in file_source_schema"); + // fields.push(Arc::new(Field::new(ROW_ID_FIELD_NAME, field.data_type().clone(), field.is_nullable()))); } new_projections.push(file_source_schema.fields.len()); - // fields.push(Arc::new(Field::new("row_base", file_source_schema.field_with_name("___row_id").unwrap().data_type().clone(), true))); + // fields.push(Arc::new(Field::new("row_base", file_source_schema.field_with_name(ROW_ID_FIELD_NAME).unwrap().data_type().clone(), true))); // Add row_base field to schema let mut new_fields = file_source_schema.fields().clone().to_vec(); - new_fields.push(Arc::new(Field::new( - "row_base", - file_source_schema - .field_with_name("___row_id") - .unwrap() - .data_type() - .clone(), - true, - ))); + new_fields.push(Arc::new(Field::new(ROW_BASE_FIELD_NAME, file_source_schema.field_with_name(ROW_ID_FIELD_NAME).unwrap().data_type().clone(), true))); let new_schema = Arc::new(Schema { metadata: file_source_schema.metadata().clone(), @@ -84,31 +82,23 @@ impl ProjectRowIdOptimizer { } /// Creates a projection expression that adds `row_base` to `___row_id`. - fn build_projection_exprs( - &self, - new_schema: &SchemaRef, - ) -> Result, String)>, DataFusionError> { - let row_id_idx = new_schema - .index_of("___row_id") - .expect("Field ___row_id missing"); - let row_base_idx = new_schema - .index_of("row_base") - .expect("Field row_base missing"); - + fn build_projection_exprs(&self, new_schema: &SchemaRef) -> Result, String)>, DataFusionError> { + let row_id_idx = new_schema.index_of(ROW_ID_FIELD_NAME).expect("Field ___row_id missing"); + let row_base_idx = new_schema.index_of(ROW_BASE_FIELD_NAME).expect("Field row_base missing"); let sum_expr: Arc = Arc::new(BinaryExpr::new( - Arc::new(Column::new("___row_id", row_id_idx)), + Arc::new(Column::new(ROW_ID_FIELD_NAME, row_id_idx)), Operator::Plus, - Arc::new(Column::new("row_base", row_base_idx)), + Arc::new(Column::new(ROW_BASE_FIELD_NAME, row_base_idx)), )); let mut projection_exprs: Vec<(Arc, String)> = Vec::new(); let mut has_row_id = false; for field_name in new_schema.fields().to_vec() { - if field_name.name() == "___row_id" { + if field_name.name() == ROW_ID_FIELD_NAME { projection_exprs.push((sum_expr.clone(), field_name.name().clone())); has_row_id = true; - } else if (field_name.name() != "row_base") { + } else if(field_name.name() != ROW_BASE_FIELD_NAME) { // Match the column by name from new_schema let idx = new_schema .index_of(&*field_name.name().clone()) @@ -120,7 +110,7 @@ impl ProjectRowIdOptimizer { } } if !has_row_id { - projection_exprs.push((sum_expr.clone(), "___row_id".parse().unwrap())); + projection_exprs.push((sum_expr.clone(), ROW_ID_FIELD_NAME.parse().unwrap())); } Ok(projection_exprs) } @@ -147,7 +137,7 @@ impl ProjectRowIdOptimizer { } } -impl PhysicalOptimizerRule for ProjectRowIdOptimizer { +impl PhysicalOptimizerRule for AbsoluteRowIdOptimizer { fn optimize( &self, plan: Arc, @@ -162,34 +152,16 @@ impl PhysicalOptimizerRule for ProjectRowIdOptimizer { .downcast_ref::() .expect("DataSource not found"); let schema = datasource.file_schema.clone(); - schema - .field_with_name("___row_id") - .expect("Field ___row_id missing"); - let projection = self - .create_datasource_projection(datasource, datasource_exec.schema()) - .expect("Failed to create ProjectionExec from datasource"); - return Ok(Transformed::new( - Arc::new(projection), - true, - TreeNodeRecursion::Continue, - )); + schema.field_with_name(ROW_ID_FIELD_NAME).expect("Field ___row_id missing"); + let projection = self.create_datasource_projection(datasource, datasource_exec.schema()).expect("Failed to create ProjectionExec from datasource"); + return Ok(Transformed::new(Arc::new(projection), true, TreeNodeRecursion::Continue)); + } else if let Some(projection_exec) = node.as_any().downcast_ref::() { - if !projection_exec - .schema() - .field_with_name("___row_id") - .is_ok() - { + if !projection_exec.schema().field_with_name(ROW_ID_FIELD_NAME).is_ok() { + let mut projection_exprs = projection_exec.expr().to_vec(); - if (projection_exec - .input() - .schema() - .index_of("___row_id") - .is_ok()) - { - if projection_exec.input().schema().index_of("___row_id").is_ok() { - let row_id_col: Arc = Arc::new(Column::new("___row_id", projection_exec.input().schema().index_of("___row_id").unwrap())); - projection_exprs.push(ProjectionExpr::new(row_id_col, "___row_id".to_string())); - } + if(projection_exec.input().schema().index_of(ROW_ID_FIELD_NAME).is_ok()) { + projection_exprs.push(ProjectionExpr::new(Arc::new(Column::new(ROW_ID_FIELD_NAME, projection_exec.input().schema().index_of(ROW_ID_FIELD_NAME).unwrap())), ROW_ID_FIELD_NAME.to_string())); } let projection = diff --git a/plugins/engine-datafusion/jni/src/lib.rs b/plugins/engine-datafusion/jni/src/lib.rs index 58c87e0ee66e2..abbac9d22d5dc 100644 --- a/plugins/engine-datafusion/jni/src/lib.rs +++ b/plugins/engine-datafusion/jni/src/lib.rs @@ -10,7 +10,7 @@ use std::num::NonZeroUsize; use std::ptr::addr_of_mut; use jni::objects::{JByteArray, JClass, JObject}; use jni::objects::JLongArray; -use jni::sys::{jbyteArray, jint, jlong, jstring}; +use jni::sys::{jboolean, jbyteArray, jint, jlong, jstring}; use jni::{JNIEnv, JavaVM}; use std::sync::{Arc, OnceLock}; use arrow_array::{Array, StructArray}; @@ -32,7 +32,7 @@ use std::default::Default; use std::time::{Duration, Instant}; mod util; -mod row_id_optimizer; +mod absolute_row_id_optimizer; mod listing_table; mod cache; mod custom_cache_manager; @@ -44,6 +44,7 @@ mod runtime_manager; mod cache_jni; mod partial_agg_optimizer; mod query_executor; +mod project_row_id_analyzer; use crate::custom_cache_manager::CustomCacheManager; use crate::util::{create_file_meta_from_filenames, parse_string_arr, set_action_listener_error, set_action_listener_error_global, set_action_listener_ok, set_action_listener_ok_global}; @@ -330,7 +331,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_createDat } }; - let files: Vec = match parse_string_arr(&mut env, files) { + let mut files: Vec = match parse_string_arr(&mut env, files) { Ok(files) => files, Err(e) => { let _ = env.throw_new( @@ -341,6 +342,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_createDat } }; + // TODO: This works since files are named similarly ending with incremental generation count, preferably move this up to DatafusionReaderManager to keep file order + files.sort(); let files_metadata = match create_file_meta_from_filenames(&table_path, files.clone()) { Ok(metadata) => metadata, Err(err) => { @@ -450,6 +453,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu shard_view_ptr: jlong, table_name: JString, substrait_bytes: jbyteArray, + is_aggregation_query: jboolean, runtime_ptr: jlong, listener: JObject, ) { @@ -458,7 +462,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu None => { error!("Runtime manager not initialized"); set_action_listener_error(&mut env, listener, - &DataFusionError::Execution("Runtime manager not initialized".to_string())); + &DataFusionError::Execution("Runtime manager not initialized".to_string())); return; } }; @@ -469,18 +473,20 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu Err(e) => { error!("Failed to get table name: {}", e); set_action_listener_error(&mut env, listener, - &DataFusionError::Execution(format!("Failed to get table name: {}", e))); + &DataFusionError::Execution(format!("Failed to get table name: {}", e))); return; } }; + let is_aggregation_query: bool = is_aggregation_query !=0; + let plan_bytes_obj = unsafe { JByteArray::from_raw(substrait_bytes) }; let plan_bytes_vec = match env.convert_byte_array(plan_bytes_obj) { Ok(bytes) => bytes, Err(e) => { error!("Failed to convert plan bytes: {}", e); set_action_listener_error(&mut env, listener, - &DataFusionError::Execution(format!("Failed to convert plan bytes: {}", e))); + &DataFusionError::Execution(format!("Failed to convert plan bytes: {}", e))); return; } }; @@ -491,7 +497,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu Err(e) => { error!("Failed to create global ref: {}", e); set_action_listener_error(&mut env, listener, - &DataFusionError::Execution(format!("Failed to create global ref: {}", e))); + &DataFusionError::Execution(format!("Failed to create global ref: {}", e))); return; } }; @@ -511,6 +517,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu files_meta, table_name, plan_bytes_vec, + is_aggregation_query, runtime, cpu_executor, ).await; @@ -559,7 +566,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_streamNex Err(e) => { error!("Failed to create global ref: {}", e); set_action_listener_error(&mut env, listener, - &DataFusionError::Execution(format!("Failed to create global ref: {}", e))); + &DataFusionError::Execution(format!("Failed to create global ref: {}", e))); return; } }; @@ -644,7 +651,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe _class: JClass, shard_view_ptr: jlong, values: JLongArray, - projections: JObjectArray, + include_fields: JObjectArray, + exclude_fields: JObjectArray, runtime_ptr: jlong, callback: JObject, ) -> jlong { @@ -654,8 +662,10 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe let table_path = shard_view.table_path(); let files_metadata = shard_view.files_metadata(); - let projections: Vec = - parse_string_arr(&mut env, projections).expect("Expected list of files"); + let include_fields: Vec = + parse_string_arr(&mut env, include_fields).expect("Expected list of files"); + let exclude_fields: Vec = + parse_string_arr(&mut env, exclude_fields).expect("Expected list of files"); // Safety checks first if values.is_null() { @@ -697,7 +707,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe None => { error!("Runtime manager not initialized"); set_action_listener_error(&mut env, callback, - &DataFusionError::Execution("Runtime manager not initialized".to_string())); + &DataFusionError::Execution("Runtime manager not initialized".to_string())); return 0; } }; @@ -710,7 +720,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe table_path, files_metadata, row_ids, - projections, + include_fields, + exclude_fields, runtime, cpu_executor, ).await { diff --git a/plugins/engine-datafusion/jni/src/listing_table.rs b/plugins/engine-datafusion/jni/src/listing_table.rs index 83728c2261adb..c7628350ccd9a 100644 --- a/plugins/engine-datafusion/jni/src/listing_table.rs +++ b/plugins/engine-datafusion/jni/src/listing_table.rs @@ -64,6 +64,7 @@ use futures::{future, stream, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; use regex::Regex; +use crate::absolute_row_id_optimizer::ROW_ID_FIELD_NAME; use std::fs::File; use std::{any::Any, collections::HashMap, str::FromStr, sync::Arc}; @@ -302,7 +303,7 @@ impl ListingTableConfig { /// # Errors /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_schema(self, state: &dyn Session) -> Result { - match self.options { + match self.options { Some(options) => { let ListingTableConfig { table_paths, @@ -1144,7 +1145,7 @@ impl ListingTable { } let row_id_field_datatype = self .file_schema - .field_with_name("___row_id") + .field_with_name(ROW_ID_FIELD_NAME) .expect("Field ___row_id not found") .data_type(); if !(row_id_field_datatype.equals_datatype(&DataType::Int32) diff --git a/plugins/engine-datafusion/jni/src/project_row_id_analyzer.rs b/plugins/engine-datafusion/jni/src/project_row_id_analyzer.rs new file mode 100644 index 0000000000000..4448d166f02fe --- /dev/null +++ b/plugins/engine-datafusion/jni/src/project_row_id_analyzer.rs @@ -0,0 +1,119 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +use std::sync::Arc; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::{Column, DFSchema}; +use datafusion::common::tree_node::{TreeNode, Transformed}; +use datafusion::config::ConfigOptions; +use datafusion::optimizer::{AnalyzerRule}; +use datafusion::logical_expr::{ + Expr, LogicalPlan, LogicalPlanBuilder, col +}; +use datafusion::error::Result; +use datafusion_expr::{Projection, UserDefinedLogicalNode}; +use crate::absolute_row_id_optimizer::ROW_ID_FIELD_NAME; + +#[derive(Debug)] +pub struct ProjectRowIdAnalyzer; + +impl ProjectRowIdAnalyzer { + + pub fn new() -> Self { + Self {} + } + + fn wrap_project_row_id(&self, plan: LogicalPlan) -> Result { + LogicalPlanBuilder::from(plan) + .project(vec![col(ROW_ID_FIELD_NAME)]) + .map(|b| b.build())? + } +} + +impl AnalyzerRule for ProjectRowIdAnalyzer { + fn analyze( + &self, + plan: LogicalPlan, + _config: &ConfigOptions + ) -> Result { + + let rewritten = plan.transform_up(|node| { + match &node { + LogicalPlan::TableScan(scan) => { + let mut proj = scan.projection.clone().unwrap_or_else(|| { + (0..scan.projected_schema.fields().len()).collect() + }); + + let mut new_projected_schema = (*scan.projected_schema).clone(); + + // Append ___row_id field if not already present + if scan.source.schema().index_of(ROW_ID_FIELD_NAME).is_ok() { + let row_id_idx = scan.source.schema().index_of(ROW_ID_FIELD_NAME).unwrap(); + if !proj.contains(&row_id_idx) { + proj.push(row_id_idx); + new_projected_schema = new_projected_schema + .join(&DFSchema::try_from_qualified_schema( + scan.projected_schema.qualified_field(0).0.expect("Failed to get qualified name").clone(), + &Schema::new(vec![ + Field::new(ROW_ID_FIELD_NAME, DataType::Int64, false), + ]), + )?) + .expect("Failed to join schemas"); + } + } + + // Optionally, add row_base similarly + let new_scan = LogicalPlan::TableScan(datafusion_expr::TableScan { + table_name: scan.table_name.clone(), + source: scan.source.clone(), + projection: Some(proj), + projected_schema: Arc::new(new_projected_schema), + filters: scan.filters.clone(), + fetch: scan.fetch, + }); + println!("new_scan: {:?}", new_scan); + return Ok(Transformed::yes(new_scan)); + } + + LogicalPlan::Projection(p) => { + if !p.expr.iter().any(|e| matches!(e, Expr::Column(c) if c.name == ROW_ID_FIELD_NAME)) + && p.input.schema().index_of_column(&Column::from_name("___row_id")).is_ok() + { + let mut new_exprs = p.expr.to_vec(); + new_exprs.push(col(ROW_ID_FIELD_NAME)); + // new_exprs.push(self.build_row_id_expr()); + let mut new_fields = vec![];//p.schema.fields().to_vec(); + new_fields.push(Field::new(ROW_ID_FIELD_NAME, DataType::Int64, false)); + let new_schema = DFSchema::try_from_qualified_schema( + p.schema.qualified_field(0).0.expect("Failed to get qualified name").clone(), + &Schema::new(new_fields), + )?; + + // if p.input.schema().index_of_column(&Column::from_name("___row_id")).is_ok() + + let merged_schema = if p.schema.index_of_column(&Column::from_name(ROW_ID_FIELD_NAME)).is_ok() { p.schema.clone() } else { Arc::new(p.schema.clone().join(&new_schema).expect("Failed to join schemas")) }; + let new_proj = LogicalPlan::Projection(Projection::try_new_with_schema(new_exprs, p.input.clone(), merged_schema).expect("Failed to create projection")); + // println!("new_proj: {:?}", new_proj); + + return Ok(Transformed::yes(new_proj)); + } + Ok(Transformed::no(node)) + } + + _ => {Ok(Transformed::no(node))} + } + })?; + + // rewritten.data is the updated logical plan + Ok(rewritten.data) + } + + fn name(&self) -> &str { + "project_row_id_logical_optimizer" + } +} diff --git a/plugins/engine-datafusion/jni/src/query_executor.rs b/plugins/engine-datafusion/jni/src/query_executor.rs index 418f19045f9eb..4da83c2bb5663 100644 --- a/plugins/engine-datafusion/jni/src/query_executor.rs +++ b/plugins/engine-datafusion/jni/src/query_executor.rs @@ -7,7 +7,7 @@ */ use std::sync::Arc; -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use jni::sys::jlong; use datafusion::{ common::DataFusionError, @@ -34,22 +34,32 @@ use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use datafusion_substrait::substrait::proto::{Plan, extensions::simple_extension_declaration::MappingType}; use object_store::ObjectMeta; use prost::Message; -use arrow_schema::DataType; +use arrow_schema::{DataType, Field, SchemaRef}; +use chrono::TimeZone; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::Operator; +use datafusion::optimizer::AnalyzerRule; +use datafusion::physical_expr::expressions::BinaryExpr; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion_expr::{LogicalPlan, Projection}; use log::error; - +use object_store::path::Path; use crate::listing_table::{ListingOptions, ListingTable, ListingTableConfig}; use crate::partial_agg_optimizer::PartialAggregationOptimizer; use crate::executor::DedicatedExecutor; use crate::cross_rt_stream::CrossRtStream; use crate::CustomFileMeta; use crate::DataFusionRuntime; -use crate::row_id_optimizer::ProjectRowIdOptimizer; +use crate::project_row_id_analyzer::ProjectRowIdAnalyzer; +use crate::absolute_row_id_optimizer::{AbsoluteRowIdOptimizer, ROW_BASE_FIELD_NAME, ROW_ID_FIELD_NAME}; pub async fn execute_query_with_cross_rt_stream( table_path: ListingTableUrl, files_meta: Arc>, table_name: String, plan_bytes_vec: Vec, + is_aggregation_query: bool, runtime: &DataFusionRuntime, cpu_executor: DedicatedExecutor, ) -> Result { @@ -87,15 +97,17 @@ pub async fn execute_query_with_cross_rt_stream( config.options_mut().execution.target_partitions = 1; config.options_mut().execution.batch_size = 1024; - let state = datafusion::execution::SessionStateBuilder::new() + let mut state_builder = datafusion::execution::SessionStateBuilder::new() .with_config(config.clone()) .with_runtime_env(Arc::from(runtime_env)) .with_default_features() - //.with_physical_optimizer_rule(Arc::new(ProjectRowIdOptimizer)) // TODO : uncomment this after fix - .with_physical_optimizer_rule(Arc::new(PartialAggregationOptimizer)) - .build(); + .with_physical_optimizer_rule(Arc::new(PartialAggregationOptimizer)); - let ctx = SessionContext::new_with_state(state); + if(!is_aggregation_query) { + state_builder = state_builder.with_physical_optimizer_rule(Arc::new(AbsoluteRowIdOptimizer)); // Uses row_base from partition cols to evaluate ___row_id + row_base as ___row_id + } + + let ctx = SessionContext::new_with_state(state_builder.build()); // Register table let file_format = ParquetFormat::new(); @@ -103,7 +115,7 @@ pub async fn execute_query_with_cross_rt_stream( .with_file_extension(".parquet") .with_files_metadata(files_meta) .with_session_config_options(&config) - .with_table_partition_cols(vec![("row_base".to_string(), DataType::Int64)]); + .with_table_partition_cols(if is_aggregation_query { vec![] } else { vec![(ROW_BASE_FIELD_NAME.to_string(), DataType::Int64)] }); let resolved_schema = match listing_options .infer_schema(&ctx.state(), &table_path) @@ -152,7 +164,7 @@ pub async fn execute_query_with_cross_rt_stream( } } - let logical_plan = match from_substrait_plan(&ctx.state(), &modified_plan).await { + let mut logical_plan = match from_substrait_plan(&ctx.state(), &modified_plan).await { Ok(plan) => plan, Err(e) => { error!("Failed to convert Substrait plan: {}", e); @@ -160,7 +172,15 @@ pub async fn execute_query_with_cross_rt_stream( } }; - let dataframe = match ctx.execute_logical_plan(logical_plan).await { + if !is_aggregation_query { + logical_plan = ProjectRowIdAnalyzer.analyze(logical_plan, ctx.state().config_options())?; // Only keeps ___row_id in projections + logical_plan = LogicalPlan::Projection(Projection::try_new( + vec![col(ROW_ID_FIELD_NAME.to_string())], + Arc::new(logical_plan), + ).expect("Failed to create top level projection with ___row_id")); + } + + let mut dataframe = match ctx.execute_logical_plan(logical_plan).await { Ok(df) => df, Err(e) => { error!("Failed to execute logical plan: {}", e); @@ -168,6 +188,10 @@ pub async fn execute_query_with_cross_rt_stream( } }; + // println!("Explain show"); + // let clone_df = dataframe.clone().explain(false, true); + // clone_df?.show().await?; + let df_stream = match dataframe.execute_stream().await { Ok(stream) => stream, Err(e) => { @@ -197,7 +221,8 @@ pub async fn execute_fetch_phase( table_path: ListingTableUrl, files_metadata: Arc>, row_ids: Vec, - projections: Vec, + include_fields: Vec, + exclude_fields: Vec, runtime: &DataFusionRuntime, cpu_executor: DedicatedExecutor, ) -> Result { @@ -223,54 +248,161 @@ pub async fn execute_fetch_phase( .with_metadata_cache_limit(file_metadata_cache.cache_limit()), ) .build()?; - let ctx = SessionContext::new_with_config_rt(SessionConfig::new(), Arc::new(runtime_env)); + + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + config.options_mut().execution.target_partitions = 1; + + let state = datafusion::execution::SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(Arc::from(runtime_env)) + .with_default_features() + // .with_physical_optimizer_rule(Arc::new(ProjectRowIdOptimizer)) + .build(); + + let ctx = SessionContext::new_with_state(state); let file_format = ParquetFormat::new(); let listing_options = ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); let parquet_schema = listing_options.infer_schema(&ctx.state(), &table_path).await?; + let projections = create_projections(include_fields, exclude_fields, parquet_schema.clone()); let partitioned_files: Vec = files_metadata .iter() .zip(access_plans.iter()) .map(|(meta, access_plan)| { - PartitionedFile::new( - meta.object_meta().location.to_string(), - meta.object_meta.size, - ) - .with_extensions(Arc::new(access_plan.clone())) + PartitionedFile { + object_meta: ObjectMeta { + location: Path::from(meta.object_meta().location.to_string()), + last_modified: chrono::Utc.timestamp_nanos(0), + size: meta.object_meta.size, + e_tag: None, + version: None, + }, + partition_values: vec![ScalarValue::Int64(Some(*meta.row_base))], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + } + .with_extensions(Arc::new(access_plan.clone())) }) .collect(); let file_group = FileGroup::new(partitioned_files); - let file_source = Arc::new(ParquetSource::default()); + + let file_source = Arc::new( + ParquetSource::default(), // provide the factory to create parquet reader without re-reading metadata + //.with_parquet_file_reader_factory(Arc::new(reader_factory)), + ); let mut projection_index = vec![]; + for field_name in projections.iter() { projection_index.push( parquet_schema .index_of(field_name) - .map_err(|_| DataFusionError::Execution(format!("Projected field {} not found in Schema", field_name)))?, + .expect(format!("Projected field {} not found in Schema", field_name).as_str()), ); } + if(!projections.contains(&ROW_ID_FIELD_NAME.to_string())) { + projection_index.push(parquet_schema.index_of(ROW_ID_FIELD_NAME).unwrap()); + } + projection_index.push(parquet_schema.fields.len()); + let file_scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), parquet_schema.clone(), file_source, ) + .with_table_partition_cols(vec![Field::new(ROW_BASE_FIELD_NAME, DataType::Int64, false)]) .with_projection(Option::from(projection_index.clone())) .with_file_group(file_group) .build(); - let parquet_exec = DataSourceExec::from_data_source(file_scan_config); - let optimized_plan: Arc = parquet_exec.clone(); + let parquet_exec = DataSourceExec::from_data_source(file_scan_config.clone()); + + let projection_exprs = build_projection_exprs(file_scan_config.projected_schema()) + .expect("Failed to build projection expressions"); + + let projection_exec = Arc::new(ProjectionExec::try_new(projection_exprs, parquet_exec) + .expect("Failed to create ProjectionExec")); + let optimized_plan: Arc = projection_exec.clone(); let task_ctx = Arc::new(TaskContext::default()); let stream = optimized_plan.execute(0, task_ctx)?; Ok(get_cross_rt_stream(cpu_executor, stream)) } +pub fn create_projections( + include_fields: Vec, + exclude_fields: Vec, + schema: SchemaRef, +) -> Vec { + + // Get all field names from schema + let all_fields: Vec = + schema.fields().to_vec().iter().map(|f| f.name().to_string()).filter(|f| f.eq(ROW_ID_FIELD_NAME) || !f.starts_with("_")).collect(); //exclude metadata fields + + match (include_fields.is_empty(), exclude_fields.is_empty()) { + + // includes empty, excludes empty → all fields + (true, true) => all_fields.clone(), + + // includes non-empty → include only these fields + (false, _) => include_fields + .into_iter() + .filter(|f| schema.field_with_name(f).is_ok()) // keep valid fields + .collect(), + + // includes empty, excludes non-empty → remove excludes + (true, false) => { + let exclude_set: HashSet = + exclude_fields.into_iter().collect(); + + all_fields + .into_iter() + .filter(|f| !exclude_set.contains(f)) + .collect() + } + } +} + +fn build_projection_exprs(new_schema: SchemaRef) -> std::result::Result, String)>, DataFusionError> { + let row_id_idx = new_schema.index_of(ROW_ID_FIELD_NAME).expect("Field ___row_id missing"); + let row_base_idx = new_schema.index_of(ROW_BASE_FIELD_NAME).expect("Field ___row_id missing"); + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(datafusion::physical_expr::expressions::Column::new(ROW_ID_FIELD_NAME, row_id_idx)), + Operator::Plus, + Arc::new(datafusion::physical_expr::expressions::Column::new(ROW_BASE_FIELD_NAME, row_base_idx)), + )); + + let mut projection_exprs: Vec<(Arc, String)> = Vec::new(); + + let mut has_row_id = false; + for field_name in new_schema.fields().to_vec() { + if field_name.name() == ROW_ID_FIELD_NAME { + projection_exprs.push((sum_expr.clone(), field_name.name().clone())); + has_row_id = true; + } else if(field_name.name() != ROW_BASE_FIELD_NAME) { + // Match the column by name from new_schema + let idx = new_schema + .index_of(&*field_name.name().clone()) + .unwrap_or_else(|_| panic!("Field {field_name} missing in schema")); + projection_exprs.push(( + Arc::new(datafusion::physical_expr::expressions::Column::new(&*field_name.name(), idx)), + field_name.name().clone(), + )); + } + } + if !has_row_id { + projection_exprs.push((sum_expr.clone(), ROW_ID_FIELD_NAME.parse().unwrap())); + } + Ok(projection_exprs) +} + async fn create_access_plans( row_ids: Vec, files_metadata: Arc>, diff --git a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/DatafusionEngine.java b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/DatafusionEngine.java index 11ab030bdc9ab..b07649f43d32c 100644 --- a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/DatafusionEngine.java +++ b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/DatafusionEngine.java @@ -10,6 +10,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ViewVarCharVector; import org.apache.arrow.vector.types.pojo.Field; @@ -42,7 +43,9 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.SearchResultsCollector; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.FetchSubPhase; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.ReaderContext; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; @@ -92,7 +95,7 @@ public DatafusionEngine(DataFormat dataFormat, Collection formatCa public DatafusionContext createContext(ReaderContext readerContext, ShardSearchRequest request, SearchShardTarget searchShardTarget, SearchShardTask task, BigArrays bigArrays, SearchContext originalContext) throws IOException { DatafusionContext datafusionContext = new DatafusionContext(readerContext, request, searchShardTarget, task, this, bigArrays, originalContext); // Parse source - datafusionContext.datafusionQuery(new DatafusionQuery(request.shardId().getIndexName(), request.source().queryPlanIR(), new ArrayList<>())); + datafusionContext.datafusionQuery(new DatafusionQuery(request.shardId().getIndexName(), request.source().queryPlanIR(), new ArrayList<>(), request.source().aggregations() != null)); return datafusionContext; } @@ -187,7 +190,7 @@ public void close() { @Override - public Map executeQueryPhase(DatafusionContext context) { + public void executeQueryPhase(DatafusionContext context) { Map finalRes = new HashMap<>(); List rowIdResult = new ArrayList<>(); RecordBatchStream stream = null; @@ -253,9 +256,8 @@ public Map executeQueryPhase(DatafusionContext context) { throw new RuntimeException(e); } } - + context.setDFResults(finalRes); context.queryResult().topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(rowIdResult.size(), TotalHits.Relation.EQUAL_TO), rowIdResult.stream().map(d-> new ScoreDoc(d.intValue(), Float.NaN, context.indexShard().shardId().getId())).toList().toArray(ScoreDoc[]::new)) , Float.NaN), new DocValueFormat[0]); - return finalRes; } @Override @@ -374,9 +376,26 @@ public void executeFetchPhase(DatafusionContext context) throws IOException { // preprocess context.getDatafusionQuery().setFetchPhaseContext(rowIds); - List projections = new ArrayList<>(List.of(context.request().source().fetchSource().includes())); - projections.add(CompositeDataFormatWriter.ROW_ID); - context.getDatafusionQuery().setProjections(projections); + + List includeFields = + Optional.ofNullable(context.request().source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::includes) + .map(list -> new ArrayList<>(Arrays.asList(list))) + .orElseGet(ArrayList::new); + + List excludeFields = + Optional.ofNullable(context.request().source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::includes) + .map(list -> new ArrayList<>(Arrays.asList(list))) + .orElseGet(ArrayList::new); + + if(!includeFields.isEmpty()) { + includeFields.add(CompositeDataFormatWriter.ROW_ID); + } + + context.getDatafusionQuery().setSource(includeFields, excludeFields); DatafusionSearcher datafusionSearcher = context.getEngineSearcher(); long streamPointer = datafusionSearcher.search(context.getDatafusionQuery(), datafusionService.getRuntimePointer()); RecordBatchStream stream = new RecordBatchStream(streamPointer, datafusionService.getRuntimePointer(), rootAllocator); @@ -410,15 +429,22 @@ public void executeFetchPhase(DatafusionContext context) throws IOException { DerivedFieldGenerator derivedFieldGenerator = mapper.derivedFieldGenerator(); Object value = valueVectors.getObject(i); - if(valueVectors instanceof ViewVarCharVector) { - BytesRef bytesRef = new BytesRef(((ViewVarCharVector) valueVectors).get(i)); - derivedFieldGenerator.generate(builder, List.of(bytesRef)); // TODO: // Currently keyword field mapper do not have derived field converter from byte[] to BytesRef + if(value == null) { + builder.nullField(valueVectors.getName()); } else { - derivedFieldGenerator.generate(builder, List.of(value)); - } - if (valueVectors.getName().equals(IdFieldMapper.NAME)) { - BytesRef idRef = new BytesArray((byte[]) value).toBytesRef(); - _id = Uid.decodeId(idRef.bytes, idRef.offset, idRef.length); + if(valueVectors instanceof ViewVarCharVector) { + BytesRef bytesRef = new BytesRef(((ViewVarCharVector) valueVectors).get(i)); + derivedFieldGenerator.generate(builder, List.of(bytesRef)); + } else if (valueVectors instanceof TimeStampMilliVector) { + long timestamp = ((TimeStampMilliVector) valueVectors).get(i); + derivedFieldGenerator.generate(builder, List.of(timestamp)); + } else { + derivedFieldGenerator.generate(builder, List.of(value)); + } + if (valueVectors.getName().equals(IdFieldMapper.NAME)) { + BytesRef idRef = new BytesArray((byte[]) value).toBytesRef(); + _id = Uid.decodeId(idRef.bytes, idRef.offset, idRef.length); + } } } } catch (Exception e) { @@ -428,6 +454,7 @@ public void executeFetchPhase(DatafusionContext context) throws IOException { builder.endObject(); } assert row_id != null || rowIds.get(i) != null; + assert rowIdToIndex.containsKey(row_id); assert _id != null; BytesReference document = BytesReference.bytes(builder); byteRefs.add(document); diff --git a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/jni/NativeBridge.java b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/jni/NativeBridge.java index 0bc1054c92029..70a92d5b98edb 100644 --- a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/jni/NativeBridge.java +++ b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/jni/NativeBridge.java @@ -35,8 +35,8 @@ private NativeBridge() {} public static native void shutdownTokioRuntimeManager(); // Query execution - public static native void executeQueryPhaseAsync(long readerPtr, String tableName, byte[] plan, long runtimePtr, ActionListener listener); - public static native long executeFetchPhase(long readerPtr, long[] rowIds, String[] projections, long runtimePtr); + public static native void executeQueryPhaseAsync(long readerPtr, String tableName, byte[] plan, boolean isAggregationQuery, long runtimePtr, ActionListener listener); + public static native long executeFetchPhase(long readerPtr, long[] rowIds, String[] includeFields, String[] excludeFields, long runtimePtr); // Stream operations public static native void streamNext(long runtime, long stream, ActionListener listener); diff --git a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionQuery.java b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionQuery.java index 3ba9682059a64..acd7789a47308 100644 --- a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionQuery.java +++ b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionQuery.java @@ -8,7 +8,6 @@ package org.opensearch.datafusion.search; -import java.util.Iterator; import java.util.List; public class DatafusionQuery { @@ -19,17 +18,21 @@ public class DatafusionQuery { private final List searchExecutors; private Boolean isFetchPhase; private List queryPhaseRowIds; - private List projections; + private List includeFields; + private List excludeFields; + private boolean isAggregationQuery; - public DatafusionQuery(String indexName, byte[] substraitBytes, List searchExecutors) { + public DatafusionQuery(String indexName, byte[] substraitBytes, List searchExecutors, boolean isAggregationQuery) { this.indexName = indexName; this.substraitBytes = substraitBytes; this.searchExecutors = searchExecutors; this.isFetchPhase = false; + this.isAggregationQuery = isAggregationQuery; } - public void setProjections(List projections) { - this.projections = projections; + public void setSource(List includeFields, List excludeFields) { + this.includeFields = includeFields; + this.excludeFields = excludeFields; } public void setFetchPhaseContext(List queryPhaseRowIds) { @@ -41,12 +44,24 @@ public boolean isFetchPhase() { return this.isFetchPhase; } + public boolean isAggregationQuery() { + return isAggregationQuery; + } + + public void setAggregationQuery(boolean aggregationQuery) { + isAggregationQuery = aggregationQuery; + } + public List getQueryPhaseRowIds() { return this.queryPhaseRowIds; } - public List getProjections() { - return this.projections; + public List getIncludeFields() { + return this.includeFields; + } + + public List getExcludeFields() { + return this.excludeFields; } public byte[] getSubstraitBytes() { diff --git a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionReader.java b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionReader.java index 7f74e337155e6..65eb84843d88b 100644 --- a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionReader.java +++ b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionReader.java @@ -16,6 +16,8 @@ import java.io.Closeable; import java.util.Arrays; import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; /** * DataFusion reader for JNI operations. */ diff --git a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionSearcher.java b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionSearcher.java index 3691f663a649b..c791f746663a6 100644 --- a/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionSearcher.java +++ b/plugins/engine-datafusion/src/main/java/org/opensearch/datafusion/search/DatafusionSearcher.java @@ -10,7 +10,6 @@ import org.apache.lucene.store.AlreadyClosedException; import org.opensearch.core.action.ActionListener; -import org.opensearch.datafusion.ErrorUtil; import org.opensearch.datafusion.jni.NativeBridge; import org.opensearch.index.engine.EngineSearcher; import org.opensearch.vectorized.execution.search.spi.RecordBatchStream; @@ -45,9 +44,10 @@ public long search(DatafusionQuery datafusionQuery, Long runtimePtr) { .stream() .mapToLong(Long::longValue) .toArray(); - String[] projections = Objects.isNull(datafusionQuery.getProjections()) ? new String[]{} : datafusionQuery.getProjections().toArray(String[]::new); + String[] includeFields = Objects.isNull(datafusionQuery.getIncludeFields()) ? new String[]{} : datafusionQuery.getIncludeFields().toArray(String[]::new); + String[] excludeFields = Objects.isNull(datafusionQuery.getExcludeFields()) ? new String[]{} : datafusionQuery.getExcludeFields().toArray(String[]::new); - return NativeBridge.executeFetchPhase(reader.getReaderPtr(), row_ids, projections, runtimePtr); + return NativeBridge.executeFetchPhase(reader.getReaderPtr(), row_ids, includeFields, excludeFields, runtimePtr); } throw new RuntimeException("Can be only called for fetch phase"); } @@ -55,7 +55,7 @@ public long search(DatafusionQuery datafusionQuery, Long runtimePtr) { @Override public CompletableFuture searchAsync(DatafusionQuery datafusionQuery, Long runtimePtr) { CompletableFuture result = new CompletableFuture<>(); - NativeBridge.executeQueryPhaseAsync(reader.getReaderPtr(), datafusionQuery.getIndexName(), datafusionQuery.getSubstraitBytes(), runtimePtr, new ActionListener() { + NativeBridge.executeQueryPhaseAsync(reader.getReaderPtr(), datafusionQuery.getIndexName(), datafusionQuery.getSubstraitBytes(), datafusionQuery.isAggregationQuery(), runtimePtr, new ActionListener() { @Override public void onResponse(Long streamPointer) { if (streamPointer == 0) { diff --git a/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionReaderManagerTests.java b/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionReaderManagerTests.java index 483aa3c9d8990..63f5891195b80 100644 --- a/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionReaderManagerTests.java +++ b/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionReaderManagerTests.java @@ -349,7 +349,7 @@ public void testSearch() throws Exception { throw new RuntimeException(e); } - DatafusionQuery datafusionQuery = new DatafusionQuery("index-7", protoContent, new java.util.ArrayList<>()); + DatafusionQuery datafusionQuery = new DatafusionQuery("index-7", protoContent, new java.util.ArrayList<>(), true); Map expectedResults = new HashMap<>(); expectedResults.put("min", 2L); expectedResults.put("max", 4L); diff --git a/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionServiceTests.java b/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionServiceTests.java index 639b0d724ef35..7163314f19b2d 100644 --- a/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionServiceTests.java +++ b/plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionServiceTests.java @@ -158,7 +158,7 @@ public void testQueryPhaseExecutor() throws IOException { throw new RuntimeException(e); } - long streamPointer = datafusionSearcher.search(new DatafusionQuery(index.getName(), protoContent, new ArrayList<>()), service.getRuntimePointer()); + long streamPointer = datafusionSearcher.search(new DatafusionQuery(index.getName(), protoContent, new ArrayList<>(), false), service.getRuntimePointer()); RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); RecordBatchStream stream = new RecordBatchStream(streamPointer, service.getRuntimePointer(), allocator); @@ -217,7 +217,7 @@ public void testQueryThenFetchExecutor() throws IOException, URISyntaxException throw new RuntimeException(e); } - DatafusionQuery query = new DatafusionQuery(index.getName(), protoContent, new ArrayList<>()); + DatafusionQuery query = new DatafusionQuery(index.getName(), protoContent, new ArrayList<>(), false); long streamPointer = datafusionSearcher.search(query, service.getRuntimePointer()); RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); RecordBatchStream stream = new RecordBatchStream(streamPointer, service.getRuntimePointer(), allocator); @@ -240,7 +240,7 @@ public void testQueryThenFetchExecutor() throws IOException, URISyntaxException logger.info("Final row_ids count: {}", row_ids_res); List projections = List.of("message"); - query.setProjections(projections); + query.setSource(projections, List.of()); query.setFetchPhaseContext(row_ids_res); long fetchPhaseStreamPointer = datafusionSearcher.search(query, service.getRuntimePointer()); @@ -326,15 +326,15 @@ public void testQueryThenFetchE2ETest() throws IOException, URISyntaxException, DatafusionContext datafusionContext = new DatafusionContext(readerContext, shardSearchRequest, searchShardTarget, searchShardTask, engine, null, null); byte[] protoContent; - try (InputStream is = getClass().getResourceAsStream("/substrait_plan.pb")) { + try (InputStream is = getClass().getResourceAsStream("/substrait_plan_test.pb")) { protoContent = is.readAllBytes(); } catch (IOException e) { throw new RuntimeException(e); } - DatafusionQuery query = new DatafusionQuery(index.getName(), protoContent, new ArrayList<>()); + DatafusionQuery query = new DatafusionQuery(index.getName(), protoContent, new ArrayList<>(), false); List projections = List.of("message"); - query.setProjections(projections); + query.setSource(projections, List.of()); datafusionContext.datafusionQuery(query); diff --git a/server/src/main/java/org/opensearch/index/engine/SearchExecEngine.java b/server/src/main/java/org/opensearch/index/engine/SearchExecEngine.java index c0f461cbe3ad8..531e5ffa61b5d 100644 --- a/server/src/main/java/org/opensearch/index/engine/SearchExecEngine.java +++ b/server/src/main/java/org/opensearch/index/engine/SearchExecEngine.java @@ -40,7 +40,7 @@ public abstract class SearchExecEngine executeQueryPhase(C context) throws IOException; + public abstract void executeQueryPhase(C context) throws IOException; public abstract void executeQueryPhaseAsync(C context, Executor executor, ActionListener> listener); diff --git a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java index a03e69a05acce..754d27856bc85 100644 --- a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java @@ -224,6 +224,17 @@ public Object convert(Object value) { case BYTE, SHORT, INTEGER, LONG -> val; case UNSIGNED_LONG -> Numbers.toUnsignedBigInteger(val); }; + } else if (value instanceof Short) { + Short val = (Short) value; + + return switch (type) { + case HALF_FLOAT -> HalfFloatPoint.sortableShortToHalfFloat(val); + case FLOAT -> val.floatValue(); + case DOUBLE -> val.doubleValue(); + case BYTE, SHORT -> val; + case INTEGER, LONG -> val.longValue(); + case UNSIGNED_LONG -> Numbers.toUnsignedBigInteger(val); + }; } Long val = (Long) value; if (val == null) { diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index ee5061dd09e51..e96be124c4726 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -861,8 +861,7 @@ private SearchPhaseResult executeQueryPhase( context.queryResult().size(context.size()); if (substraitQuery != null) { // setDFResults in context - Map result = searchExecEngine.executeQueryPhase(context); - context.setDFResults(result); + searchExecEngine.executeQueryPhase(context); } return executeQueryPhase( context,