Skip to content

Commit 59fae94

Browse files
authored
feat: add read array support (#1456)
* feat: add read array support
1 parent 4ed00cf commit 59fae94

File tree

8 files changed

+469
-151
lines changed

8 files changed

+469
-151
lines changed

native/Cargo.lock

Lines changed: 255 additions & 142 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ jni = { version = "0.21", features = ["invocation"] }
7777
lazy_static = "1.4"
7878
assertables = "7"
7979
hex = "0.4.3"
80+
datafusion-functions-nested = "46.0.0"
8081

8182
[features]
8283
default = []

native/core/src/execution/planner.rs

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2686,12 +2686,14 @@ mod tests {
26862686

26872687
use arrow::array::{DictionaryArray, Int32Array, StringArray};
26882688
use arrow::datatypes::DataType;
2689-
use datafusion::{physical_plan::common::collect, prelude::SessionContext};
2689+
use datafusion::logical_expr::ScalarUDF;
2690+
use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext};
26902691
use tokio::sync::mpsc;
26912692

26922693
use crate::execution::{operators::InputBatch, planner::PhysicalPlanner};
26932694

26942695
use crate::execution::operators::ExecutionError;
2696+
use datafusion_comet_proto::spark_expression::expr::ExprStruct;
26952697
use datafusion_comet_proto::{
26962698
spark_expression::expr::ExprStruct::*,
26972699
spark_expression::Expr,
@@ -3004,4 +3006,130 @@ mod tests {
30043006
type_info: None,
30053007
}
30063008
}
3009+
3010+
#[test]
3011+
fn test_create_array() {
3012+
let session_ctx = SessionContext::new();
3013+
session_ctx.register_udf(ScalarUDF::from(
3014+
datafusion_functions_nested::make_array::MakeArray::new(),
3015+
));
3016+
let task_ctx = session_ctx.task_ctx();
3017+
let planner = PhysicalPlanner::new(Arc::from(session_ctx));
3018+
3019+
// Create a plan for
3020+
// ProjectionExec: expr=[make_array(col_0@0) as col_0]
3021+
// ScanExec: source=[CometScan parquet (unknown)], schema=[col_0: Int32]
3022+
let op_scan = Operator {
3023+
plan_id: 0,
3024+
children: vec![],
3025+
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
3026+
fields: vec![
3027+
spark_expression::DataType {
3028+
type_id: 3, // Int32
3029+
type_info: None,
3030+
},
3031+
spark_expression::DataType {
3032+
type_id: 3, // Int32
3033+
type_info: None,
3034+
},
3035+
spark_expression::DataType {
3036+
type_id: 3, // Int32
3037+
type_info: None,
3038+
},
3039+
],
3040+
source: "".to_string(),
3041+
})),
3042+
};
3043+
3044+
let array_col = spark_expression::Expr {
3045+
expr_struct: Some(Bound(spark_expression::BoundReference {
3046+
index: 0,
3047+
datatype: Some(spark_expression::DataType {
3048+
type_id: 3,
3049+
type_info: None,
3050+
}),
3051+
})),
3052+
};
3053+
3054+
let array_col_1 = spark_expression::Expr {
3055+
expr_struct: Some(Bound(spark_expression::BoundReference {
3056+
index: 1,
3057+
datatype: Some(spark_expression::DataType {
3058+
type_id: 3,
3059+
type_info: None,
3060+
}),
3061+
})),
3062+
};
3063+
3064+
let projection = Operator {
3065+
children: vec![op_scan],
3066+
plan_id: 0,
3067+
op_struct: Some(OpStruct::Projection(spark_operator::Projection {
3068+
project_list: vec![spark_expression::Expr {
3069+
expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc {
3070+
func: "make_array".to_string(),
3071+
args: vec![array_col, array_col_1],
3072+
return_type: None,
3073+
})),
3074+
}],
3075+
})),
3076+
};
3077+
3078+
let a = Int32Array::from(vec![0, 3]);
3079+
let b = Int32Array::from(vec![1, 4]);
3080+
let c = Int32Array::from(vec![2, 5]);
3081+
let input_batch = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 2);
3082+
3083+
let (mut scans, datafusion_plan) =
3084+
planner.create_plan(&projection, &mut vec![], 1).unwrap();
3085+
scans[0].set_input_batch(input_batch);
3086+
3087+
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
3088+
3089+
let runtime = tokio::runtime::Runtime::new().unwrap();
3090+
let (tx, mut rx) = mpsc::channel(1);
3091+
3092+
// Separate thread to send the EOF signal once we've processed the only input batch
3093+
runtime.spawn(async move {
3094+
// Create a dictionary array with 100 values, and use it as input to the execution.
3095+
let a = Int32Array::from(vec![0, 3]);
3096+
let b = Int32Array::from(vec![1, 4]);
3097+
let c = Int32Array::from(vec![2, 5]);
3098+
let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 2);
3099+
let input_batch2 = InputBatch::EOF;
3100+
3101+
let batches = vec![input_batch1, input_batch2];
3102+
3103+
for batch in batches.into_iter() {
3104+
tx.send(batch).await.unwrap();
3105+
}
3106+
});
3107+
3108+
runtime.block_on(async move {
3109+
loop {
3110+
let batch = rx.recv().await.unwrap();
3111+
scans[0].set_input_batch(batch);
3112+
match poll!(stream.next()) {
3113+
Poll::Ready(Some(batch)) => {
3114+
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
3115+
let batch = batch.unwrap();
3116+
assert_eq!(batch.num_rows(), 2);
3117+
let expected = [
3118+
"+--------+",
3119+
"| col_0 |",
3120+
"+--------+",
3121+
"| [0, 1] |",
3122+
"| [3, 4] |",
3123+
"+--------+",
3124+
];
3125+
assert_batches_eq!(expected, &[batch]);
3126+
}
3127+
Poll::Ready(None) => {
3128+
break;
3129+
}
3130+
_ => {}
3131+
}
3132+
}
3133+
});
3134+
}
30073135
}

