Skip to content

Commit 9995544

Browse files
committed
fix ut
1 parent 824beeb commit 9995544

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
2828
import org.apache.gluten.sql.shims.SparkShimLoader
2929
import org.apache.gluten.substrait.SubstraitContext
3030
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
31-
import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
31+
import org.apache.gluten.utils.{CHJoinValidateUtil, PullOutProjectHelper, UnknownJoinStrategy}
3232
import org.apache.gluten.vectorized.{BlockOutputStream, CHColumnarBatchSerializer, CHNativeBlock, CHStreamReader}
3333

3434
import org.apache.spark.{ShuffleDependency, SparkEnv}
@@ -66,7 +66,7 @@ import java.util.{ArrayList => JArrayList, List => JList}
6666

6767
import scala.collection.JavaConverters._
6868

69-
class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
69+
class CHSparkPlanExecApi extends SparkPlanExecApi with Logging with PullOutProjectHelper {
7070

7171
/** Transform GetArrayItem to Substrait. */
7272
override def genGetArrayItemTransformer(
@@ -488,22 +488,33 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
488488
(Seq.empty, false)
489489
}
490490

491+
val (newChild, newOutput, newBuildKeys) = if (buildKeys.exists(isNotAttribute)) {
492+
val (newChild, newBuildKeys, _) = pullOutPreProjectForJoin(child, buildKeys)
493+
(newChild, newChild.output, newBuildKeys)
494+
} else {
495+
(child, child.output, Seq.empty)
496+
}
497+
491498
// find the key index in the output
492-
// Note: PullOutPreProject rule ensures join keys are AttributeReferences before this point,
493-
// so no pre-projection is needed here.
494499
val keyColumnIndex = if (isNullAware) {
495-
buildKeys(0) match {
496-
case b: BoundReference => b.ordinal
497-
case n: NamedExpression =>
498-
child.output.indexWhere(o => o.name.equals(n.name) && o.exprId == n.exprId)
499-
case key =>
500-
throw new GlutenException(s"Cannot find $key in the child's output: ${child.output}")
500+
def findKeyOrdinal(key: Expression, output: Seq[Attribute]): Int = {
501+
key match {
502+
case b: BoundReference => b.ordinal
503+
case n: NamedExpression =>
504+
output.indexWhere(o => (o.name.equals(n.name) && o.exprId == n.exprId))
505+
case _ => throw new GlutenException(s"Cannot find $key in the child's output: $output")
506+
}
507+
}
508+
if (newBuildKeys.isEmpty) {
509+
findKeyOrdinal(buildKeys(0), newOutput)
510+
} else {
511+
findKeyOrdinal(newBuildKeys(0), newOutput)
501512
}
502513
} else {
503514
0
504515
}
505516
val countsAndBytes =
506-
CHExecUtil.buildSideRDD(dataSize, child, isNullAware, keyColumnIndex).collect
517+
CHExecUtil.buildSideRDD(dataSize, newChild, isNullAware, keyColumnIndex).collect
507518

508519
val batches = countsAndBytes.map(_._2)
509520
val totalBatchesSize = batches.map(_.length).sum
@@ -524,10 +535,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
524535
numOutputRows += rowCount
525536
ClickHouseBuildSideRelation(
526537
mode,
527-
child.output,
538+
newOutput,
528539
batches.flatten,
529540
rowCount,
530-
Seq.empty,
541+
newBuildKeys,
531542
hasNullKeyValues)
532543
}
533544

gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.apache.gluten.extension.columnar.rewrite
1818

1919
import org.apache.gluten.backendsapi.BackendsApiManager
20-
import org.apache.gluten.extension.columnar.heuristic.RewrittenNodeWall
2120
import org.apache.gluten.sql.shims.SparkShimLoader
2221
import org.apache.gluten.utils.PullOutProjectHelper
2322

@@ -303,28 +302,6 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
303302
case _ =>
304303
(join.leftKeys, join.rightKeys)
305304
}
306-
307-
def pullOutPreProjectForJoin(joinChild: SparkPlan, joinKeys: Seq[Expression])
308-
: (SparkPlan, Seq[Expression], mutable.HashMap[Expression, NamedExpression]) = {
309-
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
310-
if (joinKeys.exists(isNotAttribute)) {
311-
val newJoinKeys =
312-
joinKeys.toIndexedSeq.map(replaceExpressionWithAttribute(_, expressionMap))
313-
val preProject = ProjectExec(
314-
eliminateProjectList(joinChild.outputSet, expressionMap.values.toSeq),
315-
joinChild)
316-
joinChild match {
317-
case r: RewrittenNodeWall =>
318-
r.originalChild.logicalLink.foreach(preProject.setLogicalLink)
319-
case _ =>
320-
joinChild.logicalLink.foreach(preProject.setLogicalLink)
321-
}
322-
(preProject, newJoinKeys, expressionMap)
323-
} else {
324-
(joinChild, joinKeys, expressionMap)
325-
}
326-
}
327-
328305
val (newLeft, newLeftKeys, leftMap) = pullOutPreProjectForJoin(join.left, leftKeys)
329306
val (newRight, newRightKeys, rightMap) = pullOutPreProjectForJoin(join.right, rightKeys)
330307
val newCondition = if (leftMap.nonEmpty || rightMap.nonEmpty) {

gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ package org.apache.gluten.utils
1818

1919
import org.apache.gluten.backendsapi.BackendsApiManager
2020
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
21+
import org.apache.gluten.extension.columnar.heuristic.RewrittenNodeWall
2122

2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete, Partial}
24-
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
2526
import org.apache.spark.sql.execution.aggregate._
26-
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec}
27+
import org.apache.spark.sql.execution.joins._
2728
import org.apache.spark.sql.execution.window.WindowExec
28-
import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType}
29+
import org.apache.spark.sql.types._
2930

3031
import java.sql.Date
3132
import java.util.concurrent.atomic.AtomicInteger
@@ -332,4 +333,25 @@ trait PullOutProjectHelper {
332333
newWe.copyTagsFrom(we)
333334
newWe
334335
}
336+
337+
protected def pullOutPreProjectForJoin(joinChild: SparkPlan, joinKeys: Seq[Expression])
338+
: (SparkPlan, Seq[Expression], mutable.HashMap[Expression, NamedExpression]) = {
339+
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
340+
if (joinKeys.exists(isNotAttribute)) {
341+
val newJoinKeys =
342+
joinKeys.toIndexedSeq.map(replaceExpressionWithAttribute(_, expressionMap))
343+
val preProject = ProjectExec(
344+
eliminateProjectList(joinChild.outputSet, expressionMap.values.toSeq),
345+
joinChild)
346+
joinChild match {
347+
case r: RewrittenNodeWall =>
348+
r.originalChild.logicalLink.foreach(preProject.setLogicalLink)
349+
case _ =>
350+
joinChild.logicalLink.foreach(preProject.setLogicalLink)
351+
}
352+
(preProject, newJoinKeys, expressionMap)
353+
} else {
354+
(joinChild, joinKeys, expressionMap)
355+
}
356+
}
335357
}

0 commit comments

Comments
 (0)