@@ -28,7 +28,7 @@ import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
2828import org .apache .gluten .sql .shims .SparkShimLoader
2929import org .apache .gluten .substrait .SubstraitContext
3030import org .apache .gluten .substrait .expression .{ExpressionBuilder , ExpressionNode , WindowFunctionNode }
31- import org .apache .gluten .utils .{CHJoinValidateUtil , UnknownJoinStrategy }
31+ import org .apache .gluten .utils .{CHJoinValidateUtil , PullOutProjectHelper , UnknownJoinStrategy }
3232import org .apache .gluten .vectorized .{BlockOutputStream , CHColumnarBatchSerializer , CHNativeBlock , CHStreamReader }
3333
3434import org .apache .spark .{ShuffleDependency , SparkEnv }
@@ -66,7 +66,7 @@ import java.util.{ArrayList => JArrayList, List => JList}
6666
6767import scala .collection .JavaConverters ._
6868
69- class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
69+ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging with PullOutProjectHelper {
7070
7171 /** Transform GetArrayItem to Substrait. */
7272 override def genGetArrayItemTransformer (
@@ -488,22 +488,33 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
488488 (Seq .empty, false )
489489 }
490490
491+ val (newChild, newOutput, newBuildKeys) = if (buildKeys.exists(isNotAttribute)) {
492+ val (newChild, newBuildKeys, _) = pullOutPreProjectForJoin(child, buildKeys)
493+ (newChild, newChild.output, newBuildKeys)
494+ } else {
495+ (child, child.output, Seq .empty)
496+ }
497+
491498 // find the key index in the output
492- // Note: PullOutPreProject rule ensures join keys are AttributeReferences before this point,
493- // so no pre-projection is needed here.
494499 val keyColumnIndex = if (isNullAware) {
495- buildKeys(0 ) match {
496- case b : BoundReference => b.ordinal
497- case n : NamedExpression =>
498- child.output.indexWhere(o => o.name.equals(n.name) && o.exprId == n.exprId)
499- case key =>
500- throw new GlutenException (s " Cannot find $key in the child's output: ${child.output}" )
500+ def findKeyOrdinal (key : Expression , output : Seq [Attribute ]): Int = {
501+ key match {
502+ case b : BoundReference => b.ordinal
503+ case n : NamedExpression =>
504+ output.indexWhere(o => (o.name.equals(n.name) && o.exprId == n.exprId))
505+ case _ => throw new GlutenException (s " Cannot find $key in the child's output: $output" )
506+ }
507+ }
508+ if (newBuildKeys.isEmpty) {
509+ findKeyOrdinal(buildKeys(0 ), newOutput)
510+ } else {
511+ findKeyOrdinal(newBuildKeys(0 ), newOutput)
501512 }
502513 } else {
503514 0
504515 }
505516 val countsAndBytes =
506- CHExecUtil .buildSideRDD(dataSize, child , isNullAware, keyColumnIndex).collect
517+ CHExecUtil .buildSideRDD(dataSize, newChild , isNullAware, keyColumnIndex).collect
507518
508519 val batches = countsAndBytes.map(_._2)
509520 val totalBatchesSize = batches.map(_.length).sum
@@ -524,10 +535,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
524535 numOutputRows += rowCount
525536 ClickHouseBuildSideRelation (
526537 mode,
527- child.output ,
538+ newOutput ,
528539 batches.flatten,
529540 rowCount,
530- Seq .empty ,
541+ newBuildKeys ,
531542 hasNullKeyValues)
532543 }
533544
0 commit comments