diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index 788be3526..99f8078bf 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -113,6 +113,9 @@ message PhysicalExprNode { // RowNum RowNumExprNode row_num_expr = 20100; + // SparkPartitionID + SparkPartitionIdExprNode spark_partition_id_expr = 20101; + // BloomFilterMightContain BloomFilterMightContainExprNode bloom_filter_might_contain_expr = 20200; } @@ -914,3 +917,5 @@ message ArrowType { // } //} message EmptyMessage{} + +message SparkPartitionIdExprNode {} diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index f081e32c7..9cad27541 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -52,7 +52,7 @@ use datafusion::{ use datafusion_ext_exprs::{ bloom_filter_might_contain::BloomFilterMightContainExpr, cast::TryCastExpr, get_indexed_field::GetIndexedFieldExpr, get_map_value::GetMapValueExpr, - named_struct::NamedStructExpr, row_num::RowNumExpr, + named_struct::NamedStructExpr, row_num::RowNumExpr, spark_partition_id::SparkPartitionIdExpr, spark_scalar_subquery_wrapper::SparkScalarSubqueryWrapperExpr, spark_udf_wrapper::SparkUDFWrapperExpr, string_contains::StringContainsExpr, string_ends_with::StringEndsWithExpr, string_starts_with::StringStartsWithExpr, @@ -967,6 +967,9 @@ impl PhysicalPlanner { Arc::new(StringContainsExpr::new(expr, e.infix.clone())) } ExprType::RowNumExpr(_) => Arc::new(RowNumExpr::default()), + ExprType::SparkPartitionIdExpr(_) => { + Arc::new(SparkPartitionIdExpr::new(self.partition_id)) + } ExprType::BloomFilterMightContainExpr(e) => Arc::new(BloomFilterMightContainExpr::new( e.uuid.clone(), self.try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?, diff --git a/native-engine/datafusion-ext-exprs/src/lib.rs b/native-engine/datafusion-ext-exprs/src/lib.rs index 3a685a41f..bb2757f00 100644 --- a/native-engine/datafusion-ext-exprs/src/lib.rs +++ b/native-engine/datafusion-ext-exprs/src/lib.rs @@ -23,6 +23,7 @@ pub mod get_indexed_field; pub mod get_map_value; pub mod named_struct; pub mod row_num; +pub mod spark_partition_id; pub mod spark_scalar_subquery_wrapper; pub mod spark_udf_wrapper; pub mod string_contains; diff --git a/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs b/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs new file mode 100644 index 000000000..d34150dbb --- /dev/null +++ b/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + array::{Int32Array, RecordBatch}, + datatypes::{DataType, Schema}, +}; +use datafusion::{ + common::Result, + logical_expr::ColumnarValue, + physical_expr::{PhysicalExpr, PhysicalExprRef}, +}; + +pub struct SparkPartitionIdExpr { + partition_id: i32, +} + +impl SparkPartitionIdExpr { + pub fn new(partition_id: usize) -> Self { + Self { + partition_id: partition_id as i32, + } + } +} + +impl Display for SparkPartitionIdExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SparkPartitionID") + } +} + +impl Debug for SparkPartitionIdExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SparkPartitionID") + } +} + +impl PartialEq for SparkPartitionIdExpr { + fn eq(&self, other: &Self) -> bool { + self.partition_id == other.partition_id + } +} + +impl Eq for SparkPartitionIdExpr {} + +impl Hash for SparkPartitionIdExpr { + fn hash(&self, state: &mut H) { + self.partition_id.hash(state) + } +} + +impl PhysicalExpr for SparkPartitionIdExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Int32) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); + let array = Int32Array::from_value(self.partition_id, num_rows); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn children(&self) -> Vec<&PhysicalExprRef> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec, + ) -> Result { + Ok(self) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "fmt_sql not used") + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::Int32Array, + datatypes::{Field, Schema}, + record_batch::RecordBatch, + }; + + use super::*; + + #[test] + fn test_data_type_and_nullable() { + let expr = SparkPartitionIdExpr::new(0); + let schema = Schema::new(vec![] as Vec); + assert_eq!( + expr.data_type(&schema).expect("data_type failed"), + DataType::Int32 + ); + assert!(!expr.nullable(&schema).expect("nullable failed")); + } + + #[test] + fn test_evaluate_returns_constant_partition_id() { + let expr = SparkPartitionIdExpr::new(5); + let schema = Schema::new(vec![Field::new("col", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .expect("RecordBatch creation failed"); + + let result = expr.evaluate(&batch).expect("evaluate failed"); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr + .as_any() + .downcast_ref::() + .expect("downcast failed"); + assert_eq!(int_arr.len(), 3); + for i in 0..3 { + assert_eq!(int_arr.value(i), 5); + } + } + _ => panic!("Expected Array result"), + } + } + + #[test] + fn test_evaluate_different_partition_ids() { + let schema = Schema::new(vec![Field::new("col", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .expect("RecordBatch creation failed"); + + for partition_id in [0, 1, 100, 999] { + let expr = SparkPartitionIdExpr::new(partition_id); + let result = expr.evaluate(&batch).expect("evaluate failed"); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr + .as_any() + .downcast_ref::() + .expect("downcast failed"); + for i in 0..int_arr.len() { + assert_eq!(int_arr.value(i), partition_id as i32); + } + } + _ => panic!("Expected Array result"), + } + } + } + + #[test] + fn test_equality() { + let expr1 = SparkPartitionIdExpr::new(5); + let expr2 = SparkPartitionIdExpr::new(5); + let expr3 = SparkPartitionIdExpr::new(3); + + assert_eq!(expr1, expr2); + assert_ne!(expr1, expr3); + } +} diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala index cb9492c9c..1427e01d3 100644 --- a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.expressions.Like import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.SparkPartitionID import org.apache.spark.sql.catalyst.expressions.StringSplit import org.apache.spark.sql.catalyst.expressions.TaggingExpression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -521,6 +522,13 @@ class ShimsImpl extends Shims with Logging { isPruningExpr: Boolean, fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = { e match { + case _: SparkPartitionID => + Some( + pb.PhysicalExprNode + .newBuilder() + .setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder()) + .build()) + case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, IntegerType)) // native StringSplit implementation does not support regex, so only most frequently // used cases without regex are supported diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 13a627f24..29b9386ae 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -1125,6 +1125,11 @@ object NativeConverters extends Logging { _.setRowNumExpr(pb.RowNumExprNode.newBuilder()) } + case StubExpr("SparkPartitionID", _, _) => + buildExprNode { + _.setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder()) + } + // hive UDFJson case e if udfJsonEnabled && (