Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ trait KeyGroupedPartitionedScan[T] {
filteredPartitions: Seq[Seq[T]],
partitionValueAccessor: T => InternalRow): Seq[Seq[T]] = {
assert(spjParams.keyGroupedPartitioning.isDefined)

if (spjParams.noGrouping) {
return filteredPartitions.flatten.map(Seq(_))
}

val expressions = spjParams.keyGroupedPartitioning.get

// Re-group the input partitions if we are projecting on a subset of join keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterized
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE}
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
Expand Down Expand Up @@ -622,6 +623,11 @@ object QueryExecution {
sparkSession: SparkSession,
adaptiveExecutionRule: Option[InsertAdaptiveSparkPlan] = None,
subquery: Boolean): Seq[Rule[SparkPlan]] = {
val requiredDistribution = if (subquery) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure i get this, if its not a subquery we pass in any requiredDistribution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let me change this tomorrow and pass in subquery directly into EnsureRequirements, that way this will be much cleaner.

Copy link
Contributor Author

@peter-toth peter-toth Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b04bb61 and added comments in c28fc3f.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more detailed comments here? It looks confusing without any context when looking code here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this and now passing in subquery and added comments in c28fc3f.

Some(UnspecifiedDistribution)
} else {
None
}
// `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op
// as the original plan is hidden behind `AdaptiveSparkPlanExec`.
adaptiveExecutionRule.toSeq ++
Expand All @@ -630,7 +636,7 @@ object QueryExecution {
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements(),
EnsureRequirements(requiredDistribution = requiredDistribution),
// This rule must be run after `EnsureRequirements`.
InsertSortForLimitAndOffset,
// `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,22 @@ case class EnsureRequirements(
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
ensureOrdering(child, distribution)
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
val newChild = disableKeyGroupingIfNotNeeded(child)
BroadcastExchangeExec(mode, newChild)
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
distribution match {
case _: StatefulOpClusteredDistribution =>
val newChild = disableKeyGroupingIfNotNeeded(child)
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child,
distribution.createPartitioning(numPartitions), newChild,
REQUIRED_BY_STATEFUL_OPERATOR)

case _ =>
val newChild = disableKeyGroupingIfNotNeeded(child)
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child, shuffleOrigin)
distribution.createPartitioning(numPartitions), newChild, shuffleOrigin)
}
}

Expand Down Expand Up @@ -224,8 +227,11 @@ case class EnsureRequirements(

child match {
case ShuffleExchangeExec(_, c, so, ps) =>
ShuffleExchangeExec(newPartitioning, c, so, ps)
case _ => ShuffleExchangeExec(newPartitioning, child)
val newChild = disableKeyGroupingIfNotNeeded(c)
ShuffleExchangeExec(newPartitioning, newChild, so, ps)
case _ =>
val newChild = disableKeyGroupingIfNotNeeded(child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make a method createShuffleExchangeExec(..., disableGrouping: Boolean) to reduce duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in b04bb61.

ShuffleExchangeExec(newPartitioning, newChild)
}
}
}
Expand Down Expand Up @@ -695,6 +701,21 @@ case class EnsureRequirements(
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
}

private def disableKeyGroupingIfNotNeeded(child: SparkPlan) = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More detailed comments on this method would be good, e.g., the conditions under which grouping can be safely disabled, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added comments in b04bb61.

if (canApplyPartialClusteredDistribution(child)) {
populateNoGroupingPartitionInfo(child)
} else {
child
}
}

private def populateNoGroupingPartitionInfo(plan: SparkPlan): SparkPlan = plan match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like can be done with transform api?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can change this to use transform() APIs.
Wanted to make it similar to the other 2 populate...() methods. Shall I change those as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified all 3 populate...()s in b04bb61.

case scan: BatchScanExec =>
val newScan = scan.copy(spjParams = scan.spjParams.copy(noGrouping = true))
newScan.copyTagsFrom(scan)
newScan
case node => node.mapChildren(child => populateNoGroupingPartitionInfo(child))
}

private def populateJoinKeyPositions(
plan: SparkPlan,
Expand Down Expand Up @@ -843,9 +864,14 @@ case class EnsureRequirements(
} else {
REPARTITION_BY_COL
}
val groupingDisabledPlan = if (requiredDistribution.get == UnspecifiedDistribution) {
disableKeyGroupingIfNotNeeded(newPlan)
} else {
newPlan
}
val finalPlan = ensureDistributionAndOrdering(
None,
newPlan :: Nil,
groupingDisabledPlan :: Nil,
requiredDistribution.get :: Nil,
Seq(Nil),
shuffleOrigin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ case class StoragePartitionJoinParams(
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
reducers: Option[Seq[Option[Reducer[_, _]]]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
replicatePartitions: Boolean = false,
noGrouping: Boolean = false) {
override def equals(other: Any): Boolean = other match {
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering &&
this.joinKeyPositions == other.joinKeyPositions
this.joinKeyPositions == other.joinKeyPositions &&
this.noGrouping == other.noGrouping
case _ =>
false
}
Expand All @@ -44,5 +46,6 @@ case class StoragePartitionJoinParams(
joinKeyPositions: Option[Seq[Int]],
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
applyPartialClustering: java.lang.Boolean,
replicatePartitions: java.lang.Boolean)
replicatePartitions: java.lang.Boolean,
noGrouping: java.lang.Boolean)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
case (true, false, false) => assert(scannedPartitions == Seq(4, 4))

// No SPJ
case _ => assert(scannedPartitions == Seq(5, 4))
case _ => assert(scannedPartitions == Seq(7, 7))
}

checkAnswer(df, Seq(
Expand Down Expand Up @@ -2114,7 +2114,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(scans == Seq(2, 2))
case (_, _) =>
assert(shuffles.nonEmpty, "SPJ should not be triggered")
assert(scans == Seq(3, 2))
assert(scans == Seq(3, 3))
}

checkAnswer(df, Seq(
Expand Down Expand Up @@ -2234,7 +2234,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
// SPJ and not partially-clustered
case (true, false) => assert(scans == Seq(3, 3))
// No SPJ
case _ => assert(scans == Seq(4, 4))
case _ => assert(scans == Seq(5, 5))
}

checkAnswer(df,
Expand Down Expand Up @@ -2823,4 +2823,65 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}

test("SPARK-55092: Don't group partitions when not needed") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

val purchases_partitions = Array(years("time"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)"))

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "only shuffle one side not report partitioning")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans(0).inputRDD.partitions.length === 2,
"items scan should group as it is the driver of SPJ")
assert(scans(1).inputRDD.partitions.length === 2,
"purchases scan should not group as SPJ can't leverage it")

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020)))
}

withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "false") {
val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)"))

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 2, "only shuffle one side not report partitioning")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans(0).inputRDD.partitions.length === 3,
"items scan should not group as it is shuffled")
assert(scans(1).inputRDD.partitions.length === 2,
"purchases scan should not group as it is shuffled")

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020)))
}
}

test("SPARK-55092: Main query output maintains partition grouping despite it is not needed") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

val df = sql(s"SELECT * FROM testcat.ns.$items")
val scans = collectScans(df.queryExecution.executedPlan)
assert(scans(0).inputRDD.partitions.length === 2,
"items scan should group to maintain query output partitioning semantics")
}
}