native/core/src/execution/shuffle/row.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3197,6 +3197,7 @@ fn make_builders(
31973197
// Disable dictionary encoding for array element
31983198
let value_builder =
31993199
make_builders(field.data_type(), NESTED_TYPE_BUILDER_CAPACITY, 1.0)?;
3200+
32003201
match field.data_type() {
32013202
DataType::Boolean => {
32023203
let builder = downcast_builder!(BooleanBuilder, value_builder);

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
6161
logWarning(s"Comet native execution is disabled due to: $reason")
6262
}
6363

64-
def supportedDataType(dt: DataType, allowStruct: Boolean = false): Boolean = dt match {
64+
def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match {
6565
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
6666
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
6767
_: DecimalType | _: DateType | _: BooleanType | _: NullType =>
6868
true
69-
case s: StructType if allowStruct =>
70-
s.fields.map(_.dataType).forall(supportedDataType(_, allowStruct))
69+
case s: StructType if allowComplex =>
70+
s.fields.map(_.dataType).forall(supportedDataType(_, allowComplex))
71+
// TODO: Add nested array and iceberg compat support
72+
// case a: ArrayType if allowComplex =>
73+
// supportedDataType(a.elementType)
7174
case dt =>
7275
emitWarning(s"unsupported Spark data type: $dt")
7376
false
@@ -763,7 +766,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
763766
binding,
764767
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
765768

766-
case Literal(value, dataType) if supportedDataType(dataType, allowStruct = value == null) =>
769+
case Literal(value, dataType)
770+
if supportedDataType(dataType, allowComplex = value == null) =>
767771
val exprBuilder = ExprOuterClass.Literal.newBuilder()
768772

769773
if (value == null) {
@@ -2716,7 +2720,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
27162720
withInfo(join, "SortMergeJoin is not enabled")
27172721
None
27182722

2719-
case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType, true)) =>
2723+
case op
2724+
if isCometSink(op) && op.output.forall(a =>
2725+
supportedDataType(a.dataType, allowComplex = true)) =>
27202726
// These operators are source of Comet native execution chain
27212727
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
27222728
val source = op.simpleStringWithNodeId()

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ object CometNativeScanExec extends DataTypeSupport {
182182
case null => null
183183
}
184184

185-
val newArgs = mapProductIterator(scanExec, transform(_))
185+
val newArgs = mapProductIterator(scanExec, transform)
186186
val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec]
187187
val batchScanExec = CometNativeScanExec(
188188
nativeOp,
@@ -202,9 +202,10 @@ object CometNativeScanExec extends DataTypeSupport {
202202
}
203203

204204
override def isAdditionallySupported(dt: DataType): Boolean = {
205-
// TODO add array and map
205+
// TODO add map
206206
dt match {
207207
case s: StructType => s.fields.map(_.dataType).forall(isTypeSupported)
208+
case a: ArrayType => isTypeSupported(a.elementType)
208209
case _ => false
209210
}
210211
}

spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,11 @@ object CometScanExec extends DataTypeSupport {
487487

488488
override def isAdditionallySupported(dt: DataType): Boolean = {
489489
if (CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_ICEBERG_COMPAT) {
490-
// TODO add array and map
490+
// TODO add map
491491
dt match {
492492
case s: StructType => s.fields.map(_.dataType).forall(isTypeSupported)
493+
// TODO: Add nested array and iceberg compat support
494+
// case a: ArrayType => isTypeSupported(a.elementType)
493495
case _ => false
494496
}
495497
} else {
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.exec
21+
22+
import org.scalactic.source.Position
23+
import org.scalatest.Tag
24+
25+
import org.apache.spark.sql.CometTestBase
26+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
27+
import org.apache.spark.sql.internal.SQLConf
28+
29+
import org.apache.comet.CometConf
30+
31+
class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper {
32+
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
33+
pos: Position): Unit = {
34+
// TODO: Enable Iceberg compat tests
35+
Seq(CometConf.SCAN_NATIVE_DATAFUSION /*, CometConf.SCAN_NATIVE_ICEBERG_COMPAT*/ ).foreach(
36+
scan =>
37+
super.test(s"$testName - $scan", testTags: _*) {
38+
withSQLConf(
39+
CometConf.COMET_EXEC_ENABLED.key -> "true",
40+
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
41+
CometConf.COMET_ENABLED.key -> "true",
42+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false",
43+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scan) {
44+
testFun
45+
}
46+
})
47+
}
48+
49+
test("native reader - read simple STRUCT fields") {
50+
testSingleLineQuery(
51+
"""
52+
|select named_struct('firstName', 'John', 'lastName', 'Doe', 'age', 35) as personal_info union all
53+
|select named_struct('firstName', 'Jane', 'lastName', 'Doe', 'age', 40) as personal_info
54+
|""".stripMargin,
55+
"select personal_info.* from tbl")
56+
}
57+
58+
test("native reader - read simple ARRAY fields") {
59+
testSingleLineQuery(
60+
"""
61+
|select array(1, 2, 3) as arr union all
62+
|select array(2, 3, 4) as arr
63+
|""".stripMargin,
64+
"select arr from tbl")
65+
}
66+
}

0 commit comments

Comments
 (0)