Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Fields, Schema};
use arrow_schema::SchemaRef;
use datafusion::physical_plan::projection::new_projections_for_columns;
use datafusion::physical_plan::projection::{new_projections_for_columns, ProjectionExpr};
use datafusion::{
common::tree_node::{Transformed, TreeNode, TreeNodeRecursion},
config::ConfigOptions,
Expand All @@ -21,15 +21,21 @@ use datafusion::{
},
error::DataFusionError,
logical_expr::Operator,
physical_expr::{PhysicalExpr, expressions::{BinaryExpr, Column}},
parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder,
physical_expr::{
expressions::{BinaryExpr, Column},
PhysicalExpr,
},
physical_optimizer::PhysicalOptimizerRule,
physical_plan::{ExecutionPlan, filter::FilterExec, projection::{ProjectionExec, ProjectionExpr}},
physical_plan::{filter::FilterExec, projection::ProjectionExec, ExecutionPlan},
};

#[derive(Debug)]
pub struct ProjectRowIdOptimizer;
pub struct AbsoluteRowIdOptimizer;
pub const ROW_ID_FIELD_NAME: &'static str = "___row_id";
pub const ROW_BASE_FIELD_NAME: &'static str = "row_base";

impl ProjectRowIdOptimizer {
impl AbsoluteRowIdOptimizer {
/// Helper to build new schema and projection info with added `row_base` column.
fn build_updated_file_source_schema(
&self,
Expand All @@ -53,27 +59,19 @@ impl ProjectRowIdOptimizer {
// }
// }

if !projections.contains(&file_source_schema.index_of("___row_id").unwrap()) {
new_projections.push(file_source_schema.index_of("___row_id").unwrap());
if !projections.contains(&file_source_schema.index_of(ROW_ID_FIELD_NAME).unwrap()) {
new_projections.push(file_source_schema.index_of(ROW_ID_FIELD_NAME).unwrap());

// let field = file_source_schema.field_with_name(&*"___row_id").expect("Field ___row_id not found in file_source_schema");
// fields.push(Arc::new(Field::new("___row_id", field.data_type().clone(), field.is_nullable())));
// let field = file_source_schema.field_with_name(&*ROW_ID_FIELD_NAME).expect("Field ___row_id not found in file_source_schema");
// fields.push(Arc::new(Field::new(ROW_ID_FIELD_NAME, field.data_type().clone(), field.is_nullable())));
}
new_projections.push(file_source_schema.fields.len());
// fields.push(Arc::new(Field::new("row_base", file_source_schema.field_with_name("___row_id").unwrap().data_type().clone(), true)));
// fields.push(Arc::new(Field::new("row_base", file_source_schema.field_with_name(ROW_ID_FIELD_NAME).unwrap().data_type().clone(), true)));

// Add row_base field to schema

let mut new_fields = file_source_schema.fields().clone().to_vec();
new_fields.push(Arc::new(Field::new(
"row_base",
file_source_schema
.field_with_name("___row_id")
.unwrap()
.data_type()
.clone(),
true,
)));
new_fields.push(Arc::new(Field::new(ROW_BASE_FIELD_NAME, file_source_schema.field_with_name(ROW_ID_FIELD_NAME).unwrap().data_type().clone(), true)));

let new_schema = Arc::new(Schema {
metadata: file_source_schema.metadata().clone(),
Expand All @@ -84,31 +82,23 @@ impl ProjectRowIdOptimizer {
}

/// Creates a projection expression that adds `row_base` to `___row_id`.
fn build_projection_exprs(
&self,
new_schema: &SchemaRef,
) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>, DataFusionError> {
let row_id_idx = new_schema
.index_of("___row_id")
.expect("Field ___row_id missing");
let row_base_idx = new_schema
.index_of("row_base")
.expect("Field row_base missing");

fn build_projection_exprs(&self, new_schema: &SchemaRef) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>, DataFusionError> {
let row_id_idx = new_schema.index_of(ROW_ID_FIELD_NAME).expect("Field ___row_id missing");
let row_base_idx = new_schema.index_of(ROW_BASE_FIELD_NAME).expect("Field row_base missing");
let sum_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(Column::new("___row_id", row_id_idx)),
Arc::new(Column::new(ROW_ID_FIELD_NAME, row_id_idx)),
Operator::Plus,
Arc::new(Column::new("row_base", row_base_idx)),
Arc::new(Column::new(ROW_BASE_FIELD_NAME, row_base_idx)),
));

let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = Vec::new();

let mut has_row_id = false;
for field_name in new_schema.fields().to_vec() {
if field_name.name() == "___row_id" {
if field_name.name() == ROW_ID_FIELD_NAME {
projection_exprs.push((sum_expr.clone(), field_name.name().clone()));
has_row_id = true;
} else if (field_name.name() != "row_base") {
} else if(field_name.name() != ROW_BASE_FIELD_NAME) {
// Match the column by name from new_schema
let idx = new_schema
.index_of(&*field_name.name().clone())
Expand All @@ -120,7 +110,7 @@ impl ProjectRowIdOptimizer {
}
}
if !has_row_id {
projection_exprs.push((sum_expr.clone(), "___row_id".parse().unwrap()));
projection_exprs.push((sum_expr.clone(), ROW_ID_FIELD_NAME.parse().unwrap()));
}
Ok(projection_exprs)
}
Expand All @@ -147,7 +137,7 @@ impl ProjectRowIdOptimizer {
}
}

impl PhysicalOptimizerRule for ProjectRowIdOptimizer {
impl PhysicalOptimizerRule for AbsoluteRowIdOptimizer {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
Expand All @@ -162,34 +152,16 @@ impl PhysicalOptimizerRule for ProjectRowIdOptimizer {
.downcast_ref::<FileScanConfig>()
.expect("DataSource not found");
let schema = datasource.file_schema.clone();
schema
.field_with_name("___row_id")
.expect("Field ___row_id missing");
let projection = self
.create_datasource_projection(datasource, datasource_exec.schema())
.expect("Failed to create ProjectionExec from datasource");
return Ok(Transformed::new(
Arc::new(projection),
true,
TreeNodeRecursion::Continue,
));
schema.field_with_name(ROW_ID_FIELD_NAME).expect("Field ___row_id missing");
let projection = self.create_datasource_projection(datasource, datasource_exec.schema()).expect("Failed to create ProjectionExec from datasource");
return Ok(Transformed::new(Arc::new(projection), true, TreeNodeRecursion::Continue));

} else if let Some(projection_exec) = node.as_any().downcast_ref::<ProjectionExec>() {
if !projection_exec
.schema()
.field_with_name("___row_id")
.is_ok()
{
if !projection_exec.schema().field_with_name(ROW_ID_FIELD_NAME).is_ok() {

let mut projection_exprs = projection_exec.expr().to_vec();
if (projection_exec
.input()
.schema()
.index_of("___row_id")
.is_ok())
{
if projection_exec.input().schema().index_of("___row_id").is_ok() {
let row_id_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("___row_id", projection_exec.input().schema().index_of("___row_id").unwrap()));
projection_exprs.push(ProjectionExpr::new(row_id_col, "___row_id".to_string()));
}
if(projection_exec.input().schema().index_of(ROW_ID_FIELD_NAME).is_ok()) {
projection_exprs.push(ProjectionExpr::new(Arc::new(Column::new(ROW_ID_FIELD_NAME, projection_exec.input().schema().index_of(ROW_ID_FIELD_NAME).unwrap())), ROW_ID_FIELD_NAME.to_string()));
}

let projection =
Expand Down
37 changes: 24 additions & 13 deletions plugins/engine-datafusion/jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::num::NonZeroUsize;
use std::ptr::addr_of_mut;
use jni::objects::{JByteArray, JClass, JObject};
use jni::objects::JLongArray;
use jni::sys::{jbyteArray, jint, jlong, jstring};
use jni::sys::{jboolean, jbyteArray, jint, jlong, jstring};
use jni::{JNIEnv, JavaVM};
use std::sync::{Arc, OnceLock};
use arrow_array::{Array, StructArray};
Expand All @@ -32,7 +32,7 @@ use std::default::Default;
use std::time::{Duration, Instant};

mod util;
mod row_id_optimizer;
mod absolute_row_id_optimizer;
mod listing_table;
mod cache;
mod custom_cache_manager;
Expand All @@ -44,6 +44,7 @@ mod runtime_manager;
mod cache_jni;
mod partial_agg_optimizer;
mod query_executor;
mod project_row_id_analyzer;

use crate::custom_cache_manager::CustomCacheManager;
use crate::util::{create_file_meta_from_filenames, parse_string_arr, set_action_listener_error, set_action_listener_error_global, set_action_listener_ok, set_action_listener_ok_global};
Expand Down Expand Up @@ -330,7 +331,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_createDat
}
};

let files: Vec<String> = match parse_string_arr(&mut env, files) {
let mut files: Vec<String> = match parse_string_arr(&mut env, files) {
Ok(files) => files,
Err(e) => {
let _ = env.throw_new(
Expand All @@ -341,6 +342,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_createDat
}
};

// TODO: This works since files are named similarly ending with incremental generation count, preferably move this up to DatafusionReaderManager to keep file order
files.sort();
let files_metadata = match create_file_meta_from_filenames(&table_path, files.clone()) {
Ok(metadata) => metadata,
Err(err) => {
Expand Down Expand Up @@ -450,6 +453,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
shard_view_ptr: jlong,
table_name: JString,
substrait_bytes: jbyteArray,
is_aggregation_query: jboolean,
runtime_ptr: jlong,
listener: JObject,
) {
Expand All @@ -458,7 +462,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
None => {
error!("Runtime manager not initialized");
set_action_listener_error(&mut env, listener,
&DataFusionError::Execution("Runtime manager not initialized".to_string()));
&DataFusionError::Execution("Runtime manager not initialized".to_string()));
return;
}
};
Expand All @@ -469,18 +473,20 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
Err(e) => {
error!("Failed to get table name: {}", e);
set_action_listener_error(&mut env, listener,
&DataFusionError::Execution(format!("Failed to get table name: {}", e)));
&DataFusionError::Execution(format!("Failed to get table name: {}", e)));
return;
}
};

let is_aggregation_query: bool = is_aggregation_query !=0;

let plan_bytes_obj = unsafe { JByteArray::from_raw(substrait_bytes) };
let plan_bytes_vec = match env.convert_byte_array(plan_bytes_obj) {
Ok(bytes) => bytes,
Err(e) => {
error!("Failed to convert plan bytes: {}", e);
set_action_listener_error(&mut env, listener,
&DataFusionError::Execution(format!("Failed to convert plan bytes: {}", e)));
&DataFusionError::Execution(format!("Failed to convert plan bytes: {}", e)));
return;
}
};
Expand All @@ -491,7 +497,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
Err(e) => {
error!("Failed to create global ref: {}", e);
set_action_listener_error(&mut env, listener,
&DataFusionError::Execution(format!("Failed to create global ref: {}", e)));
&DataFusionError::Execution(format!("Failed to create global ref: {}", e)));
return;
}
};
Expand All @@ -511,6 +517,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
files_meta,
table_name,
plan_bytes_vec,
is_aggregation_query,
runtime,
cpu_executor,
).await;
Expand Down Expand Up @@ -559,7 +566,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_streamNex
Err(e) => {
error!("Failed to create global ref: {}", e);
set_action_listener_error(&mut env, listener,
&DataFusionError::Execution(format!("Failed to create global ref: {}", e)));
&DataFusionError::Execution(format!("Failed to create global ref: {}", e)));
return;
}
};
Expand Down Expand Up @@ -644,7 +651,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe
_class: JClass,
shard_view_ptr: jlong,
values: JLongArray,
projections: JObjectArray,
include_fields: JObjectArray,
exclude_fields: JObjectArray,
runtime_ptr: jlong,
callback: JObject,
) -> jlong {
Expand All @@ -654,8 +662,10 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe
let table_path = shard_view.table_path();
let files_metadata = shard_view.files_metadata();

let projections: Vec<String> =
parse_string_arr(&mut env, projections).expect("Expected list of files");
let include_fields: Vec<String> =
parse_string_arr(&mut env, include_fields).expect("Expected list of files");
let exclude_fields: Vec<String> =
parse_string_arr(&mut env, exclude_fields).expect("Expected list of files");

// Safety checks first
if values.is_null() {
Expand Down Expand Up @@ -697,7 +707,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe
None => {
error!("Runtime manager not initialized");
set_action_listener_error(&mut env, callback,
&DataFusionError::Execution("Runtime manager not initialized".to_string()));
&DataFusionError::Execution("Runtime manager not initialized".to_string()));
return 0;
}
};
Expand All @@ -710,7 +720,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeFe
table_path,
files_metadata,
row_ids,
projections,
include_fields,
exclude_fields,
runtime,
cpu_executor,
).await {
Expand Down
5 changes: 3 additions & 2 deletions plugins/engine-datafusion/jni/src/listing_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ use futures::{future, stream, Stream, StreamExt, TryStreamExt};
use itertools::Itertools;
use object_store::ObjectStore;
use regex::Regex;
use crate::absolute_row_id_optimizer::ROW_ID_FIELD_NAME;
use std::fs::File;
use std::{any::Any, collections::HashMap, str::FromStr, sync::Arc};

Expand Down Expand Up @@ -302,7 +303,7 @@ impl ListingTableConfig {
/// # Errors
/// * if `self.options` is not set. See [`Self::with_listing_options`]
pub async fn infer_schema(self, state: &dyn Session) -> Result<Self> {
match self.options {
match self.options {
Some(options) => {
let ListingTableConfig {
table_paths,
Expand Down Expand Up @@ -1144,7 +1145,7 @@ impl ListingTable {
}
let row_id_field_datatype = self
.file_schema
.field_with_name("___row_id")
.field_with_name(ROW_ID_FIELD_NAME)
.expect("Field ___row_id not found")
.data_type();
if !(row_id_field_datatype.equals_datatype(&DataType::Int32)
Expand Down
Loading
Loading