Skip to content

Commit 824ad1a

Browse files
Feat: Support array_intersect function (apache#1271)
* Feat: Support array_intersect * Address review comment
1 parent c3a552f commit 824ad1a

File tree

4 files changed

+40
-0
lines changed

4 files changed

+40
-0
lines changed

native/core/src/execution/planner.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ use datafusion::{
6767
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
6868
use datafusion_functions_nested::concat::ArrayAppend;
6969
use datafusion_functions_nested::remove::array_remove_all_udf;
70+
use datafusion_functions_nested::set_ops::array_intersect_udf;
7071
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
7172

7273
use crate::execution::shuffle::CompressionCodec;
@@ -774,6 +775,22 @@ impl PhysicalPlanner {
774775

775776
Ok(Arc::new(case_expr))
776777
}
778+
ExprStruct::ArrayIntersect(expr) => {
779+
let left_expr =
780+
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
781+
let right_expr =
782+
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
783+
let args = vec![Arc::clone(&left_expr), right_expr];
784+
let datafusion_array_intersect = array_intersect_udf();
785+
let return_type = left_expr.data_type(&input_schema)?;
786+
let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
787+
"array_intersect",
788+
datafusion_array_intersect,
789+
args,
790+
return_type,
791+
));
792+
Ok(array_intersect_expr)
793+
}
777794
expr => Err(ExecutionError::GeneralError(format!(
778795
"Not implemented: {:?}",
779796
expr

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ message Expr {
8686
ArrayInsert array_insert = 59;
8787
BinaryExpr array_contains = 60;
8888
BinaryExpr array_remove = 61;
89+
BinaryExpr array_intersect = 62;
8990
}
9091
}
9192

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,6 +2302,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
23022302
expr.children(1),
23032303
inputs,
23042304
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
2305+
case _ if expr.prettyName == "array_intersect" =>
2306+
createBinaryExpr(
2307+
expr.children(0),
2308+
expr.children(1),
2309+
inputs,
2310+
(builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
23052311
case _ =>
23062312
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
23072313
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2675,4 +2675,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
26752675
}
26762676
}
26772677
}
2678+
2679+
test("array_intersect") {
2680+
Seq(true, false).foreach { dictionaryEnabled =>
2681+
withTempDir { dir =>
2682+
val path = new Path(dir.toURI.toString, "test.parquet")
2683+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
2684+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
2685+
checkSparkAnswerAndOperator(
2686+
sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1"))
2687+
checkSparkAnswerAndOperator(
2688+
sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1"))
2689+
checkSparkAnswerAndOperator(sql("SELECT array_intersect(array(_18), array(_19)) from t1"))
2690+
}
2691+
}
2692+
}
2693+
26782694
}

0 commit comments

Comments
 (0)