diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0c3d345c8e..18d3bddf3c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -28,7 +28,7 @@ use crate::{ }, }; use arrow::compute::CastOptions; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; @@ -85,15 +85,16 @@ use datafusion::physical_expr::LexOrdering; use crate::parquet::parquet_exec::init_datasource_exec; use arrow::array::{ - BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, NullArray, StringBuilder, - TimestampMicrosecondArray, + new_empty_array, Array, ArrayRef, BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, + NullArray, StringBuilder, TimestampMicrosecondArray, }; -use arrow::buffer::BooleanBuffer; +use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer}; use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::limit::GlobalLimitExec; +use datafusion_comet_proto::spark_expression::ListLiteral; use datafusion_comet_proto::spark_operator::SparkFilePartition; use datafusion_comet_proto::{ spark_expression::{ @@ -483,118 +484,8 @@ impl PhysicalPlanner { } }, Value::ListVal(values) => { - if let DataType::List(f) = data_type { - match f.data_type() { - DataType::Null => { - SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len()))) - .build_list_scalar() - } - DataType::Boolean => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(BooleanArray::new(BooleanBuffer::from(vals.boolean_values), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Int8 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Int8Array::new(vals.byte_values.iter().map(|&x| x as i8).collect::>().into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Int16 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Int16Array::new(vals.short_values.iter().map(|&x| x as i16).collect::>().into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Int32 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Int32Array::new(vals.int_values.into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Int64 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Int64Array::new(vals.long_values.into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Float32 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Float32Array::new(vals.float_values.into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Float64 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Float64Array::new(vals.double_values.into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Timestamp(TimeUnit::Microsecond, None) => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(TimestampMicrosecondArray::new(vals.long_values.into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(TimestampMicrosecondArray::new(vals.long_values.into(), Some(vals.null_mask.into())).with_timezone(Arc::clone(tz)))) - .build_list_scalar() - } - DataType::Date32 => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Date32Array::new(vals.int_values.into(), Some(vals.null_mask.into())))) - .build_list_scalar() - } - DataType::Binary => { - // Using a builder as it is cumbersome to create BinaryArray from a vector with nulls - // and calculate correct offsets - let vals = values.clone(); - let item_capacity = vals.string_values.len(); - let data_capacity = vals.string_values.first().map(|s| s.len() * item_capacity).unwrap_or(0); - let mut arr = BinaryBuilder::with_capacity(item_capacity, data_capacity); - - for (i, v) in vals.bytes_values.into_iter().enumerate() { - if vals.null_mask[i] { - arr.append_value(v); - } else { - arr.append_null(); - } - } - - SingleRowListArrayBuilder::new(Arc::new(arr.finish())) - .build_list_scalar() - } - DataType::Utf8 => { - // Using a builder as it is cumbersome to create StringArray from a vector with nulls - // and calculate correct offsets - let vals = values.clone(); - let item_capacity = vals.string_values.len(); - let data_capacity = vals.string_values.first().map(|s| s.len() * item_capacity).unwrap_or(0); - let mut arr = StringBuilder::with_capacity(item_capacity, data_capacity); - - for (i, v) in vals.string_values.into_iter().enumerate() { - if vals.null_mask[i] { - arr.append_value(v); - } else { - arr.append_null(); - } - } - - SingleRowListArrayBuilder::new(Arc::new(arr.finish())) - .build_list_scalar() - } - DataType::Decimal128(p, s) => { - let vals = values.clone(); - SingleRowListArrayBuilder::new(Arc::new(Decimal128Array::new(vals.decimal_values.into_iter().map(|v| { - let big_integer = BigInt::from_signed_bytes_be(&v); - big_integer.to_i128().ok_or_else(|| { - GeneralError(format!( - "Cannot parse {big_integer:?} as i128 for Decimal literal" - )) - }).unwrap() - }).collect::>().into(), Some(vals.null_mask.into())).with_precision_and_scale(*p, *s)?)).build_list_scalar() - } - dt => { - return Err(GeneralError(format!( - "DataType::List literal does not support {dt:?} type" - ))) - } - } - + if let DataType::List(_) = data_type { + SingleRowListArrayBuilder::new(literal_to_array_ref(data_type, values.clone())?).build_list_scalar() } else { return Err(GeneralError(format!( "Expected DataType::List but got {data_type:?}" @@ -2792,13 +2683,188 @@ fn create_case_expr( } } +fn literal_to_array_ref( + data_type: DataType, + list_literal: ListLiteral, +) -> Result { + let nulls = &list_literal.null_mask; + match data_type { + DataType::Null => Ok(Arc::new(NullArray::new(nulls.len()))), + DataType::Boolean => Ok(Arc::new(BooleanArray::new( + BooleanBuffer::from(list_literal.boolean_values), + Some(nulls.clone().into()), + ))), + DataType::Int8 => Ok(Arc::new(Int8Array::new( + list_literal + .byte_values + .iter() + .map(|&x| x as i8) + .collect::>() + .into(), + Some(nulls.clone().into()), + ))), + DataType::Int16 => Ok(Arc::new(Int16Array::new( + list_literal + .short_values + .iter() + .map(|&x| x as i16) + .collect::>() + .into(), + Some(nulls.clone().into()), + ))), + DataType::Int32 => Ok(Arc::new(Int32Array::new( + list_literal.int_values.into(), + Some(nulls.clone().into()), + ))), + DataType::Int64 => Ok(Arc::new(Int64Array::new( + list_literal.long_values.into(), + Some(nulls.clone().into()), + ))), + DataType::Float32 => Ok(Arc::new(Float32Array::new( + list_literal.float_values.into(), + Some(nulls.clone().into()), + ))), + DataType::Float64 => Ok(Arc::new(Float64Array::new( + list_literal.double_values.into(), + Some(nulls.clone().into()), + ))), + DataType::Date32 => Ok(Arc::new(Date32Array::new( + list_literal.int_values.into(), + Some(nulls.clone().into()), + ))), + DataType::Timestamp(TimeUnit::Microsecond, None) => { + Ok(Arc::new(TimestampMicrosecondArray::new( + list_literal.long_values.into(), + Some(nulls.clone().into()), + ))) + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(Arc::new( + TimestampMicrosecondArray::new( + list_literal.long_values.into(), + Some(nulls.clone().into()), + ) + .with_timezone(Arc::clone(&tz)), + )), + DataType::Binary => { + // Using a builder as it is cumbersome to create BinaryArray from a vector with nulls + // and calculate correct offsets + let item_capacity = list_literal.bytes_values.len(); + let data_capacity = list_literal + .bytes_values + .first() + .map(|s| s.len() * item_capacity) + .unwrap_or(0); + let mut arr = BinaryBuilder::with_capacity(item_capacity, data_capacity); + + for (i, v) in list_literal.bytes_values.into_iter().enumerate() { + if nulls[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + Ok(Arc::new(arr.finish())) + } + DataType::Utf8 => { + // Using a builder as it is cumbersome to create StringArray from a vector with nulls + // and calculate correct offsets + let item_capacity = list_literal.string_values.len(); + let data_capacity = list_literal + .string_values + .first() + .map(|s| s.len() * item_capacity) + .unwrap_or(0); + let mut arr = StringBuilder::with_capacity(item_capacity, data_capacity); + + for (i, v) in list_literal.string_values.into_iter().enumerate() { + if nulls[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + Ok(Arc::new(arr.finish())) + } + DataType::Decimal128(p, s) => Ok(Arc::new( + Decimal128Array::new( + list_literal + .decimal_values + .into_iter() + .map(|v| { + let big_integer = BigInt::from_signed_bytes_be(&v); + big_integer + .to_i128() + .ok_or_else(|| { + GeneralError(format!( + "Cannot parse {big_integer:?} as i128 for Decimal literal" + )) + }) + .unwrap() + }) + .collect::>() + .into(), + Some(nulls.clone().into()), + ) + .with_precision_and_scale(p, s)?, + )), + // list of primitive types + DataType::List(f) if !matches!(f.data_type(), DataType::List(_)) => { + literal_to_array_ref(f.data_type().clone(), list_literal) + } + DataType::List(ref f) => { + let dt = f.data_type().clone(); + + // Build offsets and collect non-null child arrays + let mut offsets = Vec::with_capacity(list_literal.list_values.len() + 1); + offsets.push(0i32); + let mut child_arrays: Vec = Vec::new(); + + for (i, child_literal) in list_literal.list_values.iter().enumerate() { + // Check if the current child literal is non-null and not the empty array + if list_literal.null_mask[i] && *child_literal != ListLiteral::default() { + // Non-null entry: process the child array + let child_array = literal_to_array_ref(dt.clone(), child_literal.clone())?; + let len = child_array.len() as i32; + offsets.push(offsets.last().unwrap() + len); + child_arrays.push(child_array); + } else { + // Null entry: just repeat the last offset (empty slot) + offsets.push(*offsets.last().unwrap()); + } + } + + // Concatenate all non-null child arrays' values into one array + let output_array = if !child_arrays.is_empty() { + let child_refs: Vec<&dyn Array> = child_arrays.iter().map(|a| a.as_ref()).collect(); + arrow::compute::concat(&child_refs)? + } else { + // All entries are null or the list is empty + new_empty_array(&dt) + }; + + // Create and return the parent ListArray + Ok(Arc::new(ListArray::new( + FieldRef::from(Field::new("item", output_array.data_type().clone(), true)), + OffsetBuffer::new(offsets.into()), + output_array, + Some(NullBuffer::from(list_literal.null_mask.clone())), + ))) + } + dt => Err(GeneralError(format!( + "DataType::List literal does not support {dt:?} type" + ))), + } +} + #[cfg(test)] mod tests { use futures::{poll, StreamExt}; use std::{sync::Arc, task::Poll}; - use arrow::array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray}; - use arrow::datatypes::{DataType, Field, Fields, Schema}; + use arrow::array::{Array, DictionaryArray, Int32Array, ListArray, RecordBatch, StringArray}; + use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; use datafusion::catalog::memory::DataSourceExec; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; @@ -2815,9 +2881,11 @@ mod tests { use crate::execution::{operators::InputBatch, planner::PhysicalPlanner}; use crate::execution::operators::ExecutionError; + use crate::execution::planner::literal_to_array_ref; use crate::parquet::parquet_support::SparkParquetOptions; use crate::parquet::schema_adapter::SparkSchemaAdapterFactory; use datafusion_comet_proto::spark_expression::expr::ExprStruct; + use datafusion_comet_proto::spark_expression::ListLiteral; use datafusion_comet_proto::{ spark_expression::expr::ExprStruct::*, spark_expression::Expr, @@ -3595,4 +3663,127 @@ mod tests { assert_batches_eq!(expected, &[actual]); Ok(()) } + + #[tokio::test] + async fn test_literal_to_list() -> Result<(), DataFusionError> { + /* + [ + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9, null], + [], + null + ], + [ + [10, null, 12] + ], + null, + [] + ] + */ + let data = ListLiteral { + list_values: vec![ + ListLiteral { + list_values: vec![ + ListLiteral { + int_values: vec![1, 2, 3], + null_mask: vec![true, true, true], + ..Default::default() + }, + ListLiteral { + int_values: vec![4, 5, 6], + null_mask: vec![true, true, true], + ..Default::default() + }, + ListLiteral { + int_values: vec![7, 8, 9, 0], + null_mask: vec![true, true, true, false], + ..Default::default() + }, + ListLiteral { + ..Default::default() + }, + ListLiteral { + ..Default::default() + }, + ], + null_mask: vec![true, true, true, false, true], + ..Default::default() + }, + ListLiteral { + list_values: vec![ListLiteral { + int_values: vec![10, 0, 11], + null_mask: vec![true, false, true], + ..Default::default() + }], + null_mask: vec![true], + ..Default::default() + }, + ListLiteral { + ..Default::default() + }, + ListLiteral { + ..Default::default() + }, + ], + null_mask: vec![true, true, false, true], + ..Default::default() + }; + + let nested_type = DataType::List(FieldRef::from(Field::new( + "item", + DataType::List( + Field::new( + "item", + DataType::List( + Field::new( + "item", + DataType::Int32, + true, // Int32 nullable + ) + .into(), + ), + true, // inner list nullable + ) + .into(), + ), + true, // outer list nullable + ))); + + let array = literal_to_array_ref(nested_type, data)?; + + // Top-level should be ListArray> + let list_outer = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_outer.len(), 4); + + // First outer element: ListArray + let first_elem = list_outer.value(0); + let list_inner = first_elem.as_any().downcast_ref::().unwrap(); + assert_eq!(list_inner.len(), 5); + + // Inner values + let v0 = list_inner.value(0); + let vals0 = v0.as_any().downcast_ref::().unwrap(); + assert_eq!(vals0.values(), &[1, 2, 3]); + + let v1 = list_inner.value(1); + let vals1 = v1.as_any().downcast_ref::().unwrap(); + assert_eq!(vals1.values(), &[4, 5, 6]); + + let v2 = list_inner.value(2); + let vals2 = v2.as_any().downcast_ref::().unwrap(); + assert_eq!(vals2.values(), &[7, 8, 9, 0]); + + // Second outer element + let second_elem = list_outer.value(1); + let list_inner2 = second_elem.as_any().downcast_ref::().unwrap(); + assert_eq!(list_inner2.len(), 1); + + let v3 = list_inner2.value(0); + let vals3 = v3.as_any().downcast_ref::().unwrap(); + assert_eq!(vals3.values(), &[10, 0, 11]); + + Ok(()) + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2a5b6d0750..a4f8afc31c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,6 +19,8 @@ package org.apache.comet.serde +import java.lang + import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -758,11 +760,12 @@ object QueryPlanSerde extends Logging with CometExprShim { allowComplex = value == null || // Nested literal support for native reader // can be tracked https://github.com/apache/datafusion-comet/issues/1937 - // now supports only Array of primitive - (Seq(CometConf.SCAN_NATIVE_ICEBERG_COMPAT, CometConf.SCAN_NATIVE_DATAFUSION) - .contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && dataType - .isInstanceOf[ArrayType]) && !isComplexType( - dataType.asInstanceOf[ArrayType].elementType)) => + (dataType + .isInstanceOf[ArrayType] && (!isComplexType( + dataType.asInstanceOf[ArrayType].elementType) || dataType + .asInstanceOf[ArrayType] + .elementType + .isInstanceOf[ArrayType]))) => val exprBuilder = LiteralOuterClass.Literal.newBuilder() if (value == null) { @@ -789,85 +792,9 @@ object QueryPlanSerde extends Logging with CometExprShim { val byteStr = com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) exprBuilder.setBytesVal(byteStr) - case a: ArrayType => - val listLiteralBuilder = ListLiteral.newBuilder() - val array = value.asInstanceOf[GenericArrayData].array - a.elementType match { - case NullType => - array.foreach(_ => listLiteralBuilder.addNullMask(true)) - case BooleanType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Boolean] - listLiteralBuilder.addBooleanValues(casted) - listLiteralBuilder.addNullMask(casted != null) - }) - case ByteType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Integer] - listLiteralBuilder.addByteValues(casted) - listLiteralBuilder.addNullMask(casted != null) - }) - case ShortType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Short] - listLiteralBuilder.addShortValues( - if (casted != null) casted.intValue() - else null.asInstanceOf[java.lang.Integer]) - listLiteralBuilder.addNullMask(casted != null) - }) - case IntegerType | DateType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Integer] - listLiteralBuilder.addIntValues(casted) - listLiteralBuilder.addNullMask(casted != null) - }) - case LongType | TimestampType | TimestampNTZType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Long] - listLiteralBuilder.addLongValues(casted) - listLiteralBuilder.addNullMask(casted != null) - }) - case FloatType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Float] - listLiteralBuilder.addFloatValues(casted) - listLiteralBuilder.addNullMask(casted != null) - }) - case DoubleType => - array.foreach(v => { - val casted = v.asInstanceOf[java.lang.Double] - listLiteralBuilder.addDoubleValues(casted) - listLiteralBuilder.addNullMask(casted != null) - }) - case StringType => - array.foreach(v => { - val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String] - listLiteralBuilder.addStringValues( - if (casted != null) casted.toString else "") - listLiteralBuilder.addNullMask(casted != null) - }) - case _: DecimalType => - array - .foreach(v => { - val casted = - v.asInstanceOf[Decimal] - listLiteralBuilder.addDecimalValues(if (casted != null) { - com.google.protobuf.ByteString - .copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray) - } else ByteString.EMPTY) - listLiteralBuilder.addNullMask(casted != null) - }) - case _: BinaryType => - array - .foreach(v => { - val casted = - v.asInstanceOf[Array[Byte]] - listLiteralBuilder.addBytesValues(if (casted != null) { - com.google.protobuf.ByteString.copyFrom(casted) - } else ByteString.EMPTY) - listLiteralBuilder.addNullMask(casted != null) - }) - } + case arr: ArrayType => + val listLiteralBuilder: ListLiteral.Builder = + makeListLiteral(value.asInstanceOf[GenericArrayData].array, arr) exprBuilder.setListVal(listLiteralBuilder.build()) exprBuilder.setDatatype(serializeDataType(dataType).get) case dt => @@ -1296,6 +1223,94 @@ object QueryPlanSerde extends Logging with CometExprShim { }) } + private def makeListLiteral(array: Array[Any], arrayType: ArrayType): ListLiteral.Builder = { + val listLiteralBuilder = ListLiteral.newBuilder() + arrayType.elementType match { + case NullType => + array.foreach(_ => listLiteralBuilder.addNullMask(true)) + case BooleanType => + array.foreach(v => { + val casted = v.asInstanceOf[lang.Boolean] + listLiteralBuilder.addBooleanValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case ByteType => + array.foreach(v => { + val casted = v.asInstanceOf[Integer] + listLiteralBuilder.addByteValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case ShortType => + array.foreach(v => { + val casted = v.asInstanceOf[lang.Short] + listLiteralBuilder.addShortValues( + if (casted != null) casted.intValue() + else null.asInstanceOf[Integer]) + listLiteralBuilder.addNullMask(casted != null) + }) + case IntegerType | DateType => + array.foreach(v => { + val casted = v.asInstanceOf[Integer] + listLiteralBuilder.addIntValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case LongType | TimestampType | TimestampNTZType => + array.foreach(v => { + val casted = v.asInstanceOf[lang.Long] + listLiteralBuilder.addLongValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case FloatType => + array.foreach(v => { + val casted = v.asInstanceOf[lang.Float] + listLiteralBuilder.addFloatValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case DoubleType => + array.foreach(v => { + val casted = v.asInstanceOf[lang.Double] + listLiteralBuilder.addDoubleValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case StringType => + array.foreach(v => { + val casted = v.asInstanceOf[UTF8String] + listLiteralBuilder.addStringValues(if (casted != null) casted.toString else "") + listLiteralBuilder.addNullMask(casted != null) + }) + case _: DecimalType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Decimal] + listLiteralBuilder.addDecimalValues(if (casted != null) { + com.google.protobuf.ByteString + .copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted != null) + }) + case _: BinaryType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Array[Byte]] + listLiteralBuilder.addBytesValues(if (casted != null) { + com.google.protobuf.ByteString.copyFrom(casted) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted != null) + }) + case a: ArrayType => + array.foreach(v => { + val casted = v.asInstanceOf[GenericArrayData] + listLiteralBuilder.addListValues(if (casted != null) { + makeListLiteral(casted.array, a) + } else ListLiteral.newBuilder()) + listLiteralBuilder.addNullMask(casted != null) + }) + } + listLiteralBuilder + } + /** * Creates a UnaryExpr by calling exprToProtoInternal for the provided child expression and then * invokes the supplied function to wrap this UnaryExpr in a top-level Expr. diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 025cc19e1b..56d9b3b429 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -692,4 +692,20 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("array literals") { + withSQLConf( + CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array(array(1, 2, 3), null, array(), array(null), array(1)) from t1")) + } + } + } + } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index 8f1e7cfdf0..63164c0fc9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -557,4 +557,12 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) } } + + test("native reader - support ARRAY literal nested ARRAY fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(array(1, 2, null), array(), array(10), null, array(null)) from tbl") + } }