Skip to content

Commit 4d0c2ab

Browse files
committed
feat: pass the sorted input data to rust scan
This can cause problems if spark says something is sorted while we don't sort it. for example shuffle files in spark are sorted, but ours are not, so we should make sure that the sort is used correctly.
1 parent b9ce50c commit 4d0c2ab

File tree

8 files changed

+178
-56
lines changed

8 files changed

+178
-56
lines changed

native/core/src/execution/operators/scan.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,19 @@ impl ScanExec {
434434

435435
Ok(selection_indices_arrays)
436436
}
437+
438+
pub fn with_ordering(mut self, input_sorted: Vec<PhysicalSortExpr>) -> Self {
439+
assert_ne!(input_sorted.len(), 0, "input_sorted cannot be empty");
440+
let mut eq_properties = self.cache.eq_properties.clone();
441+
442+
eq_properties.add_ordering(
443+
LexOrdering::new(input_sorted).expect("Must be able to create LexOrdering"),
444+
);
445+
446+
self.cache = self.cache.with_eq_properties(eq_properties);
447+
448+
self
449+
}
437450
}
438451

439452
fn scan_schema(input_batch: &InputBatch, data_types: &[DataType]) -> SchemaRef {

native/core/src/execution/planner.rs

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -974,29 +974,38 @@ impl PhysicalPlanner {
974974
/// Create a DataFusion physical sort expression from Spark physical expression
975975
fn create_sort_expr<'a>(
976976
&'a self,
977-
spark_expr: &'a Expr,
977+
spark_expr: &'a spark_expression::Expr,
978978
input_schema: SchemaRef,
979979
) -> Result<PhysicalSortExpr, ExecutionError> {
980980
match spark_expr.expr_struct.as_ref().unwrap() {
981981
ExprStruct::SortOrder(expr) => {
982-
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
983-
let descending = expr.direction == 1;
984-
let nulls_first = expr.null_ordering == 0;
985-
986-
let options = SortOptions {
987-
descending,
988-
nulls_first,
989-
};
990-
991-
Ok(PhysicalSortExpr {
992-
expr: child,
993-
options,
994-
})
982+
self.sort_order_to_physical_sort_expr(expr, input_schema)
995983
}
996984
expr => Err(GeneralError(format!("{expr:?} isn't a SortOrder"))),
997985
}
998986
}
999987

988+
/// Create a DataFusion physical sort expression from Spark physical Sort Order
989+
fn sort_order_to_physical_sort_expr<'a>(
990+
&'a self,
991+
spark_sort_order: &'a spark_expression::SortOrder,
992+
input_schema: SchemaRef,
993+
) -> Result<PhysicalSortExpr, ExecutionError> {
994+
let child = self.create_expr(spark_sort_order.child.as_ref().unwrap(), input_schema)?;
995+
let descending = spark_sort_order.direction == 1;
996+
let nulls_first = spark_sort_order.null_ordering == 0;
997+
998+
let options = SortOptions {
999+
descending,
1000+
nulls_first,
1001+
};
1002+
1003+
Ok(PhysicalSortExpr {
1004+
expr: child,
1005+
options,
1006+
})
1007+
}
1008+
10001009
fn create_binary_expr(
10011010
&self,
10021011
left: &Expr,
@@ -1467,15 +1476,28 @@ impl PhysicalPlanner {
14671476
Some(inputs.remove(0))
14681477
};
14691478

1479+
let input_ordering = scan.input_ordering.clone();
1480+
14701481
// The `ScanExec` operator will take actual arrays from Spark during execution
1471-
let scan = ScanExec::new(
1482+
let mut scan = ScanExec::new(
14721483
self.exec_context_id,
14731484
input_source,
14741485
&scan.source,
14751486
data_types,
14761487
scan.arrow_ffi_safe,
14771488
)?;
14781489

1490+
if !input_ordering.is_empty() {
1491+
let sort_exprs: Vec<PhysicalSortExpr> = input_ordering
1492+
.iter()
1493+
.map(|expr| {
1494+
self.sort_order_to_physical_sort_expr(expr, Arc::clone(&scan.schema()))
1495+
})
1496+
.collect::<Result<_, ExecutionError>>()?;
1497+
1498+
scan = scan.with_ordering(sort_exprs)
1499+
}
1500+
14791501
Ok((
14801502
vec![scan.clone()],
14811503
Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])),
@@ -2844,6 +2866,7 @@ mod tests {
28442866
}],
28452867
source: "".to_string(),
28462868
arrow_ffi_safe: false,
2869+
input_ordering: vec![],
28472870
})),
28482871
};
28492872

@@ -2918,6 +2941,7 @@ mod tests {
29182941
}],
29192942
source: "".to_string(),
29202943
arrow_ffi_safe: false,
2944+
input_ordering: vec![],
29212945
})),
29222946
};
29232947

