diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 0c259be439..d04c58d94a 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -657,6 +657,15 @@ object CometConf extends ShimCometConf { .longConf .createWithDefault(3000L) + val COMET_ENABLE_GROUPING_ON_MAP_TYPE: ConfigEntry[Boolean] = + conf("spark.comet.enableGroupingOnMapType") + .doc( + "An experimental feature with limited capabilities to enable grouping on Spark Map type." + + "Requires Spark 4.0 and beyond along with support for scan on Map type." + + s"Set this config to true to enable grouping on map type. $COMPAT_GUIDE.") + .booleanConf + .createWithDefault(false) + /** Create a config to enable a specific operator */ private def createExecEnabledConfig( exec: String, diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 434c1934fb..56f5886f69 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -37,6 +37,7 @@ Comet provides the following configuration settings. | spark.comet.convert.parquet.enabled | When enabled, data from Spark (non-native) Parquet v1 and v2 scans will be converted to Arrow format. Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. | false | | spark.comet.debug.enabled | Whether to enable debug mode for Comet. When enabled, Comet will do additional checks for debugging purpose. For example, validating array when importing arrays from JVM at native side. Note that these checks may be expensive in performance and should only be enabled for debugging purpose. | false | | spark.comet.dppFallback.enabled | Whether to fall back to Spark for queries that use DPP. | true | +| spark.comet.enableGroupingOnMapType | An experimental feature with limited capabilities to enable grouping on Spark Map type.Requires Spark 4.0 and beyond along with support for scan on Map type.Set this config to true to enable grouping on map type. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | | spark.comet.enabled | Whether to enable Comet extension for Spark. When this is turned on, Spark will use Comet to read Parquet data source. Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. By default, this config is the value of the env var `ENABLE_COMET` if set, or true otherwise. | true | | spark.comet.exceptionOnDatetimeRebase | Whether to throw exception when seeing dates/timestamps from the legacy hybrid (Julian + Gregorian) calendar. Since Spark 3, dates/timestamps were written according to the Proleptic Gregorian calendar. When this is true, Comet will throw exceptions when seeing these dates/timestamps that were written by Spark version before 3.0. If this is false, these dates/timestamps will be read as if they were written to the Proleptic Gregorian calendar and will not be rebased. | false | | spark.comet.exec.aggregate.enabled | Whether to enable aggregate by default. | true | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0c3d345c8e..ef6e897fc5 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1207,7 +1207,21 @@ impl PhysicalPlanner { .map(|r| (r, format!("col_{idx}"))) }) .collect(); - let group_by = PhysicalGroupBy::new_single(group_exprs?); + + let mut map_converter = + crate::execution::utils::HashAggregateMapConverter::default(); + + // Currently DataFusion does not support grouping on Map type, as such pass the + // `group_exprs` through `maybe_wrap_map_type_in_grouping_exprs` which canonicalizes + // any Map type to a List of Struct types for grouping. + let maybe_wrapped_group_exprs = map_converter + .maybe_wrap_map_type_in_grouping_exprs( + &self.session_ctx.state(), + group_exprs?, + child.schema(), + )?; + + let group_by = PhysicalGroupBy::new_single(maybe_wrapped_group_exprs); let schema = child.schema(); let mode = if agg.mode == 0 { @@ -1234,12 +1248,24 @@ impl PhysicalPlanner { Arc::clone(&schema), )?, ); + + // To maintain schema consistency, the `AggregateExec` output is passed through + // `maybe_project_map_type_with_aggregation` which adds a projection that converts + // any canonicalized Map back to its original Map type. Not doing so will + // result in schema mismatch between Spark and DataFusion. + let maybe_aggregate_with_project = map_converter + .maybe_project_map_type_with_aggregation( + &self.session_ctx.state(), + agg, + aggregate, + )?; + let result_exprs: PhyExprResult = agg .result_exprs .iter() .enumerate() .map(|(idx, expr)| { - self.create_expr(expr, aggregate.schema()) + self.create_expr(expr, maybe_aggregate_with_project.schema()) .map(|r| (r, format!("col_{idx}"))) }) .collect(); @@ -1247,7 +1273,11 @@ impl PhysicalPlanner { if agg.result_exprs.is_empty() { Ok(( scans, - Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), + Arc::new(SparkPlan::new( + spark_plan.plan_id, + maybe_aggregate_with_project, + vec![child], + )), )) } else { // For final aggregation, DF's hash aggregate exec doesn't support Spark's @@ -1259,7 +1289,7 @@ impl PhysicalPlanner { // Spark side. let projection = Arc::new(ProjectionExec::try_new( result_exprs?, - Arc::clone(&aggregate), + Arc::clone(&maybe_aggregate_with_project), )?); Ok(( scans, @@ -1267,7 +1297,7 @@ impl PhysicalPlanner { spark_plan.plan_id, projection, vec![child], - vec![aggregate], + vec![maybe_aggregate_with_project], )), )) } diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 838c8523bb..f8c3ac47e3 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -18,11 +18,21 @@ /// Utils for array vector, etc. use crate::errors::ExpressionError; use crate::execution::operators::ExecutionError; +use arrow::datatypes::{DataType, Field, SchemaRef}; use arrow::{ array::ArrayData, error::ArrowError, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; +use datafusion::execution::FunctionRegistry; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr}; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_comet_proto::spark_operator::HashAggregate; +use datafusion_comet_spark_expr::create_comet_physical_fun; +use std::collections::HashMap; +use std::sync::Arc; impl From for ExecutionError { fn from(error: ArrowError) -> ExecutionError { @@ -127,3 +137,133 @@ pub fn bytes_to_i128(slice: &[u8]) -> i128 { i128::from_le_bytes(bytes) } + +type GroupingExprs = Vec<(Arc, String)>; +type GroupingExprResult = Result; + +/// Provides utilities to support grouping on Map type in HashAggregate. +pub struct HashAggregateMapConverter { + // Maps index of a grouping expression to its original Map type. This is used to convert a + // grouping expression return type back to Map type after aggregation. + expr_index_to_map_type: HashMap, +} + +impl HashAggregateMapConverter { + pub fn default() -> Self { + Self { + expr_index_to_map_type: HashMap::new(), + } + } + + /// Iterates through grouping expressions, and wraps those with Map type with `map_to_list` + /// scalar function. + pub fn maybe_wrap_map_type_in_grouping_exprs( + &mut self, + fn_registry: &dyn FunctionRegistry, + grouping_exprs: GroupingExprs, + child_schema: SchemaRef, + ) -> GroupingExprResult { + grouping_exprs + .into_iter() + .enumerate() + .map(|(idx, (physical_expr, expr_name))| { + let expr_data_type = physical_expr.data_type(&child_schema)?; + + if let DataType::Map(field_ref, _) = &expr_data_type { + let list_type = DataType::List(Arc::clone(field_ref)); + + // Update the map with the grouping expression index and its original Map type. + self.expr_index_to_map_type + .insert(idx, expr_data_type.clone()); + + // Create `map_to_list` expression to wrap the original grouping expression. + let map_to_list_func = create_comet_physical_fun( + "map_to_list", + list_type.clone(), + fn_registry, + None, + )?; + let map_to_list_expr = ScalarFunctionExpr::new( + "map_to_list", + map_to_list_func, + vec![physical_expr], + Arc::new(Field::new("map_to_list", list_type, true)), + ); + + // Return the scalar function expression. + Ok(( + Arc::new(map_to_list_expr) as Arc, + expr_name, + )) + } else { + Ok((physical_expr, expr_name)) + } + }) + .collect() + } + + /// Iterates over the aggregate schema, find the grouping expressions with Map type, and + /// wraps them with `map_from_list` scalar function to convert them back to Map type. It returns + /// a new ProjectionExec stacked on top of the original aggregate execution plan. If there was + /// no grouping expression with Map type, it returns the original aggregate execution plan. + pub fn maybe_project_map_type_with_aggregation( + &self, + fn_registry: &dyn FunctionRegistry, + hash_agg: &HashAggregate, + aggregate: Arc, + ) -> Result, ExecutionError> { + // If there was no grouping expression with Map type, return the original aggregate plan. + if self.expr_index_to_map_type.is_empty() { + return Ok(aggregate); + } + + // Insert the projection expressions in this. + let mut projection_exprs = Vec::new(); + + let num_grouping_cols = hash_agg.grouping_exprs.len(); + let agg_schema = aggregate.schema(); + + // Iterate through the aggregate schema. The aggregate schema contains both grouping + // expressions and aggregate expressions. The grouping expressions are at the beginning of + // the schema. + for (field_idx, field) in agg_schema.fields().iter().enumerate() { + let opt_map_type = self.expr_index_to_map_type.get(&field_idx); + + // If the current field is not a grouping expression or the grouping expression does not + // have Map type, then project the current field as it is. + if field_idx >= num_grouping_cols || opt_map_type.is_none() { + let col_expr = + Arc::new(Column::new(field.name(), field_idx)) as Arc; + projection_exprs.push((col_expr, field.name().to_string())); + continue; + } + + let map_type = opt_map_type.unwrap(); + + // Create `map_from_list` expression to convert the List type back to Map type. This + // expression was previously wrapped with `map_to_list` during grouping. + let map_from_list_func = + create_comet_physical_fun("map_from_list", map_type.clone(), fn_registry, None)?; + let col_expr = Arc::new(Column::new(field.name(), field_idx)); + let map_to_list_expr = Arc::new(ScalarFunctionExpr::new( + "map_from_list", + map_from_list_func, + vec![col_expr], + Arc::new(Field::new( + field.name(), + map_type.clone(), + field.is_nullable(), + )), + )) as Arc; + + // Add the `map_from_list` expression to the projection expressions. + projection_exprs.push((map_to_list_expr, field.name().to_string())); + } + + // Return a new ProjectionExec on top of the original aggregate plan. + Ok(Arc::new(ProjectionExec::try_new( + projection_exprs, + Arc::clone(&aggregate), + )?) as Arc) + } +} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 75f5689ad5..d7933cd575 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::map_funcs::{map_from_list, map_to_list, spark_map_sort}; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ @@ -157,6 +158,18 @@ pub fn create_comet_physical_fun( let fail_on_error = fail_on_error.unwrap_or(false); make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error) } + "map_sort" => { + let func = Arc::new(spark_map_sort); + make_comet_scalar_udf!("spark_map_sort", func, without data_type) + } + "map_to_list" => { + let func = Arc::new(map_to_list); + make_comet_scalar_udf!("map_to_list", func, without data_type) + } + "map_from_list" => { + let func = Arc::new(map_from_list); + make_comet_scalar_udf!("map_from_list", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 4b29b61775..27672ded5e 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -55,6 +55,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; +mod map_funcs; mod math_funcs; mod nondetermenistic_funcs; diff --git a/native/spark-expr/src/map_funcs/map_from_list.rs b/native/spark-expr/src/map_funcs/map_from_list.rs new file mode 100644 index 0000000000..498151a418 --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_from_list.rs @@ -0,0 +1,337 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayData, ArrayRef, ListArray, MapArray, StructArray}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Converts a ListArray of Structs of key-value pairs to a MapArray preserving the original layout. +/// One use case is to re-construct the original Map type after doing the group by using hash +/// aggregation on their canonicalized form. +pub fn map_from_list(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("map_from_list expects exactly one argument"); + } + + let arr_arg: ArrayRef = match &args[0] { + ColumnarValue::Array(array) => Arc::::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + }; + + let list_arg = match arr_arg.data_type() { + DataType::List(_) => arr_arg.as_any().downcast_ref::().unwrap(), + _ => return exec_err!("map_from_list expects ListArray type as argument"), + }; + + let list_struct_array = list_arg + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution( + "map_from_list expects ListArray to contain StructArray values".to_string(), + ) + })?; + let list_struct_array_data = list_struct_array.to_data(); + + let list_data = list_arg.to_data(); + + let list_offset_buffer = if list_data.buffers().len() == 1 { + list_data.buffers()[0].clone() + } else { + return exec_err!("map_from_list expects ListArray to have a single offset buffer"); + }; + + // TODO: Do no hard code field name and nullability. + // This create a field for the MapArray entries. + let map_entries_field = Arc::new(arrow::datatypes::Field::new( + "entries", + list_struct_array_data.data_type().clone(), + false, + )); + + // Build a MapArray preserving the same layout as the ListArray. + let mut map_builder = ArrayData::builder(DataType::Map(map_entries_field, false)) + .len(list_arg.len()) + .offset(list_arg.offset()) + .add_buffer(list_offset_buffer) + .child_data(vec![list_struct_array_data]); + + // Copy the null bitmaps if they exist. + if let Some(list_nulls) = list_data.nulls() { + map_builder = map_builder.nulls(Some(list_nulls.clone())); + } + + let map_data = map_builder.build()?; + let map_array = Arc::new(MapArray::from(map_data)); + Ok(ColumnarValue::Array(map_array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + test_create_list_array, test_create_map_array, test_create_nested_list_array, + test_create_nested_map_array, test_verify_result_equals_map, + }; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{Int32Array, MapArray, MapFieldNames, StringArray}; + use datafusion::common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_map_from_list_basic() { + let keys_arg = [ + vec![ + Some("b".to_string()), + Some("a".to_string()), + Some("c".to_string()), + ], + Vec::>::new(), + ]; + let values_arg = [vec![Some(2), Some(1), Some(3)], Vec::>::new()]; + let validity = [true, true]; + let list_input = test_create_list_array!( + StringBuilder::new(), + Int32Builder::new(), + keys_arg, + values_arg, + validity + ); + let expected_map_arr = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + [ + vec![ + Some("b".to_string()), + Some("a".to_string()), + Some("c".to_string()), + ], + Vec::>::new(), + ], + [vec![Some(2), Some(1), Some(3)], Vec::>::new()], + [true, true] + ); + let args = vec![ColumnarValue::Array(Arc::new(list_input))]; + let result = map_from_list(&args).unwrap(); + + test_verify_result_equals_map!(StringArray, Int32Array, result, expected_map_arr); + } + + #[test] + fn test_map_from_list_with_scalar_argument() { + let list_input = test_create_list_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("b".to_string()), Some("a".to_string())]], + vec![vec![Some(2), Some(1)]], + vec![true] + ); + let args = vec![ColumnarValue::Scalar( + ScalarValue::try_from_array(&list_input, 0).unwrap(), + )]; + let result = map_from_list(&args).unwrap(); + let expected_map_arr = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("b".to_string()), Some("a".to_string())]], + vec![vec![Some(2), Some(1)]], + vec![true] + ); + test_verify_result_equals_map!(StringArray, Int32Array, result, expected_map_arr); + } + + #[test] + fn test_map_from_list_with_invalid_arguments() { + let res = map_from_list(&[]); + assert!(res + .unwrap_err() + .to_string() + .contains("map_from_list expects exactly one argument")); + + let list_input = test_create_list_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("a".to_string())]], + vec![vec![Some(1)]], + vec![true] + ); + let args = vec![ + ColumnarValue::Array(Arc::new(list_input.clone())), + ColumnarValue::Array(Arc::new(list_input)), + ]; + let res = map_from_list(&args); + assert!(res + .unwrap_err() + .to_string() + .contains("map_from_list expects exactly one argument")); + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let res = map_from_list(&[ColumnarValue::Array(int_array)]); + assert!(res + .unwrap_err() + .to_string() + .contains("map_from_list expects ListArray type as argument")); + } + + #[test] + fn test_map_from_list_with_nested_maps() { + let outer_keys = [ + vec![Some("key_b_1".to_string()), Some("key_a_1".to_string())], + vec![Some("key_b_2".to_string()), Some("key_a_2".to_string())], + ]; + let inner_map_data = [ + vec![ + ( + vec![ + Some("key_b_1=key1".to_string()), + Some("key_b_1=key0".to_string()), + ], + vec![ + Some("key_b_1=value1".to_string()), + Some("key_b_1=value0".to_string()), + ], + true, + ), + ( + vec![ + Some("key_a_1=key0".to_string()), + Some("key_a_1=key1".to_string()), + ], + vec![ + Some("key_a_1=value0".to_string()), + Some("key_a_1=value1".to_string()), + ], + true, + ), + ], + vec![ + ( + vec![ + Some("key_b_2=key1".to_string()), + Some("key_b_2=key0".to_string()), + ], + vec![ + Some("key_b_2=value1".to_string()), + Some("key_b_2=value0".to_string()), + ], + true, + ), + ( + vec![ + Some("key_a_2=key1".to_string()), + Some("key_a_2=key0".to_string()), + ], + vec![ + Some("key_a_2=value1".to_string()), + Some("key_a_2=value0".to_string()), + ], + true, + ), + ], + ]; + let validity = [true, true]; + + let expected_map_arr = test_create_nested_map_array!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + outer_keys, + inner_map_data, + validity + ); + let list_input = test_create_nested_list_array!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + vec![ + vec![Some("key_b_1".to_string()), Some("key_a_1".to_string())], + vec![Some("key_b_2".to_string()), Some("key_a_2".to_string())], + ], + vec![ + vec![ + ( + vec![ + Some("key_b_1=key1".to_string()), + Some("key_b_1=key0".to_string()), + ], + vec![ + Some("key_b_1=value1".to_string()), + Some("key_b_1=value0".to_string()), + ], + true, + ), + ( + vec![ + Some("key_a_1=key0".to_string()), + Some("key_a_1=key1".to_string()), + ], + vec![ + Some("key_a_1=value0".to_string()), + Some("key_a_1=value1".to_string()), + ], + true, + ), + ], + vec![ + ( + vec![ + Some("key_b_2=key1".to_string()), + Some("key_b_2=key0".to_string()), + ], + vec![ + Some("key_b_2=value1".to_string()), + Some("key_b_2=value0".to_string()), + ], + true, + ), + ( + vec![ + Some("key_a_2=key1".to_string()), + Some("key_a_2=key0".to_string()), + ], + vec![ + Some("key_a_2=value1".to_string()), + Some("key_a_2=value0".to_string()), + ], + true, + ), + ], + ], + vec![true, true] + ); + let args = vec![ColumnarValue::Array(Arc::new(list_input))]; + let result = map_from_list(&args).unwrap(); + test_verify_result_equals_map!(StringArray, MapArray, result, expected_map_arr); + } +} diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs new file mode 100644 index 0000000000..5c86b8b6d2 --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -0,0 +1,518 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, MapArray, StructArray}; +use arrow::compute::{concat, sort_to_indices, take, SortOptions}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible `MapSort` implementation. +/// Sorts each entries of a MapArray by keys in ascending order without changing the ordering of the +/// maps in the array. +/// +/// For eg. If the input is a MapArray with entries: +/// ```text +/// [ +/// {"c": 3, "a": 1, "b": 2} +/// {"x": 1, "z": 3, "y": 2} +/// {"a": 1, "b": 2, "c": 3} +/// ] +/// ``` +/// The output will be: +/// ```text +/// [ +/// {"a": 1, "b": 2, "c": 3} +/// {"x": 1, "y": 2, "z": 3} +/// {"a": 1, "b": 2, "c": 3} +/// ] +/// ``` +pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("spark_map_sort expects exactly one argument"); + } + + let arr_arg: ArrayRef = match &args[0] { + ColumnarValue::Array(array) => Arc::::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + }; + + let (maps_arg, map_field, is_sorted) = match arr_arg.data_type() { + DataType::Map(map_field, is_sorted) => { + let maps_arg = arr_arg.as_any().downcast_ref::().unwrap(); + (maps_arg, map_field, is_sorted) + } + _ => return exec_err!("spark_map_sort expects Map type as argument"), + }; + + let maps_arg_entries = maps_arg.entries(); + let maps_arg_offsets = maps_arg.offsets(); + + let mut sorted_map_entries_vec: Vec = Vec::with_capacity(maps_arg.len()); + + // Iterate over each map in the MapArray and build a vector of sorted map entries. + for idx in 0..maps_arg.len() { + // Retrieve the start and end of the current map entries from the offset buffer. + let map_start = maps_arg_offsets[idx] as usize; + let map_end = maps_arg_offsets[idx + 1] as usize; + let map_len = map_end - map_start; + + // Get the current map entries. + let map_entries = maps_arg_entries.slice(map_start, map_len); + + if map_len == 0 { + sorted_map_entries_vec.push(Arc::new(map_entries)); + continue; + } + + // Sort the entry-indices of the map by their keys in ascending order. + let map_keys = map_entries.column(0); + let sort_options = SortOptions { + descending: false, + nulls_first: true, + }; + let sorted_indices = sort_to_indices(&map_keys, Some(sort_options), None)?; + + // Get the sorted map entries using the sorted indices and add it to the sorted map vector. + let sorted_map_entries = take(&map_entries, &sorted_indices, None)?; + sorted_map_entries_vec.push(sorted_map_entries); + } + + // Flatten the sorted map entries into a single StructArray. + let sorted_map_entries_arr: Vec<&dyn Array> = sorted_map_entries_vec + .iter() + .map(|arr| arr.as_ref()) + .collect(); + let combined_sorted_map_entries = concat(&sorted_map_entries_arr)?; + let sorted_map_struct = combined_sorted_map_entries + .as_any() + .downcast_ref::() + .unwrap(); + + // Create a new MapArray with the sorted entries while preserving the original metadata. + // Note that then even though the map is sorted, the is_sorted flag is been set from the + // original MapArray which may be false. Although, this might be less efficient, it is has been + // done to keep the schema consistent. + let sorted_map_arr = Arc::new(MapArray::try_new( + Arc::::clone(map_field), + maps_arg.offsets().clone(), + sorted_map_struct.clone(), + maps_arg.nulls().cloned(), + *is_sorted, + )?); + + Ok(ColumnarValue::Array(sorted_map_arr)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + test_create_map_array, test_create_nested_map_array, test_verify_result_equals_map, + }; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{Int32Array, MapFieldNames, StringArray}; + use datafusion::common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_map_sort_with_string_keys() { + let keys_arg = [ + vec![ + Some("c".to_string()), + Some("a".to_string()), + Some("b".to_string()), + ], + vec![ + Some("z".to_string()), + Some("y".to_string()), + Some("x".to_string()), + ], + vec![ + Some("a".to_string()), + Some("b".to_string()), + Some("c".to_string()), + ], + vec![ + Some("fusion".to_string()), + Some("comet".to_string()), + Some("data".to_string()), + ], + ]; + let values_arg = [ + vec![Some(3), Some(1), Some(2)], + vec![Some(30), Some(20), Some(10)], + vec![Some(1), Some(2), Some(3)], + vec![Some(300), Some(100), Some(200)], + ]; + let validity = [true, true, true, true]; + + let map_arr_arg = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + keys_arg, + values_arg, + validity + ); + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys = [ + vec![ + Some("a".to_string()), + Some("b".to_string()), + Some("c".to_string()), + ], + vec![ + Some("x".to_string()), + Some("y".to_string()), + Some("z".to_string()), + ], + vec![ + Some("a".to_string()), + Some("b".to_string()), + Some("c".to_string()), + ], + vec![ + Some("comet".to_string()), + Some("data".to_string()), + Some("fusion".to_string()), + ], + ]; + let expected_values = [ + vec![Some(1), Some(2), Some(3)], + vec![Some(10), Some(20), Some(30)], + vec![Some(1), Some(2), Some(3)], + vec![Some(100), Some(200), Some(300)], + ]; + let expected_validity = [true, true, true, true]; + + let expected_map_arr = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + expected_keys, + expected_values, + expected_validity + ); + test_verify_result_equals_map!(StringArray, Int32Array, result, expected_map_arr); + } + + #[test] + fn test_map_sort_with_int_keys() { + let keys_arg = [ + vec![Some(3), Some(2), Some(1)], + vec![Some(100), Some(50), Some(20)], + vec![Some(20), Some(50), Some(100)], + vec![Some(-5), Some(0), Some(-1)], + ]; + let values_arg = [ + vec![ + Some("three".to_string()), + Some("two".to_string()), + Some("one".to_string()), + ], + vec![ + Some("hundred".to_string()), + Some("fifty".to_string()), + Some("twenty".to_string()), + ], + vec![ + Some("twenty".to_string()), + Some("fifty".to_string()), + Some("hundred".to_string()), + ], + vec![ + Some("minus five".to_string()), + Some("zero".to_string()), + Some("minus one".to_string()), + ], + ]; + let validity = [true, true, true, true]; + + let map_arr_arg = test_create_map_array!( + Int32Builder::new(), + StringBuilder::new(), + keys_arg, + values_arg, + validity + ); + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys = [ + vec![Some(1), Some(2), Some(3)], + vec![Some(20), Some(50), Some(100)], + vec![Some(20), Some(50), Some(100)], + vec![Some(-5), Some(-1), Some(0)], + ]; + let expected_values = [ + vec![ + Some("one".to_string()), + Some("two".to_string()), + Some("three".to_string()), + ], + vec![ + Some("twenty".to_string()), + Some("fifty".to_string()), + Some("hundred".to_string()), + ], + vec![ + Some("twenty".to_string()), + Some("fifty".to_string()), + Some("hundred".to_string()), + ], + vec![ + Some("minus five".to_string()), + Some("minus one".to_string()), + Some("zero".to_string()), + ], + ]; + let expected_validity = [true, true, true, true]; + + let expected_map_arr = test_create_map_array!( + Int32Builder::new(), + StringBuilder::new(), + expected_keys, + expected_values, + expected_validity + ); + test_verify_result_equals_map!(Int32Array, StringArray, result, expected_map_arr); + } + + #[test] + fn test_map_sort_with_nested_maps() { + let outer_keys = [ + // Map 1 keys. + [Some("key_b_1".to_string()), Some("key_a_1".to_string())], + // Map 2 keys. + [Some("key_b_2".to_string()), Some("key_a_2".to_string())], + ]; + let inner_map_data = [ + // Map 1 values, which is another map. + [ + ( + vec![ + Some("key_b_1=key1".to_string()), + Some("key_b_1=key0".to_string()), + ], + vec![Some("key_b_1=value1"), Some("key_b_1=value0")], + true, + ), + ( + vec![ + Some("key_a_1=key0".to_string()), + Some("key_a_1=key1".to_string()), + ], + vec![Some("key_a_1=value0"), Some("key_a_1=value1")], + true, + ), + ], + // Map 2 values, which is another map. + [ + ( + vec![ + Some("key_b_2=key1".to_string()), + Some("key_b_2=key0".to_string()), + ], + vec![Some("key_b_2=value1"), Some("key_b_2=value0")], + true, + ), + ( + vec![ + Some("key_a_2=key1".to_string()), + Some("key_a_2=key0".to_string()), + ], + vec![Some("key_a_2=value1"), Some("key_a_2=value0")], + true, + ), + ], + ]; + let validity = [true, true]; + + let map_arr_arg = test_create_nested_map_array!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + outer_keys, + inner_map_data, + validity + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_outer_keys = [ + // Map 1 keys. + [Some("key_a_1".to_string()), Some("key_b_1".to_string())], + // Map 2 keys. + [Some("key_a_2".to_string()), Some("key_b_2".to_string())], + ]; + let expected_inner_map_data = vec![ + // Map 1 values are reordered with the outer keys with getting sorted. + [ + ( + vec![ + Some("key_a_1=key0".to_string()), + Some("key_a_1=key1".to_string()), + ], + vec![Some("key_a_1=value0"), Some("key_a_1=value1")], + true, + ), + ( + vec![ + Some("key_b_1=key1".to_string()), + Some("key_b_1=key0".to_string()), + ], + vec![Some("key_b_1=value1"), Some("key_b_1=value0")], + true, + ), + ], + // Map 2 values are reordered with the outer keys with getting sorted. + [ + ( + vec![ + Some("key_a_2=key1".to_string()), + Some("key_a_2=key0".to_string()), + ], + vec![Some("key_a_2=value1"), Some("key_a_2=value0")], + true, + ), + ( + vec![ + Some("key_b_2=key1".to_string()), + Some("key_b_2=key0".to_string()), + ], + vec![Some("key_b_2=value1"), Some("key_b_2=value0")], + true, + ), + ], + ]; + let expected_validity = [true, true]; + + let expected_map_arr = test_create_nested_map_array!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + expected_outer_keys, + expected_inner_map_data, + expected_validity + ); + + test_verify_result_equals_map!(StringArray, MapArray, result, expected_map_arr); + } + + #[test] + fn test_map_sort_with_scalar_argument() { + let map_array = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("b".to_string()), Some("a".to_string())]], + vec![vec![Some(2), Some(1)]], + vec![true] + ); + + let args = vec![ColumnarValue::Scalar( + ScalarValue::try_from_array(&map_array, 0).unwrap(), + )]; + let result = spark_map_sort(&args).unwrap(); + + let expected_map_arr = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("a".to_string()), Some("b".to_string())]], + vec![vec![Some(1), Some(2)]], + vec![true] + ); + test_verify_result_equals_map!(StringArray, Int32Array, result, expected_map_arr); + } + + #[test] + fn test_map_sort_with_empty_map() { + let map_arr_arg = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![Vec::>::new()], + vec![Vec::>::new()], + vec![false] + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + let expected_map_arr = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![Vec::>::new()], + vec![Vec::>::new()], + vec![false] + ); + test_verify_result_equals_map!(StringArray, Int32Array, result, expected_map_arr); + } + + #[test] + fn test_map_sort_with_invalid_arguments() { + let result = spark_map_sort(&[]); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("expects exactly one argument")); + + let map_array = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("a".to_string())]], + vec![vec![Some(1)]], + vec![true] + ); + + let args = vec![ + ColumnarValue::Array(Arc::new(map_array.clone())), + ColumnarValue::Array(Arc::new(map_array)), + ]; + let result = spark_map_sort(&args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects exactly one argument")); + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let args = vec![ColumnarValue::Array(int_array)]; + + let result = spark_map_sort(&args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects Map type as argument")); + } +} diff --git a/native/spark-expr/src/map_funcs/map_test_helpers.rs b/native/spark-expr/src/map_funcs/map_test_helpers.rs new file mode 100644 index 0000000000..88df40da03 --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_test_helpers.rs @@ -0,0 +1,322 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_export] +macro_rules! test_create_map_array { + ($key_builder:expr, $value_builder:expr, $keys:expr, $values:expr, $validity:expr) => {{ + let mut map_builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + $key_builder, + $value_builder, + ); + + assert_eq!($keys.len(), $values.len()); + assert_eq!($keys.len(), $validity.len()); + + let total_maps = $keys.len(); + for map_idx in 0..total_maps { + let map_keys = &$keys[map_idx]; + let map_values = &$values[map_idx]; + assert_eq!(map_keys.len(), map_values.len()); + + let map_entries = map_keys.len(); + for entry_idx in 0..map_entries { + let key_val = &map_keys[entry_idx]; + match key_val { + Some(key) => map_builder.keys().append_value(key.clone()), + None => panic!("Unexpected None key found"), + } + + let value = &map_values[entry_idx]; + map_builder.values().append_value(value.clone().unwrap()); + } + + let is_valid = $validity[map_idx]; + map_builder.append(is_valid).unwrap(); + } + + map_builder.finish() + }}; +} + +#[macro_export] +macro_rules! test_create_nested_map_array { + ($key_builder:expr, $value_builder:expr, $keys:expr, $values:expr, $validity:expr) => {{ + let mut map_builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + $key_builder, + $value_builder, + ); + + assert_eq!($keys.len(), $values.len()); + assert_eq!($keys.len(), $validity.len()); + + let total_maps = $keys.len(); + for map_idx in 0..total_maps { + let map_keys = &$keys[map_idx]; + let map_values = &$values[map_idx]; + assert_eq!(map_keys.len(), map_values.len()); + + let map_entries = map_keys.len(); + for entry_idx in 0..map_entries { + let key_val = &map_keys[entry_idx]; + match key_val { + Some(key) => map_builder.keys().append_value(key.clone()), + None => panic!("Unexpected None key found in outer map"), + } + + let (inner_keys, inner_values, inner_valid) = &map_values[entry_idx]; + + let inner_entries = inner_keys.len(); + for inner_idx in 0..inner_entries { + let inner_key_val = &inner_keys[inner_idx]; + match inner_key_val { + Some(inner_key) => { + map_builder.values().keys().append_value(inner_key.clone()) + } + None => panic!("Unexpected None key found in inner map"), + } + + let inner_value = &inner_values[inner_idx]; + map_builder + .values() + .values() + .append_value(inner_value.clone().unwrap()); + } + + map_builder.values().append(*inner_valid).unwrap(); + } + + let is_valid = $validity[map_idx]; + map_builder.append(is_valid).unwrap(); + } + + map_builder.finish() + }}; +} + +#[macro_export] +macro_rules! test_verify_result_equals_map { + ($keyType:ty, $valueType:ty, $result:expr, $expected_map_arr:expr) => {{ + match $result { + ColumnarValue::Array(actual_arr) => { + let actual_map_arr = actual_arr.as_any().downcast_ref::().unwrap(); + + assert_eq!(actual_map_arr.len(), $expected_map_arr.len()); + assert_eq!(actual_map_arr.offsets(), $expected_map_arr.offsets()); + assert_eq!(actual_map_arr.nulls(), $expected_map_arr.nulls()); + assert_eq!(actual_map_arr.data_type(), $expected_map_arr.data_type()); + + let actual_entries = actual_map_arr.entries(); + let actual_keys = actual_entries + .column(0) + .as_any() + .downcast_ref::<$keyType>() + .unwrap(); + let actual_values = actual_entries + .column(1) + .as_any() + .downcast_ref::<$valueType>() + .unwrap(); + + let expected_entries = $expected_map_arr.entries(); + let expected_keys = expected_entries + .column(0) + .as_any() + .downcast_ref::<$keyType>() + .unwrap(); + let expected_values = expected_entries + .column(1) + .as_any() + .downcast_ref::<$valueType>() + .unwrap(); + + for idx in 0..actual_entries.len() { + assert_eq!(actual_keys.value(idx), expected_keys.value(idx)); + assert_eq!(actual_values.value(idx), expected_values.value(idx)); + } + } + unexpected_arr => { + panic!("Actual result: {unexpected_arr:?} is not an Array ColumnarValue") + } + } + }}; +} + +#[macro_export] +macro_rules! test_create_list_array { + ($key_builder:expr, $value_builder:expr, $keys:expr, $values:expr, $validity:expr) => {{ + use arrow::array::{ArrayRef, GenericListArray, StructArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::Field; + use std::sync::Arc; + + assert_eq!($keys.len(), $values.len()); + assert_eq!($keys.len(), $validity.len()); + + let total_lists = $keys.len(); + + let mut key_builder = $key_builder; + let mut value_builder = $value_builder; + + let mut offsets: Vec = Vec::with_capacity(total_lists + 1); + offsets.push(0); + + for list_idx in 0..total_lists { + let list_keys = &$keys[list_idx]; + let list_values = &$values[list_idx]; + assert_eq!(list_keys.len(), list_values.len()); + + let entries = list_keys.len(); + for entry_idx in 0..entries { + let key_val = &list_keys[entry_idx]; + match key_val { + Some(key) => key_builder.append_value(key.clone()), + None => panic!("Unexpected None key found"), + } + + let value = &list_values[entry_idx]; + value_builder.append_value(value.clone().unwrap()); + } + + let last = *offsets.last().unwrap(); + offsets.push(last + (entries as i32)); + } + + let keys_array = key_builder.finish(); + let values_array = value_builder.finish(); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("key", keys_array.data_type().clone(), false)), + Arc::new(keys_array) as ArrayRef, + ), + ( + Arc::new(Field::new("value", values_array.data_type().clone(), true)), + Arc::new(values_array) as ArrayRef, + ), + ]); + + let list_field = Arc::new(Field::new( + "entries", + struct_array.data_type().clone(), + true, + )); + + GenericListArray::::try_new( + list_field, + OffsetBuffer::new(offsets.into()), + Arc::new(struct_array) as ArrayRef, + None, + ) + .unwrap() + }}; +} + +#[macro_export] +macro_rules! test_create_nested_list_array { + ($key_builder:expr, $map_value_builder:expr, $keys:expr, $values:expr, $validity:expr) => {{ + use arrow::array::{ArrayRef, GenericListArray, StructArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::Field; + use std::sync::Arc; + + assert_eq!($keys.len(), $values.len()); + assert_eq!($keys.len(), $validity.len()); + + let total_lists = $keys.len(); + + let mut key_builder = $key_builder; + let mut map_builder = $map_value_builder; + + let mut offsets: Vec = Vec::with_capacity(total_lists + 1); + offsets.push(0); + + let mut total_entries = 0; + for list_idx in 0..total_lists { + let list_keys = &$keys[list_idx]; + let list_values = &$values[list_idx]; + assert_eq!(list_keys.len(), list_values.len()); + + let entries = list_keys.len(); + for entry_idx in 0..entries { + let key_val = &list_keys[entry_idx]; + match key_val { + Some(key) => key_builder.append_value(key.clone()), + None => panic!("Unexpected None key found in outer list"), + } + + let (inner_keys, inner_values, inner_valid) = &list_values[entry_idx]; + assert_eq!(inner_keys.len(), inner_values.len()); + + for inner_idx in 0..inner_keys.len() { + let inner_key_val = &inner_keys[inner_idx]; + match inner_key_val { + Some(inner_key) => map_builder.keys().append_value(inner_key.clone()), + None => panic!("Unexpected None key found in inner map"), + } + + let inner_value = &inner_values[inner_idx]; + map_builder + .values() + .append_value(inner_value.clone().unwrap()); + } + + map_builder.append(*inner_valid).unwrap(); + } + + total_entries += entries as i32; + offsets.push(total_entries); + } + + let keys_array = key_builder.finish(); + let values_array = map_builder.finish(); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("key", keys_array.data_type().clone(), false)), + Arc::new(keys_array) as ArrayRef, + ), + ( + Arc::new(Field::new("value", values_array.data_type().clone(), true)), + Arc::new(values_array) as ArrayRef, + ), + ]); + + let list_field = Arc::new(Field::new( + "entries", + struct_array.data_type().clone(), + true, + )); + + GenericListArray::::try_new( + list_field, + OffsetBuffer::new(offsets.into()), + Arc::new(struct_array) as ArrayRef, + None, + ) + .unwrap() + }}; +} diff --git a/native/spark-expr/src/map_funcs/map_to_list.rs b/native/spark-expr/src/map_funcs/map_to_list.rs new file mode 100644 index 0000000000..96f67fb158 --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_to_list.rs @@ -0,0 +1,368 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayData, ArrayRef, ListArray, MapArray}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Converts a MapArray to a ListArray of Structs of key-value pairs preserving the original layout. +/// This method can be used to canonicalize map representations. +/// One use case is to wrap the Map type with this function before doing the grouping. +pub fn map_to_list(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("map_to_list expects exactly one argument"); + } + + let arr_arg: ArrayRef = match &args[0] { + ColumnarValue::Array(array) => Arc::::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + }; + + // TODO: Add `is_sorted` field to the metadata. + let (maps_arg, map_field) = match arr_arg.data_type() { + DataType::Map(field, _) => (arr_arg.as_any().downcast_ref::().unwrap(), field), + _ => return exec_err!("map_to_list expects MapArray type as argument"), + }; + + let maps_data = maps_arg.to_data(); + + // A Map only has a single top-level buffer which is the offset buffer. + let offset_buffer = maps_data.buffers()[0].clone(); + + // These are the entries of the map, which is a StructArray of key-value pairs. + let maps_entries = maps_arg.entries(); + let map_entries_data = maps_entries.to_data(); + + // Build a ListArray preserving the same layout as the MapArray. + let mut list_builder = ArrayData::builder(DataType::List( + Arc::::clone(map_field), + )) + .len(maps_arg.len()) + .offset(maps_arg.offset()) + .add_buffer(offset_buffer) + .child_data(vec![map_entries_data]); + + // Copy the null bitmaps they exist. + if let Some(maps_nulls) = maps_data.nulls() { + list_builder = list_builder.nulls(Some(maps_nulls.clone())); + } + + let list_data = list_builder.build()?; + let list_array = Arc::new(ListArray::from(list_data)); + Ok(ColumnarValue::Array(list_array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{test_create_map_array, test_create_nested_map_array}; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{Int32Array, ListArray, MapArray, MapFieldNames, StringArray, StructArray}; + use datafusion::common::ScalarValue; + use std::sync::Arc; + + macro_rules! verify_result { + ($key_type:ty, $value_type:ty, $result:expr, $expected_map_arr:expr) => {{ + match $result { + ColumnarValue::Array(actual_arr) => { + let actual_list = actual_arr.as_any().downcast_ref::().unwrap(); + + verify_metadata!(actual_list, $expected_map_arr); + + let actual_entries = actual_list + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_entries = $expected_map_arr.entries(); + verify_entries!($key_type, $value_type, actual_entries, expected_entries); + } + unexpected => { + panic!("Actual result: {unexpected:?} is not an Array ColumnarValue") + } + } + }}; + ($outer_key_type:ty, $inner_key_type:ty, $inner_value_type:ty, $result:expr, $expected_map_arr:expr) => {{ + match $result { + ColumnarValue::Array(actual_arr) => { + let list_arr = actual_arr.as_any().downcast_ref::().unwrap(); + + verify_metadata!(list_arr, $expected_map_arr); + + let actual_entries = list_arr + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let expected_entries = $expected_map_arr.entries(); + + let actual_keys = actual_entries + .column(0) + .as_any() + .downcast_ref::<$outer_key_type>() + .unwrap(); + let expected_keys = expected_entries + .column(0) + .as_any() + .downcast_ref::<$outer_key_type>() + .unwrap(); + + assert_eq!(actual_keys.len(), expected_keys.len()); + + for idx in 0..actual_entries.len() { + assert_eq!(actual_keys.is_null(idx), expected_keys.is_null(idx)); + if !actual_keys.is_null(idx) { + assert_eq!(actual_keys.value(idx), expected_keys.value(idx)); + } + } + + let actual_map_values = actual_entries + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let expected_map_values = expected_entries + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(actual_map_values.len(), expected_map_values.len()); + + let actual_map_offsets: Vec = + actual_map_values.offsets().iter().copied().collect(); + let expected_map_offsets: Vec = + expected_map_values.offsets().iter().copied().collect(); + + assert_eq!(actual_map_offsets, expected_map_offsets); + assert_eq!(actual_map_values.nulls(), expected_map_values.nulls()); + + verify_entries!( + $inner_key_type, + $inner_value_type, + actual_map_values.entries(), + expected_map_values.entries() + ); + } + unexpected => { + panic!("Actual result: {unexpected:?} is not an Array ColumnarValue") + } + } + }}; + } + + macro_rules! verify_metadata { + ($list_arr:ident, $map_arr:expr) => {{ + assert_eq!($list_arr.len(), $map_arr.len()); + assert_eq!($list_arr.offset(), $map_arr.offset()); + + let actual_offsets: Vec = $list_arr.offsets().iter().copied().collect(); + let expected_offsets: Vec = $map_arr.offsets().iter().copied().collect(); + + assert_eq!(actual_offsets, expected_offsets); + assert_eq!($list_arr.nulls(), $map_arr.nulls()); + }}; + } + + macro_rules! verify_entries { + ($key_type:ty, $value_type:ty, $actual_entries:expr, $expected_entries:expr) => { + assert_eq!($actual_entries.data_type(), $expected_entries.data_type()); + assert_eq!($actual_entries.len(), $expected_entries.len()); + + let actual_keys = $actual_entries + .column(0) + .as_any() + .downcast_ref::<$key_type>() + .unwrap(); + let expected_keys = $expected_entries + .column(0) + .as_any() + .downcast_ref::<$key_type>() + .unwrap(); + + let actual_values = $actual_entries + .column(1) + .as_any() + .downcast_ref::<$value_type>() + .unwrap(); + let expected_values = $expected_entries + .column(1) + .as_any() + .downcast_ref::<$value_type>() + .unwrap(); + + for idx in 0..$actual_entries.len() { + assert_eq!(actual_keys.is_null(idx), expected_keys.is_null(idx)); + if !actual_keys.is_null(idx) { + assert_eq!(actual_keys.value(idx), expected_keys.value(idx)); + } + + assert_eq!(actual_values.is_null(idx), expected_values.is_null(idx)); + if !actual_values.is_null(idx) { + assert_eq!(actual_values.value(idx), expected_values.value(idx)); + } + } + }; + } + + #[test] + fn test_map_to_list_basic() { + let keys_arg = [ + vec![ + Some("b".to_string()), + Some("a".to_string()), + Some("c".to_string()), + ], + Vec::>::new(), + ]; + let values_arg = [vec![Some(2), Some(1), Some(3)], Vec::>::new()]; + let validity = [true, false]; + + let map_arr_arg = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + keys_arg, + values_arg, + validity + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg.clone()))]; + let result = map_to_list(&args).unwrap(); + + verify_result!(StringArray, Int32Array, result, map_arr_arg); + } + + #[test] + fn test_map_to_list_with_scalar_argument() { + let map_arr = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("b".to_string()), Some("a".to_string())]], + vec![vec![Some(2), Some(1)]], + vec![true] + ); + + let args = vec![ColumnarValue::Scalar( + ScalarValue::try_from_array(&map_arr, 0).unwrap(), + )]; + + let result = map_to_list(&args).unwrap(); + verify_result!(StringArray, Int32Array, result, map_arr); + } + + #[test] + fn test_map_to_list_with_invalid_arguments() { + let res = map_to_list(&[]); + assert!(res + .unwrap_err() + .to_string() + .contains("map_to_list expects exactly one argument")); + + let map_arr_arg = test_create_map_array!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec![Some("a".to_string())]], + vec![vec![Some(1)]], + vec![true] + ); + let args = vec![ + ColumnarValue::Array(Arc::new(map_arr_arg.clone())), + ColumnarValue::Array(Arc::new(map_arr_arg)), + ]; + let res = map_to_list(&args); + assert!(res + .unwrap_err() + .to_string() + .contains("map_to_list expects exactly one argument")); + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let res = map_to_list(&[ColumnarValue::Array(int_array)]); + assert!(res + .unwrap_err() + .to_string() + .contains("map_to_list expects MapArray type as argument")); + } + + #[test] + fn test_map_to_list_with_nested_maps() { + let outer_keys = [ + vec![Some("key_b_1".to_string()), Some("key_a_1".to_string())], + vec![Some("key_b_2".to_string()), Some("key_a_2".to_string())], + ]; + let inner_map_data = [ + vec![ + ( + vec![ + Some("key_b_1=key1".to_string()), + Some("key_b_1=key0".to_string()), + ], + vec![Some("key_b_1=value1"), Some("key_b_1=value0")], + true, + ), + ( + vec![ + Some("key_a_1=key0".to_string()), + Some("key_a_1=key1".to_string()), + ], + vec![Some("key_a_1=value0"), Some("key_a_1=value1")], + true, + ), + ], + vec![ + ( + vec![ + Some("key_b_2=key1".to_string()), + Some("key_b_2=key0".to_string()), + ], + vec![Some("key_b_2=value1"), Some("key_b_2=value0")], + true, + ), + ( + vec![ + Some("key_a_2=key1".to_string()), + Some("key_a_2=key0".to_string()), + ], + vec![Some("key_a_2=value1"), Some("key_a_2=value0")], + true, + ), + ], + ]; + let validity = [true, true]; + + let map_arr = test_create_nested_map_array!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + outer_keys, + inner_map_data, + validity + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr.clone()))]; + let result = map_to_list(&args).unwrap(); + verify_result!(StringArray, StringArray, StringArray, result, map_arr); + } +} diff --git a/native/spark-expr/src/map_funcs/mod.rs b/native/spark-expr/src/map_funcs/mod.rs new file mode 100644 index 0000000000..b8972261d5 --- /dev/null +++ b/native/spark-expr/src/map_funcs/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod map_from_list; +mod map_sort; +mod map_test_helpers; +mod map_to_list; + +pub use map_from_list::map_from_list; +pub use map_sort::spark_map_sort; +pub use map_to_list::map_to_list; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 23cf9d313e..e38e963307 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String import com.google.protobuf.ByteString import org.apache.comet.{CometConf, ConfigEntry} -import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} +import org.apache.comet.CometSparkSessionExtensions.{isCometScan, isSpark40Plus, withInfo} import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig @@ -73,7 +73,7 @@ object QueryPlanSerde extends Logging with CometExprShim { /** * Mapping of Spark expression class to Comet expression handler. */ - private val exprSerdeMap: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + private val exprSerdeMap: Map[Class[_ <: Expression], CometExpressionSerde[_]] = (Map( classOf[Add] -> CometAdd, classOf[Subtract] -> CometSubtract, classOf[Multiply] -> CometMultiply, @@ -168,7 +168,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[DateAdd] -> CometDateAdd, classOf[DateSub] -> CometDateSub, classOf[TruncDate] -> CometTruncDate, - classOf[TruncTimestamp] -> CometTruncTimestamp) + classOf[TruncTimestamp] -> CometTruncTimestamp) ++ versionSpecificExprSerdeMap).toMap /** * Mapping of Spark aggregate expression class to Comet expression handler. @@ -1920,7 +1920,13 @@ object QueryPlanSerde extends Logging with CometExprShim { if (groupingExpressions.exists(expr => expr.dataType match { - case _: MapType => true + case _: MapType => + if (isSpark40Plus && + CometConf.COMET_ENABLE_GROUPING_ON_MAP_TYPE.get(conf) && + CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_DATAFUSION) + false + else + true case _ => false })) { withInfo(op, "Grouping on map types is not supported") diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index ca53efc3db..3756875674 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -21,12 +21,16 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.comet.serde.CommonStringExprs import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.CometExpressionSerde import org.apache.spark.sql.catalyst.expressions._ /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ trait CometExprShim extends CommonStringExprs { + // Version specific expression serde map. + protected val versionSpecificExprSerdeMap: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + /** * Returns a tuple of expressions for the `unhex` function. */ diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index ca53efc3db..9f746c0df2 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -21,12 +21,16 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.comet.serde.CommonStringExprs import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.CometExpressionSerde import org.apache.spark.sql.catalyst.expressions._ /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ trait CometExprShim extends CommonStringExprs { + // Version specific expression serde map. + protected val versionSpecificExprSerdeMap: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + /** * Returns a tuple of expressions for the `unhex` function. */ @@ -58,4 +62,3 @@ object CometEvalModeUtil { case EvalMode.ANSI => CometEvalMode.ANSI } } - diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index ddd53d6d8d..853e2efb2c 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -19,8 +19,9 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.CommonStringExprs +import org.apache.comet.serde.{CometExpressionSerde, CommonStringExprs, ExprOuterClass} import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.internal.types.StringTypeWithCollation @@ -30,6 +31,10 @@ import org.apache.spark.sql.types.{BinaryType, BooleanType, StringType} * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ trait CometExprShim extends CommonStringExprs { + // Version specific expression serde map. + protected val versionSpecificExprSerdeMap: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[MapSort] -> CometMapSort) + /** * Returns a tuple of expressions for the `unhex` function. */ @@ -71,3 +76,17 @@ object CometEvalModeUtil { } } +object CometMapSort extends CometExpressionSerde[MapSort] { + + override def convert( + mapSortExpr: MapSort, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(mapSortExpr.child, inputs, binding) + val returnType = mapSortExpr.child.dataType + + val mapSortScalarExpr = + scalarFunctionExprToProtoWithReturnType("map_sort", returnType, childExpr) + optExprWithInfo(mapSortScalarExpr, mapSortExpr, mapSortExpr.children: _*) + } +} \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index be7fe7ee52..54b67a6637 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.functions.{avg, count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} /** @@ -1515,4 +1516,99 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { sparkPlan.collect { case s: CometHashAggregateExec => s }.size } + test("groupby with map column") { + assume(isSpark40Plus, "Groupby on map type is supported in Spark 4.0 and beyond") + + def runTests(tableName: String): Unit = { + // Group on second map column with just aggregate. + checkSparkAnswerAndOperator(s"select count(*) AS cnt from $tableName group by _2") + + // Group on second map column with just grouping column. + checkSparkAnswerAndOperator(s"select _2 from $tableName group by _2") + + // Group on second map column with different aggregates. + checkSparkAnswerAndOperator(s"select _2, count(*) AS cnt from $tableName group by _2") + checkSparkAnswerAndOperator(s"select _2, sum(_1) AS total from $tableName group by _2") + checkSparkAnswerAndOperator( + s"select _2, count(*) AS cnt, sum(_1) AS sum_val from $tableName group by _2") + + // Group on second map column with aggregate and filtering. + checkSparkAnswerAndOperator( + s"select _2, count(*) AS cnt from $tableName group by _2 having count(*) > 1") + checkSparkAnswerAndOperator( + s"select _2, count(*) AS cnt from $tableName WHERE _2 IS not null group by _2") + + // Group on second map column with aggregate and order by. + checkSparkAnswerAndOperator( + s"select _2, count(*) AS cnt from $tableName group by _2 order by cnt DESC") + + // Group on third map column with aggregate. + checkSparkAnswerAndOperator(s"select _3, count(*) AS cnt from $tableName group by _3") + checkSparkAnswerAndOperator(s"select _3, sum(_1) AS total from $tableName group by _3") + + // Group on third map column with different aggregates. + checkSparkAnswerAndOperator( + s"select _3, count(*) AS cnt, sum(_1) AS sum_val from $tableName group by _3") + + // Group on third map column with aggregate and filtering. + checkSparkAnswerAndOperator( + s"select _3, count(*) AS cnt from $tableName WHERE _3 IS not null group by _3") + + // Group on third map column with aggregate and order by. + checkSparkAnswerAndOperator( + s"select _3, count(*) AS cnt from $tableName group by _3 order by cnt DESC") + + // Group on both map columns with aggregate. + checkSparkAnswerAndOperator( + s"select _2, _3, count(*) AS cnt from $tableName group by _2, _3") + + // Group on both map columns with aggregate. The columns are selected in different order. + checkSparkAnswerAndOperator( + s"select _3, count(*), _2 AS cnt from $tableName group by _2, _3") + + // Group on both map columns with different aggregates. + checkSparkAnswerAndOperator( + s"select _2, _3, count(*) AS cnt, sum(_1) AS sum_val from $tableName group by _2, _3") + + // Group on both map column with aggregate and filtering. + checkSparkAnswerAndOperator( + s"select _2, _3, count(*) AS cnt from $tableName WHERE _2 IS not null group by _2, _3") + checkSparkAnswerAndOperator( + s"select _2, _3, count(*) AS cnt from $tableName WHERE _3 IS not null group by _2, _3") + checkSparkAnswerAndOperator( + s"select _2, _3, count(*) AS cnt from $tableName WHERE _2 IS not null AND _3 IS not null group by _2, _3") + checkSparkAnswerAndOperator( + s"select _2, _3, count(*) AS cnt from $tableName WHERE _2 IS not null group by _2, _3 order by cnt DESC") + } + + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION, + CometConf.COMET_ENABLE_GROUPING_ON_MAP_TYPE.key -> "true") { + + withParquetTable( + Seq( + (1, Map("a" -> 1, "b" -> 2), Map(1 -> "a", 2 -> "b")), + (2, Map("b" -> 2, "a" -> 1), Map(2 -> "b", 1 -> "a")), + (3, Map("a" -> 5, "b" -> 6), Map(1 -> "c", 2 -> "d")), + (4, Map("a" -> 1, "b" -> 2), Map(1 -> "a", 2 -> "b")), + (5, Map("c" -> 3), Map(3 -> "e")), + (6, Map("a" -> 1, "b" -> 2, "c" -> 3), Map(1 -> "a", 2 -> "b", 3 -> "e")), + (7, Map.empty[String, Int], Map.empty[Int, String]), + (8, Map("b" -> 3, "a" -> 5), Map(2 -> "b", 1 -> "a")), + (9, null, null), + (10, Map("d" -> 4, "e" -> 5, "f" -> 6), Map(4 -> "f", 5 -> "g", 6 -> "h")), + (11, Map("datafusion" -> 4, "comet" -> 5), Map(1 -> "datafusion", 2 -> "comet")), + (12, Map("comet" -> 5, "datafusion" -> 4), Map(2 -> "comet", 1 -> "datafusion")), + (13, Map("a" -> 1, "b" -> 2), Map(-1 -> "a", 2 -> "b")), + (14, Map("b" -> 2, "a" -> 1), Map(1 -> "a", 2 -> "b")), + (15, Map("b" -> 2, "a" -> 1), Map(2 -> "b", -1 -> "a"))), + "map_tbl") { + runTests("map_tbl") + } + } + } + }