@@ -3129,6 +3153,7 @@ mod tests {
31293153
fields: vec![create_proto_datatype()],
31303154
source: "".to_string(),
31313155
arrow_ffi_safe: false,
3156+
input_ordering: vec![],
31323157
})),
31333158
}
31343159
}
@@ -3172,6 +3197,7 @@ mod tests {
31723197
],
31733198
source: "".to_string(),
31743199
arrow_ffi_safe: false,
3200+
input_ordering: vec![],
31753201
})),
31763202
};
31773203

@@ -3287,6 +3313,7 @@ mod tests {
32873313
],
32883314
source: "".to_string(),
32893315
arrow_ffi_safe: false,
3316+
input_ordering: vec![],
32903317
})),
32913318
};
32923319

native/proto/src/proto/operator.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ message Scan {
7979
string source = 2;
8080
// Whether native code can assume ownership of batches that it receives
8181
bool arrow_ffi_safe = 3;
82+
83+
repeated spark.spark_expression.SortOrder input_ordering = 4;
8284
}
8385

8486
message NativeScan {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.comet.serde
2121

22+
import scala.annotation.tailrec
2223
import scala.collection.JavaConverters._
2324
import scala.collection.mutable.ListBuffer
2425

@@ -819,31 +820,18 @@ object QueryPlanSerde extends Logging with CometExprShim {
819820
None
820821
}
821822

822-
case SortOrder(child, direction, nullOrdering, _) =>
823-
val childExpr = exprToProtoInternal(child, inputs, binding)
824-
825-
if (childExpr.isDefined) {
826-
val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
827-
sortOrderBuilder.setChild(childExpr.get)
828-
829-
direction match {
830-
case Ascending => sortOrderBuilder.setDirectionValue(0)
831-
case Descending => sortOrderBuilder.setDirectionValue(1)
832-
}
833-
834-
nullOrdering match {
835-
case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
836-
case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
837-
}
823+
case sortOrder @ SortOrder(child, direction, nullOrdering, _) =>
824+
val sortOrderProto = sortOrderingToProto(sortOrder, inputs, binding)
838825

826+
if (sortOrderProto.isEmpty) {
827+
withInfo(expr, child)
828+
None
829+
} else {
839830
Some(
840831
ExprOuterClass.Expr
841832
.newBuilder()
842-
.setSortOrder(sortOrderBuilder)
833+
.setSortOrder(sortOrderProto.get)
843834
.build())
844-
} else {
845-
withInfo(expr, child)
846-
None
847835
}
848836

849837
case UnaryExpression(child) if expr.prettyName == "promote_precision" =>
@@ -1363,18 +1351,16 @@ object QueryPlanSerde extends Logging with CometExprShim {
13631351
if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
13641352
val output = child.output
13651353

1366-
val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr =>
1367-
expr match {
1368-
case alias: Alias =>
1369-
alias.child match {
1370-
case winExpr: WindowExpression =>
1371-
Some(winExpr)
1372-
case _ =>
1373-
None
1374-
}
1375-
case _ =>
1376-
None
1377-
}
1354+
val winExprs: Array[WindowExpression] = windowExpression.flatMap {
1355+
case alias: Alias =>
1356+
alias.child match {
1357+
case winExpr: WindowExpression =>
1358+
Some(winExpr)
1359+
case _ =>
1360+
None
1361+
}
1362+
case _ =>
1363+
None
13781364
}.toArray
13791365

13801366
if (winExprs.length != windowExpression.length) {
@@ -1694,6 +1680,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
16941680
scanBuilder.setSource(source)
16951681
}
16961682

1683+
if (op.children.length == 1) {
1684+
scanBuilder.addAllInputOrdering(
1685+
QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(op.children.head).asJava)
1686+
}
1687+
16971688
val ffiSafe = op match {
16981689
case _ if isExchangeSink(op) =>
16991690
// Source of broadcast exchange batches is ArrowStreamReader
@@ -1927,6 +1918,79 @@ object QueryPlanSerde extends Logging with CometExprShim {
19271918
})
19281919
nativeScanBuilder.addFilePartitions(partitionBuilder.build())
19291920
}
1921+
1922+
def sortOrderingToProto(
1923+
sortOrder: SortOrder,
1924+
inputs: Seq[Attribute],
1925+
binding: Boolean): Option[ExprOuterClass.SortOrder] = {
1926+
val childExpr = exprToProtoInternal(sortOrder.child, inputs, binding)
1927+
1928+
if (childExpr.isDefined) {
1929+
val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
1930+
sortOrderBuilder.setChild(childExpr.get)
1931+
1932+
sortOrder.direction match {
1933+
case Ascending => sortOrderBuilder.setDirectionValue(0)
1934+
case Descending => sortOrderBuilder.setDirectionValue(1)
1935+
}
1936+
1937+
sortOrder.nullOrdering match {
1938+
case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
1939+
case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
1940+
}
1941+
1942+
Some(sortOrderBuilder.build())
1943+
} else {
1944+
withInfo(sortOrder, sortOrder.child)
1945+
None
1946+
}
1947+
}
1948+
1949+
/**
1950+
* Return the plan input sort order.
1951+
*
1952+
* This will not return the full sort order if it can't be fully mapped to the child (if the
1953+
* sort order is on an expression that is not a direct child of the input)
1954+
*
1955+
* in case this is the sort: Sort by a, b, coalesce(c, d), e
1956+
*
1957+
* We will return this sort order: a, b
1958+
*
1959+
* as it is still correct, the data IS ordered by a, b.
1960+
*
1961+
* And not: a, b, e
1962+
*
1963+
* as the data IS NOT ordered by a, b, e.
1964+
*
1965+
* This is meant to use for scan where we don't want to lose the input ordering information as
1966+
* it can allow certain optimization.
1967+
*/
1968+
def parsePlanSortOrderAsMuchAsCan(plan: SparkPlan): Seq[ExprOuterClass.SortOrder] = {
1969+
if (plan.outputOrdering.isEmpty) {
1970+
Seq.empty
1971+
} else {
1972+
val outputAttributes = plan.output
1973+
val sortOrders = plan.outputOrdering.map(so => {
1974+
if (!isExprOneOfAttributes(so.child, outputAttributes)) {
1975+
None
1976+
} else {
1977+
QueryPlanSerde.sortOrderingToProto(so, outputAttributes, binding = true)
1978+
}
1979+
})
1980+
1981+
// Take the sort orders until the first None
1982+
sortOrders.takeWhile(_.isDefined).map(_.get)
1983+
}
1984+
}
1985+
1986+
@tailrec
1987+
private def isExprOneOfAttributes(expr: Expression, attrs: Seq[Attribute]): Boolean = {
1988+
expr match {
1989+
case attr: Attribute => attrs.exists(_.exprId == attr.exprId)
1990+
case alias: Alias => isExprOneOfAttributes(alias.child, attrs)
1991+
case _ => false
1992+
}
1993+
}
19301994
}
19311995

19321996
sealed trait SupportLevel

spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ case class CometCollectLimitExec(
7777
childRDD
7878
} else {
7979
val localLimitedRDD = if (limit >= 0) {
80-
CometExecUtils.getNativeLimitRDD(childRDD, output, limit)
80+
CometExecUtils.getNativeLimitRDD(child, childRDD, output, limit)
8181
} else {
8282
childRDD
8383
}
@@ -92,7 +92,7 @@ case class CometCollectLimitExec(
9292

9393
new CometShuffledBatchRDD(dep, readMetrics)
9494
}
95-
CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit, offset)
95+
CometExecUtils.getNativeLimitRDD(child, singlePartitionRDD, output, limit, offset)
9696
}
9797
}
9898

spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, So
2828
import org.apache.spark.sql.execution.SparkPlan
2929
import org.apache.spark.sql.vectorized.ColumnarBatch
3030

31-
import org.apache.comet.serde.OperatorOuterClass
31+
import org.apache.comet.serde.{OperatorOuterClass, QueryPlanSerde}
3232
import org.apache.comet.serde.OperatorOuterClass.Operator
3333
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
3434

@@ -48,13 +48,14 @@ object CometExecUtils {
4848
* partition. The limit operation is performed on the native side.
4949
*/
5050
def getNativeLimitRDD(
51+
child: SparkPlan,
5152
childPlan: RDD[ColumnarBatch],
5253
outputAttribute: Seq[Attribute],
5354
limit: Int,
5455
offset: Int = 0): RDD[ColumnarBatch] = {
5556
val numParts = childPlan.getNumPartitions
5657
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
57-
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get
58+
val limitOp = CometExecUtils.getLimitNativePlan(child, outputAttribute, limit, offset).get
5859
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
5960
}
6061
}
@@ -90,10 +91,15 @@ object CometExecUtils {
9091
* child partition
9192
*/
9293
def getLimitNativePlan(
94+
child: SparkPlan,
9395
outputAttributes: Seq[Attribute],
9496
limit: Int,
9597
offset: Int = 0): Option[Operator] = {
96-
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("LimitInput")
98+
val scanBuilder = OperatorOuterClass.Scan
99+
.newBuilder()
100+
.setSource("LimitInput")
101+
.addAllInputOrdering(QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(child).asJava)
102+
97103
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()
98104

99105
val scanTypes = outputAttributes.flatten { attr =>
@@ -125,7 +131,11 @@ object CometExecUtils {
125131
child: SparkPlan,
126132
limit: Int,
127133
offset: Int = 0): Option[Operator] = {
128-
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("TopKInput")
134+
val scanBuilder = OperatorOuterClass.Scan
135+
.newBuilder()
136+
.setSource("TopKInput")
137+
.addAllInputOrdering(QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(child).asJava)
138+
129139
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()
130140

131141
val scanTypes = outputAttributes.flatten { attr =>

0 commit comments

Comments
 (0)